natankatz commited on
Commit
7ef1fbc
·
verified ·
1 Parent(s): 339cf21

Upload utils_global.py

Browse files
models/classifications/utils_global.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import OrderedDict
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+
6
+ import torch
7
+ import torchvision
8
+
9
+
10
+ def check_is_valid_torchvision_architecture(architecture: str):
11
+ """Raises an ValueError if architecture is not part of available torchvision models
12
+ """
13
+ available = sorted(
14
+ name
15
+ for name in torchvision.models.__dict__
16
+ if name.islower()
17
+ and not name.startswith("__")
18
+ and callable(torchvision.models.__dict__[name])
19
+ )
20
+ if architecture not in available:
21
+ raise ValueError(f"{architecture} not in {available}")
22
+
23
+
24
+ def build_base_model(arch: str):
25
+
26
+ model = torchvision.models.__dict__[arch](pretrained=True)
27
+
28
+ # get input dimension before classification layer
29
+ if arch in ["mobilenet_v2"]:
30
+ nfeatures = model.classifier[-1].in_features
31
+ model = torch.nn.Sequential(*list(model.children())[:-1])
32
+ elif arch in ["densenet121", "densenet161", "densenet169"]:
33
+ nfeatures = model.classifier.in_features
34
+ model = torch.nn.Sequential(*list(model.children())[:-1])
35
+ elif "resne" in arch:
36
+ # usually all ResNet variants
37
+ nfeatures = model.fc.in_features
38
+ model = torch.nn.Sequential(*list(model.children())[:-2])
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
43
+ model.flatten = torch.nn.Flatten(start_dim=1)
44
+ return model, nfeatures
45
+
46
+
47
+ def load_weights_if_available(
48
+ model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
49
+ ):
50
+
51
+ checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
52
+
53
+ state_dict_features = OrderedDict()
54
+ state_dict_classifier = OrderedDict()
55
+ for k, w in checkpoint["state_dict"].items():
56
+ if k.startswith("model"):
57
+ state_dict_features[k.replace("model.", "")] = w
58
+ elif k.startswith("classifier"):
59
+ state_dict_classifier[k.replace("classifier.", "")] = w
60
+ else:
61
+ logging.warning(f"Unexpected prefix in state_dict: {k}")
62
+ model.load_state_dict(state_dict_features, strict=True)
63
+ return model, classifier
64
+
65
+
66
+ def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
67
+ R = 6371
68
+ factor_rad = 0.01745329252
69
+ longitudes = factor_rad * longitudes
70
+ longitudes_gt = factor_rad * longitudes_gt
71
+ latitudes = factor_rad * latitudes
72
+ latitudes_gt = factor_rad * latitudes_gt
73
+ delta_long = longitudes_gt - longitudes
74
+ delta_lat = latitudes_gt - latitudes
75
+ subterm0 = torch.sin(delta_lat / 2) ** 2
76
+ subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
77
+ subterm2 = torch.sin(delta_long / 2) ** 2
78
+ subterm1 = subterm1 * subterm2
79
+ a = subterm0 + subterm1
80
+ c = 2 * torch.asin(torch.sqrt(a))
81
+ gcd = R * c
82
+ return gcd
83
+
84
+
85
+ def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
86
+ # calculate accuracy for given gcd thresolds
87
+ results = {}
88
+ for thres in thresholds:
89
+ results[thres] = torch.true_divide(
90
+ torch.sum(gc_dists <= thres), len(gc_dists)
91
+ ).item()
92
+ return results
93
+
94
+
95
+ def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
96
+ def _accuracy(output, target, topk=(1,)):
97
+ """Computes the accuracy over the k top predictions for the specified values of k"""
98
+ with torch.no_grad():
99
+ maxk = max(topk)
100
+ batch_size = target.size(0)
101
+
102
+ _, pred = output.topk(maxk, 1, True, True)
103
+ pred = pred.t()
104
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
105
+
106
+ res = {}
107
+ for k in topk:
108
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
109
+ res[k] = correct_k / batch_size
110
+ return res
111
+
112
+ with torch.no_grad():
113
+ out_dict = {}
114
+ for i, pname in enumerate(partitioning_shortnames):
115
+ res_dict = _accuracy(output[i], target[i], topk=topk)
116
+ for k, v in res_dict.items():
117
+ out_dict[f"acc{k}_val/{pname}"] = v
118
+
119
+ return out_dict
120
+
121
+
122
+ def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
123
+ gcd_dict = {}
124
+ metric_names = [f"gcd_{p}_val" for p in pnames]
125
+ if hierarchy is not None:
126
+ metric_names.append("gcd_hierarchy_val")
127
+ for metric_name in metric_names:
128
+ distances_flat = [output[metric_name] for output in outputs]
129
+ distances_flat = torch.cat(distances_flat, dim=0)
130
+ gcd_results = gcd_threshold_eval(distances_flat)
131
+ for gcd_thres, acc in gcd_results.items():
132
+ gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
133
+ return gcd_dict
134
+
135
+
136
+ def summarize_test_gcd(pnames, outputs, hierarchy=None):
137
+ def _eval(output):
138
+ # calculate acc@km for a list of given thresholds
139
+ accuracy_outputs = {}
140
+ if hierarchy is not None:
141
+ pnames.append("hierarchy")
142
+ for pname in pnames:
143
+ # concat batches of distances
144
+ distances_flat = torch.cat([x[pname] for x in output], dim=0)
145
+ # acc for all distances
146
+ acc_dict = gcd_threshold_eval(distances_flat)
147
+ accuracy_outputs[f"acc_test/{pname}"] = acc_dict
148
+ return accuracy_outputs
149
+
150
+ result = {}
151
+
152
+ if isinstance(outputs[0], dict): # only one testset
153
+ result = _eval(outputs)
154
+ elif isinstance(outputs[0], list): # multiple testsets
155
+ for testset_index, output in enumerate(outputs):
156
+ result[testset_index] = _eval(output)
157
+ else:
158
+ raise TypeError
159
+
160
+ return result
161
+
162
+
163
+ def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
164
+
165
+ loss_acc_dict = {}
166
+ metric_names = []
167
+ for k in topk:
168
+ accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
169
+ metric_names.extend(accuracy_names)
170
+ metric_names.extend([f"loss_val/{p}" for p in pnames])
171
+ for metric_name in ["loss_val/total", *metric_names]:
172
+ metric_total = 0
173
+ for output in outputs:
174
+ metric_value = output[metric_name]
175
+ metric_total += metric_value
176
+ loss_acc_dict[metric_name] = metric_total / len(outputs)
177
+ return loss_acc_dict