limich19 commited on
Commit
6001f90
·
verified ·
1 Parent(s): dba8f48

Add model.py for project submission

Browse files
Files changed (1) hide show
  1. model.py +22 -18
model.py CHANGED
@@ -1,46 +1,50 @@
1
 
2
  import torch
3
  import torch.nn as nn
4
- import timm
5
 
6
  class IMG2GPS(nn.Module):
7
  """
8
  EfficientNet-B0 model for GPS coordinate prediction from images.
9
-
10
  Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
11
  Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
12
  """
13
-
14
  def __init__(self):
15
  super().__init__()
16
-
17
- # Load pre-trained EfficientNet-B0
18
- self.backbone = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2)
19
-
 
 
 
 
20
  # Hardcoded normalization statistics from training set
21
  self.lat_mean = 39.951525
22
  self.lat_std = 0.000652
23
  self.lon_mean = -75.191400
24
  self.lon_std = 0.000598
25
-
26
  def forward(self, x):
27
  """
28
  Forward pass through the model.
29
-
30
  Args:
31
  x: Input tensor of shape (N, 3, 224, 224) - normalized images
32
  or list of tensors
33
-
34
  Returns:
35
  Tensor of shape (N, 2) - denormalized lat/lon in degrees
36
  """
37
  # Handle case where input is a list of tensors
38
  if isinstance(x, list):
39
  x = torch.stack(x)
40
-
41
  # Model outputs normalized GPS coordinates
42
  normalized_coords = self.backbone(x) # Shape: (N, 2)
43
-
44
  # Denormalize to get raw lat/lon in degrees
45
  denormalized_coords = normalized_coords * torch.tensor(
46
  [self.lat_std, self.lon_std],
@@ -51,24 +55,24 @@ class IMG2GPS(nn.Module):
51
  device=x.device,
52
  dtype=x.dtype
53
  )
54
-
55
  return denormalized_coords
56
-
57
  def predict(self, batch):
58
  """
59
  Inference method for compatibility with backend.
60
-
61
  Args:
62
  batch: Input tensor of shape (N, 3, 224, 224)
63
  or list of tensors
64
-
65
  Returns:
66
  numpy array of shape (N, 2) with raw lat/lon in degrees
67
  """
68
  # Handle case where batch is a list of tensors
69
  if isinstance(batch, list):
70
  batch = torch.stack(batch)
71
-
72
  self.eval()
73
  with torch.no_grad():
74
  output = self.forward(batch)
@@ -78,7 +82,7 @@ class IMG2GPS(nn.Module):
78
  def get_model():
79
  """
80
  Factory function to instantiate the model.
81
-
82
  Returns:
83
  IMG2GPS model instance
84
  """
 
1
 
2
  import torch
3
  import torch.nn as nn
4
+ import torchvision.models as models
5
 
6
  class IMG2GPS(nn.Module):
7
  """
8
  EfficientNet-B0 model for GPS coordinate prediction from images.
9
+
10
  Input: Batch of images (N, 3, 224, 224) - ImageNet normalized
11
  Output: Batch of GPS coordinates (N, 2) - raw lat/lon in degrees
12
  """
13
+
14
  def __init__(self):
15
  super().__init__()
16
+
17
+ # Load pre-trained EfficientNet-B0 from torchvision
18
+ self.backbone = models.efficientnet_b0(pretrained=False)
19
+
20
+ # Replace the final classifier layer to output 2 values
21
+ num_features = self.backbone.classifier[1].in_features
22
+ self.backbone.classifier[1] = nn.Linear(num_features, 2)
23
+
24
  # Hardcoded normalization statistics from training set
25
  self.lat_mean = 39.951525
26
  self.lat_std = 0.000652
27
  self.lon_mean = -75.191400
28
  self.lon_std = 0.000598
29
+
30
  def forward(self, x):
31
  """
32
  Forward pass through the model.
33
+
34
  Args:
35
  x: Input tensor of shape (N, 3, 224, 224) - normalized images
36
  or list of tensors
37
+
38
  Returns:
39
  Tensor of shape (N, 2) - denormalized lat/lon in degrees
40
  """
41
  # Handle case where input is a list of tensors
42
  if isinstance(x, list):
43
  x = torch.stack(x)
44
+
45
  # Model outputs normalized GPS coordinates
46
  normalized_coords = self.backbone(x) # Shape: (N, 2)
47
+
48
  # Denormalize to get raw lat/lon in degrees
49
  denormalized_coords = normalized_coords * torch.tensor(
50
  [self.lat_std, self.lon_std],
 
55
  device=x.device,
56
  dtype=x.dtype
57
  )
58
+
59
  return denormalized_coords
60
+
61
  def predict(self, batch):
62
  """
63
  Inference method for compatibility with backend.
64
+
65
  Args:
66
  batch: Input tensor of shape (N, 3, 224, 224)
67
  or list of tensors
68
+
69
  Returns:
70
  numpy array of shape (N, 2) with raw lat/lon in degrees
71
  """
72
  # Handle case where batch is a list of tensors
73
  if isinstance(batch, list):
74
  batch = torch.stack(batch)
75
+
76
  self.eval()
77
  with torch.no_grad():
78
  output = self.forward(batch)
 
82
  def get_model():
83
  """
84
  Factory function to instantiate the model.
85
+
86
  Returns:
87
  IMG2GPS model instance
88
  """