limich19's picture
Add model.py for project submission
28673db verified
import torch
import torch.nn as nn
import torchvision.models as models
class IMG2GPS(nn.Module):
"""
EfficientNet-B0 model for GPS coordinate prediction from images.
Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
"""
def __init__(self):
super().__init__()
# Load pre-trained EfficientNet-B0 from torchvision
self.backbone = models.efficientnet_b0(pretrained=False)
# Replace the final classifier layer to output 2 values
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier[1] = nn.Linear(num_features, 2)
# Hardcoded normalization statistics from training set
self.lat_mean = 39.951525
self.lat_std = 0.000652
self.lon_mean = -75.191400
self.lon_std = 0.000598
def forward(self, x):
"""
Forward pass through the model.
Args:
x: Input tensor of shape (N, 3, 224, 224) - normalized images
or list of tensors
Returns:
Tensor of shape (N, 2) - denormalized lat/lon in degrees
"""
# Handle case where input is a list of tensors
if isinstance(x, list):
x = torch.stack(x)
# Model outputs normalized GPS coordinates
normalized_coords = self.backbone(x) # Shape: (N, 2)
# Denormalize to get raw lat/lon in degrees
denormalized_coords = normalized_coords * torch.tensor(
[self.lat_std, self.lon_std],
device=x.device,
dtype=x.dtype
) + torch.tensor(
[self.lat_mean, self.lon_mean],
device=x.device,
dtype=x.dtype
)
return denormalized_coords
def predict(self, batch):
"""
Inference method for compatibility with backend.
Args:
batch: Input tensor of shape (N, 3, 224, 224)
or list of tensors
Returns:
numpy array of shape (N, 2) with raw lat/lon in degrees
"""
# Handle case where batch is a list of tensors
if isinstance(batch, list):
batch = torch.stack(batch)
self.eval()
with torch.no_grad():
output = self.forward(batch)
return output.cpu().numpy()
def get_model():
"""
Factory function to instantiate the model.
Returns:
IMG2GPS model instance
"""
return IMG2GPS()