import torch import torch.nn as nn import torchvision.models as models class IMG2GPS(nn.Module): """ EfficientNet-B0 model for GPS coordinate prediction from images. Input: Batch of images (N, 3, 224, 224) - ImageNet normalized Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees """ def __init__(self): super().__init__() # Load pre-trained EfficientNet-B0 from torchvision self.backbone = models.efficientnet_b0(pretrained=False) # Replace the final classifier layer to output 2 values num_features = self.backbone.classifier[1].in_features self.backbone.classifier[1] = nn.Linear(num_features, 2) # Hardcoded normalization statistics from training set self.lat_mean = 39.951525 self.lat_std = 0.000652 self.lon_mean = -75.191400 self.lon_std = 0.000598 def forward(self, x): """ Forward pass through the model. Args: x: Input tensor of shape (N, 3, 224, 224) - normalized images or list of tensors Returns: Tensor of shape (N, 2) - denormalized lat/lon in degrees """ # Handle case where input is a list of tensors if isinstance(x, list): x = torch.stack(x) # Model outputs normalized GPS coordinates normalized_coords = self.backbone(x) # Shape: (N, 2) # Denormalize to get raw lat/lon in degrees denormalized_coords = normalized_coords * torch.tensor( [self.lat_std, self.lon_std], device=x.device, dtype=x.dtype ) + torch.tensor( [self.lat_mean, self.lon_mean], device=x.device, dtype=x.dtype ) return denormalized_coords def predict(self, batch): """ Inference method for compatibility with backend. Args: batch: Input tensor of shape (N, 3, 224, 224) or list of tensors Returns: numpy array of shape (N, 2) with raw lat/lon in degrees """ # Handle case where batch is a list of tensors if isinstance(batch, list): batch = torch.stack(batch) self.eval() with torch.no_grad(): output = self.forward(batch) return output.cpu().numpy() def get_model(): """ Factory function to instantiate the model. Returns: IMG2GPS model instance """ return IMG2GPS()