File size: 2,521 Bytes
364b36a
 
 
6001f90
364b36a
 
 
 
28673db
364b36a
 
 
28673db
364b36a
 
28673db
6001f90
 
28673db
6001f90
 
 
28673db
364b36a
 
 
 
 
28673db
364b36a
 
 
28673db
364b36a
 
fa3608f
28673db
364b36a
 
 
fa3608f
 
 
28673db
364b36a
 
28673db
364b36a
 
fa3608f
 
364b36a
 
fa3608f
 
364b36a
 
28673db
364b36a
28673db
364b36a
 
 
28673db
364b36a
 
fa3608f
28673db
364b36a
 
 
fa3608f
 
 
28673db
364b36a
 
 
 
 
 
 
 
 
28673db
364b36a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

import torch
import torch.nn as nn
import torchvision.models as models

class IMG2GPS(nn.Module):
    """
    EfficientNet-B0 model for GPS coordinate prediction from images.

    Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
    Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
    """

    def __init__(self):
        super().__init__()

        # Load pre-trained EfficientNet-B0 from torchvision
        self.backbone = models.efficientnet_b0(pretrained=False)

        # Replace the final classifier layer to output 2 values
        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier[1] = nn.Linear(num_features, 2)

        # Hardcoded normalization statistics from training set
        self.lat_mean = 39.951525
        self.lat_std = 0.000652
        self.lon_mean = -75.191400
        self.lon_std = 0.000598

    def forward(self, x):
        """
        Forward pass through the model.

        Args:
            x: Input tensor of shape (N, 3, 224, 224) - normalized images
               or list of tensors

        Returns:
            Tensor of shape (N, 2) - denormalized lat/lon in degrees
        """
        # Handle case where input is a list of tensors
        if isinstance(x, list):
            x = torch.stack(x)

        # Model outputs normalized GPS coordinates
        normalized_coords = self.backbone(x)  # Shape: (N, 2)

        # Denormalize to get raw lat/lon in degrees
        denormalized_coords = normalized_coords * torch.tensor(
            [self.lat_std, self.lon_std],
            device=x.device,
            dtype=x.dtype
        ) + torch.tensor(
            [self.lat_mean, self.lon_mean],
            device=x.device,
            dtype=x.dtype
        )

        return denormalized_coords

    def predict(self, batch):
        """
        Inference method for compatibility with backend.

        Args:
            batch: Input tensor of shape (N, 3, 224, 224)
                   or list of tensors

        Returns:
            numpy array of shape (N, 2) with raw lat/lon in degrees
        """
        # Handle case where batch is a list of tensors
        if isinstance(batch, list):
            batch = torch.stack(batch)

        self.eval()
        with torch.no_grad():
            output = self.forward(batch)
        return output.cpu().numpy()


def get_model():
    """
    Factory function to instantiate the model.

    Returns:
        IMG2GPS model instance
    """
    return IMG2GPS()