limich19's picture
Add model.py for project submission
364b36a verified
raw
history blame
2.08 kB
import torch
import torch.nn as nn
import timm
class IMG2GPS(nn.Module):
"""
EfficientNet-B0 model for GPS coordinate prediction from images.
Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
"""
def __init__(self):
super().__init__()
# Load pre-trained EfficientNet-B0
self.backbone = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
# Hardcoded normalization statistics from training set
self.lat_mean = 39.951525
self.lat_std = 0.000652
self.lon_mean = -75.191400
self.lon_std = 0.000598
def forward(self, x):
"""
Forward pass through the model.
Args:
x: Input tensor of shape (N, 3, 224, 224) - normalized images
Returns:
Tensor of shape (N, 2) - denormalized lat/lon in degrees
"""
# Model outputs normalized GPS coordinates
normalized_coords = self.backbone(x) # Shape: (N, 2)
# Denormalize to get raw lat/lon in degrees
denormalized_coords = normalized_coords * torch.tensor(
[self.lat_std, self.lon_std],
device=x.device,
dtype=x.dtype
) + torch.tensor(
[self.lat_mean, self.lon_mean],
device=x.device,
dtype=x.dtype
)
return denormalized_coords
def predict(self, batch):
"""
Inference method for compatibility with backend.
Args:
batch: Input tensor of shape (N, 3, 224, 224)
Returns:
numpy array of shape (N, 2) with raw lat/lon in degrees
"""
self.eval()
with torch.no_grad():
output = self.forward(batch)
return output.cpu().numpy()
def get_model():
"""
Factory function to instantiate the model.
Returns:
IMG2GPS model instance
"""
return IMG2GPS()