File size: 2,521 Bytes
364b36a 6001f90 364b36a 28673db 364b36a 28673db 364b36a 28673db 6001f90 28673db 6001f90 28673db 364b36a 28673db 364b36a 28673db 364b36a fa3608f 28673db 364b36a fa3608f 28673db 364b36a 28673db 364b36a fa3608f 364b36a fa3608f 364b36a 28673db 364b36a 28673db 364b36a 28673db 364b36a fa3608f 28673db 364b36a fa3608f 28673db 364b36a 28673db 364b36a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import torch.nn as nn
import torchvision.models as models
class IMG2GPS(nn.Module):
"""
EfficientNet-B0 model for GPS coordinate prediction from images.
Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
"""
def __init__(self):
super().__init__()
# Load pre-trained EfficientNet-B0 from torchvision
self.backbone = models.efficientnet_b0(pretrained=False)
# Replace the final classifier layer to output 2 values
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier[1] = nn.Linear(num_features, 2)
# Hardcoded normalization statistics from training set
self.lat_mean = 39.951525
self.lat_std = 0.000652
self.lon_mean = -75.191400
self.lon_std = 0.000598
def forward(self, x):
"""
Forward pass through the model.
Args:
x: Input tensor of shape (N, 3, 224, 224) - normalized images
or list of tensors
Returns:
Tensor of shape (N, 2) - denormalized lat/lon in degrees
"""
# Handle case where input is a list of tensors
if isinstance(x, list):
x = torch.stack(x)
# Model outputs normalized GPS coordinates
normalized_coords = self.backbone(x) # Shape: (N, 2)
# Denormalize to get raw lat/lon in degrees
denormalized_coords = normalized_coords * torch.tensor(
[self.lat_std, self.lon_std],
device=x.device,
dtype=x.dtype
) + torch.tensor(
[self.lat_mean, self.lon_mean],
device=x.device,
dtype=x.dtype
)
return denormalized_coords
def predict(self, batch):
"""
Inference method for compatibility with backend.
Args:
batch: Input tensor of shape (N, 3, 224, 224)
or list of tensors
Returns:
numpy array of shape (N, 2) with raw lat/lon in degrees
"""
# Handle case where batch is a list of tensors
if isinstance(batch, list):
batch = torch.stack(batch)
self.eval()
with torch.no_grad():
output = self.forward(batch)
return output.cpu().numpy()
def get_model():
"""
Factory function to instantiate the model.
Returns:
IMG2GPS model instance
"""
return IMG2GPS()
|