Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import torch | |
| from torch import nn | |
| from models.networks.utils import UnormGPS | |
| class Random(nn.Module): | |
| def __init__(self, num_output): | |
| """Random""" | |
| super().__init__() | |
| self.num_output = num_output | |
| self.unorm = UnormGPS() | |
| def forward(self, x): | |
| """Predicts GPS coordinates from an image. | |
| Args: | |
| x: torch.Tensor with features | |
| """ | |
| #x = x["img"] | |
| gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1 | |
| return {"gps": self.unorm(gps)} | |
| class RandomCoords(nn.Module): | |
| def __init__(self, coords_path: str): | |
| """Randomly sample from a list of coordinates | |
| Args: | |
| coords_path: str with path to csv file with coordinates | |
| """ | |
| super().__init__() | |
| coordinates = pd.read_csv(coords_path) | |
| longitudes = coordinates["longitude"].values / 180 | |
| latitudes = coordinates["latitude"].values / 90 | |
| self.unorm = UnormGPS() | |
| del coordinates | |
| self.N = len(longitudes) | |
| assert len(longitudes) == len(latitudes) | |
| self.coordinates = torch.stack( | |
| [torch.tensor(latitudes), torch.tensor(longitudes)], | |
| dim=-1, | |
| ) | |
| del longitudes, latitudes | |
| def forward(self, x): | |
| """Predicts GPS coordinates from an image. | |
| Args: | |
| x: torch.Tensor with features | |
| """ | |
| x = x["img"] | |
| # randomly select a coordinate in the list | |
| n = torch.randint(0, self.N, (x.shape[0],)) | |
| return {"gps": self.unorm(self.coordinates[n].to(x.device))} | |