| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision.models as models |
| |
|
| | class IMG2GPS(nn.Module): |
| | """ |
| | EfficientNet-B0 model for GPS coordinate prediction from images. |
| | |
| | Input: Batch of images (N, 3, 224, 224) - ImageNet normalized |
| | Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees |
| | """ |
| | |
| | def __init__(self): |
| | super().__init__() |
| | |
| | |
| | self.backbone = models.efficientnet_b0(pretrained=False) |
| | |
| | |
| | num_features = self.backbone.classifier[1].in_features |
| | self.backbone.classifier[1] = nn.Linear(num_features, 2) |
| | |
| | |
| | self.lat_mean = 39.951525 |
| | self.lat_std = 0.000652 |
| | self.lon_mean = -75.191400 |
| | self.lon_std = 0.000598 |
| | |
| | def forward(self, x): |
| | """ |
| | Forward pass through the model. |
| | |
| | Args: |
| | x: Input tensor of shape (N, 3, 224, 224) - normalized images |
| | or list of tensors |
| | |
| | Returns: |
| | Tensor of shape (N, 2) - denormalized lat/lon in degrees |
| | """ |
| | |
| | if isinstance(x, list): |
| | x = torch.stack(x) |
| | |
| | |
| | normalized_coords = self.backbone(x) |
| | |
| | |
| | denormalized_coords = normalized_coords * torch.tensor( |
| | [self.lat_std, self.lon_std], |
| | device=x.device, |
| | dtype=x.dtype |
| | ) + torch.tensor( |
| | [self.lat_mean, self.lon_mean], |
| | device=x.device, |
| | dtype=x.dtype |
| | ) |
| | |
| | return denormalized_coords |
| | |
| | def predict(self, batch): |
| | """ |
| | Inference method for compatibility with backend. |
| | |
| | Args: |
| | batch: Input tensor of shape (N, 3, 224, 224) |
| | or list of tensors |
| | |
| | Returns: |
| | numpy array of shape (N, 2) with raw lat/lon in degrees |
| | """ |
| | |
| | if isinstance(batch, list): |
| | batch = torch.stack(batch) |
| | |
| | self.eval() |
| | with torch.no_grad(): |
| | output = self.forward(batch) |
| | return output.cpu().numpy() |
| |
|
| |
|
| | def get_model(): |
| | """ |
| | Factory function to instantiate the model. |
| | |
| | Returns: |
| | IMG2GPS model instance |
| | """ |
| | return IMG2GPS() |
| |
|