natankatz commited on
Commit
a11cb51
·
verified ·
1 Parent(s): 0e057e7

Upload 7 files

Browse files
models/networks/heads/__init__ .py ADDED
File without changes
models/networks/heads/auxilliary.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.networks.utils import UnormGPS
3
+ from torch.nn.functional import tanh, sigmoid, softmax
4
+
5
+
6
+ class AuxHead(nn.Module):
7
+ def __init__(self, aux_data=[], use_tanh=False):
8
+ super().__init__()
9
+ self.aux_data = aux_data
10
+ self.unorm = UnormGPS()
11
+ self.use_tanh = use_tanh
12
+
13
+ def forward(self, x):
14
+ """Forward pass of the network.
15
+ x : Union[torch.Tensor, dict] with the output of the backbone.
16
+ """
17
+ if self.use_tanh:
18
+ gps = tanh(x["gps"])
19
+ gps = self.unorm(gps)
20
+ output = {"gps": gps}
21
+ if "land_cover" in self.aux_data:
22
+ output["land_cover"] = softmax(x["land_cover"])
23
+ if "road_index" in self.aux_data:
24
+ output["road_index"] = x["road_index"]
25
+ if "drive_side" in self.aux_data:
26
+ output["drive_side"] = sigmoid(x["drive_side"])
27
+ if "climate" in self.aux_data:
28
+ output["climate"] = softmax(x["climate"])
29
+ if "soil" in self.aux_data:
30
+ output["soil"] = softmax(x["soil"])
31
+ if "dist_sea" in self.aux_data:
32
+ output["dist_sea"] = x["dist_sea"]
33
+ return output
models/networks/heads/classification.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ClassificationHead(nn.Module):
6
+ """Classification head for the network."""
7
+
8
+ def __init__(self, id_to_gps):
9
+ super().__init__()
10
+ self.id_to_gps = id_to_gps
11
+
12
+ def forward(self, x):
13
+ """Forward pass of the network.
14
+ x : Union[torch.Tensor, dict] with the output of the backbone.
15
+ """
16
+ gps = self.id_to_gps(x.argmax(dim=-1))
17
+ return {"label": x, **gps}
models/networks/heads/hybrid.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+
5
+ from models.networks.utils import UnormGPS
6
+
7
+
8
+ class HybridHead(nn.Module):
9
+ """Classification head followed by regression head for the network."""
10
+
11
+ def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
12
+ super().__init__()
13
+ self.final_dim = final_dim
14
+ self.use_tanh = use_tanh
15
+ self.scale_tanh = scale_tanh
16
+
17
+ self.unorm = UnormGPS()
18
+
19
+ if quadtree_path is not None:
20
+ quadtree = pd.read_csv(quadtree_path)
21
+ self.init_quadtree(quadtree)
22
+
23
+ def init_quadtree(self, quadtree):
24
+ quadtree[["min_lat", "max_lat"]] /= 90.0
25
+ quadtree[["min_lon", "max_lon"]] /= 180.0
26
+ self.register_buffer(
27
+ "cell_center",
28
+ 0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values)
29
+ + 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values),
30
+ )
31
+ self.register_buffer(
32
+ "cell_size",
33
+ torch.tensor(quadtree[["max_lat", "max_lon"]].values)
34
+ - torch.tensor(quadtree[["min_lat", "min_lon"]].values),
35
+ )
36
+
37
+ def forward(self, x, gt_label):
38
+ """Forward pass of the network.
39
+ x : Union[torch.Tensor, dict] with the output of the backbone.
40
+ """
41
+
42
+ classification_logits = x[..., : self.final_dim]
43
+ classification = classification_logits.argmax(dim=-1)
44
+
45
+ regression = x[..., self.final_dim :]
46
+
47
+ if self.use_tanh:
48
+ regression = self.scale_tanh * torch.tanh(regression)
49
+
50
+ regression = regression.view(regression.shape[0], -1, 2)
51
+
52
+ if self.training:
53
+ regression = torch.gather(
54
+ regression,
55
+ 1,
56
+ gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
57
+ )[:, 0, :]
58
+ size = 2.0 / self.cell_size[gt_label]
59
+ center = self.cell_center[gt_label]
60
+ gps = (
61
+ self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
62
+ )
63
+ else:
64
+ regression = torch.gather(
65
+ regression,
66
+ 1,
67
+ classification.unsqueeze(-1)
68
+ .unsqueeze(-1)
69
+ .expand(regression.shape[0], 1, 2),
70
+ )[:, 0, :]
71
+ size = 2.0 / self.cell_size[classification]
72
+ center = self.cell_center[classification]
73
+ gps = (
74
+ self.cell_center[classification]
75
+ + regression * self.cell_size[classification] / 2.0
76
+ )
77
+
78
+ gps = self.unorm(gps)
79
+
80
+ return {
81
+ "label": classification_logits,
82
+ "gps": gps,
83
+ "size": size,
84
+ "center": center,
85
+ "reg": regression,
86
+ }
87
+
88
+ class HybridHeadCentroid(nn.Module):
89
+ """Classification head followed by regression head for the network."""
90
+
91
+ def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh):
92
+ super().__init__()
93
+ self.final_dim = final_dim
94
+ self.use_tanh = use_tanh
95
+ self.scale_tanh = scale_tanh
96
+
97
+ self.unorm = UnormGPS()
98
+ if quadtree_path is not None:
99
+ quadtree = pd.read_csv(quadtree_path)
100
+ self.init_quadtree(quadtree)
101
+
102
+ def init_quadtree(self, quadtree):
103
+ quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0
104
+ quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0
105
+ self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
106
+ self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values)
107
+ self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values)
108
+
109
+ def forward(self, x, gt_label):
110
+ """Forward pass of the network.
111
+ x : Union[torch.Tensor, dict] with the output of the backbone.
112
+ """
113
+ classification_logits = x[..., : self.final_dim]
114
+ classification = classification_logits.argmax(dim=-1)
115
+ self.cell_size_up = self.cell_size_up.to(classification.device)
116
+ self.cell_center = self.cell_center.to(classification.device)
117
+ self.cell_size_down = self.cell_size_down.to(classification.device)
118
+
119
+ regression = x[..., self.final_dim :]
120
+
121
+ if self.use_tanh:
122
+ regression = self.scale_tanh * torch.tanh(regression)
123
+
124
+ regression = regression.view(regression.shape[0], -1, 2)
125
+
126
+ if self.training:
127
+ regression = torch.gather(
128
+ regression,
129
+ 1,
130
+ gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2),
131
+ )[:, 0, :]
132
+ size = torch.where(
133
+ regression > 0,
134
+ self.cell_size_up[gt_label],
135
+ self.cell_size_down[gt_label],
136
+ )
137
+ center = self.cell_center[gt_label]
138
+ gps = self.cell_center[gt_label] + regression * size
139
+ else:
140
+ regression = torch.gather(
141
+ regression,
142
+ 1,
143
+ classification.unsqueeze(-1)
144
+ .unsqueeze(-1)
145
+ .expand(regression.shape[0], 1, 2),
146
+ )[:, 0, :]
147
+ size = torch.where(
148
+ regression > 0,
149
+ self.cell_size_up[classification],
150
+ self.cell_size_down[classification],
151
+ )
152
+ center = self.cell_center[classification]
153
+ gps = self.cell_center[classification] + regression * size
154
+
155
+ gps = self.unorm(gps)
156
+
157
+ return {
158
+ "label": classification_logits,
159
+ "gps": gps,
160
+ "size": 1.0 / size,
161
+ "center": center,
162
+ "reg": regression,
163
+ }
164
+
165
+
166
+ class SharedHybridHead(HybridHead):
167
+ """Classification head followed by SHARED regression head for the network."""
168
+
169
+ def forward(self, x, gt_label):
170
+ """Forward pass of the network.
171
+ x : Union[torch.Tensor, dict] with the output of the backbone.
172
+ """
173
+
174
+ classification_logits = x[..., : self.final_dim]
175
+ classification = classification_logits.argmax(dim=-1)
176
+
177
+ regression = x[..., self.final_dim :]
178
+
179
+ if self.use_tanh:
180
+ regression = self.scale_tanh * torch.tanh(regression)
181
+
182
+ if self.training:
183
+ gps = (
184
+ self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0
185
+ )
186
+ else:
187
+ gps = (
188
+ self.cell_center[classification]
189
+ + regression * self.cell_size[classification] / 2.0
190
+ )
191
+
192
+ gps = self.unorm(gps)
193
+
194
+ return {"label": classification_logits, "gps": gps}
models/networks/heads/id_to_gps.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.networks.utils import UnormGPS
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+
7
+ class IdToGPS(nn.Module):
8
+ def __init__(self, id_to_gps: str):
9
+ """Map index to gps coordinates (indices can be country or city ids)"""
10
+ super().__init__()
11
+ if "quadtree" in id_to_gps:
12
+ self.id_to_gps = torch.load(
13
+ "_".join(id_to_gps.split("_")[:-4] + id_to_gps.split("_")[-3:])
14
+ )
15
+ else:
16
+ self.id_to_gps = torch.load(id_to_gps)
17
+ #self.unorm = UnormGPS()
18
+
19
+ def forward(self, x):
20
+ """Mapping from country id to gps coordinates
21
+ Args:
22
+ x: torch.Tensor with features
23
+ """
24
+
25
+ if isinstance(x, dict):
26
+ # for oracle
27
+ labels, x = x["label"], x["img"]
28
+ else:
29
+ # predicted labels
30
+ labels = x
31
+ self.id_to_gps = self.id_to_gps.to(labels.device)
32
+ #return {"gps": self.unorm(self.id_to_gps[labels])}
33
+ return {"gps": self.id_to_gps[labels]}
models/networks/heads/random.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch import nn
4
+ from models.networks.utils import UnormGPS
5
+
6
+
7
+ class Random(nn.Module):
8
+ def __init__(self, num_output):
9
+ """Random"""
10
+ super().__init__()
11
+ self.num_output = num_output
12
+ self.unorm = UnormGPS()
13
+
14
+ def forward(self, x):
15
+ """Predicts GPS coordinates from an image.
16
+ Args:
17
+ x: torch.Tensor with features
18
+ """
19
+ #x = x["img"]
20
+ gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1
21
+ return {"gps": self.unorm(gps)}
22
+
23
+
24
+ class RandomCoords(nn.Module):
25
+ def __init__(self, coords_path: str):
26
+ """Randomly sample from a list of coordinates
27
+ Args:
28
+ coords_path: str with path to csv file with coordinates
29
+ """
30
+ super().__init__()
31
+ coordinates = pd.read_csv(coords_path)
32
+ longitudes = coordinates["longitude"].values / 180
33
+ latitudes = coordinates["latitude"].values / 90
34
+ self.unorm = UnormGPS()
35
+ del coordinates
36
+
37
+ self.N = len(longitudes)
38
+ assert len(longitudes) == len(latitudes)
39
+ self.coordinates = torch.stack(
40
+ [torch.tensor(latitudes), torch.tensor(longitudes)],
41
+ dim=-1,
42
+ )
43
+ del longitudes, latitudes
44
+
45
+ def forward(self, x):
46
+ """Predicts GPS coordinates from an image.
47
+ Args:
48
+ x: torch.Tensor with features
49
+ """
50
+ x = x["img"]
51
+ # randomly select a coordinate in the list
52
+ n = torch.randint(0, self.N, (x.shape[0],))
53
+ return {"gps": self.unorm(self.coordinates[n].to(x.device))}
models/networks/heads/regression.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.networks.utils import UnormGPS
2
+ import torch.nn as nn
3
+ from torch.nn.functional import tanh
4
+ import torch
5
+
6
+
7
+ class RegressionHead(nn.Module):
8
+ def __init__(self, use_tanh=False):
9
+ super().__init__()
10
+ self.unorm = UnormGPS()
11
+ self.use_tanh = use_tanh
12
+
13
+ def forward(self, x):
14
+ """Forward pass of the network.
15
+ x : Union[torch.Tensor, dict] with the output of the backbone.
16
+ """
17
+ if self.use_tanh:
18
+ x = tanh(x)
19
+ gps = self.unorm(x)
20
+ return {"gps": gps}
21
+
22
+
23
+ class RegressionHeadAngle(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.unorm = UnormGPS()
27
+
28
+ def forward(self, x):
29
+ """Forward pass of the network.
30
+ x : Union[torch.Tensor, dict] with the output of the backbone.
31
+ """
32
+ x1 = x[:, 0].pow(2)
33
+ x2 = x[:, 1].pow(2)
34
+ x3 = x[:, 2].pow(2)
35
+ x4 = x[:, 3].pow(2)
36
+ cos_lambda = x1 / (x1 + x2)
37
+ sin_lambda = x2 / (x1 + x2)
38
+ cos_phi = x3 / (x3 + x4)
39
+ sin_phi = x4 / (x3 + x4)
40
+ lbd = torch.atan2(sin_lambda, cos_lambda)
41
+ phi = torch.atan2(sin_phi, cos_phi)
42
+ gps = torch.cat((lbd.unsqueeze(1), phi.unsqueeze(1)), dim=1)
43
+ # gps = self.unorm(x)
44
+ return {"gps": gps}