RolfeD11 commited on
Commit
e9f56b6
·
verified ·
1 Parent(s): 533adca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -12,13 +12,23 @@ class geoNet(nn.Module):
12
  def __init__(self):
13
  super(geoNet, self).__init__()
14
  self.name = "geo"
15
- self.fc1 = nn.Linear(256*6*6, 244)
16
- self.classifier = nn.Linear(244, 63) # Classification head
17
- self.regressor = nn.Linear(244, 2) # Regression head for (latitude, longitude)
 
 
 
 
 
 
 
18
 
19
  def forward(self, x):
20
  x = x.view(x.size(0), -1)
21
- x = F.relu(self.fc1(x))
 
 
 
22
  province_pred = self.classifier(x)
23
  coords_pred = self.regressor(x)
24
  return province_pred, coords_pred
@@ -28,10 +38,11 @@ model_path = 'model_geo_epoch79.pth'
28
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
29
  model.eval()
30
 
31
- # Load pretrained AlexNet model for feature extraction
32
- alexnet = torchvision.models.alexnet(pretrained=True)
33
- alexnet_features = alexnet.features # only need feature extraction
34
- alexnet_features.eval()
 
35
 
36
  # Define preprocessing function
37
  def crop_bottom(img):
@@ -82,15 +93,15 @@ if uploaded_file is not None:
82
  image = image.cuda()
83
 
84
  # Extract features
 
 
 
 
85
  with torch.no_grad():
86
- features = alexnet_features(image)
87
- if torch.cuda.is_available():
88
- features = features.cuda()
89
-
90
- # Predict province and coordinates
91
- province_preds, coord_preds = model(features)
92
- _, predicted_province = torch.max(province_preds, 1)
93
- predicted_coords = coord_preds.cpu().numpy()
94
 
95
  # Load class to index mapping
96
  class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
 
12
  def __init__(self):
13
  super(geoNet, self).__init__()
14
  self.name = "geo"
15
+ self.fc1 = nn.Linear(768, 512)
16
+ self.bn1 = nn.BatchNorm1d(512)
17
+ self.fc2 = nn.Linear(512, 256)
18
+ self.bn2 = nn.BatchNorm1d(256)
19
+ self.fc3 = nn.Linear(256, 128)
20
+ self.bn3 = nn.BatchNorm1d(128)
21
+ self.fc4 = nn.Linear(128, 64)
22
+ self.bn4 = nn.BatchNorm1d(64)
23
+ self.classifier = nn.Linear(64, 63) # Classification head
24
+ self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude)
25
 
26
  def forward(self, x):
27
  x = x.view(x.size(0), -1)
28
+ x = F.relu(self.bn1(self.fc1(x)))
29
+ x = F.relu(self.bn2(self.fc2(x)))
30
+ x = F.relu(self.bn3(self.fc3(x)))
31
+ x = F.relu(self.bn4(self.fc4(x)))
32
  province_pred = self.classifier(x)
33
  coords_pred = self.regressor(x)
34
  return province_pred, coords_pred
 
38
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
39
  model.eval()
40
 
41
+ # Load pretrained ViT model for feature extraction
42
+ model_name = 'google/vit-base-patch16-224-in21k'
43
+ vit = ViTModel.from_pretrained(model_name, attn_implementation="eager")
44
+ processor = ViTImageProcessor.from_pretrained(model_name)
45
+ vit.eval()
46
 
47
  # Define preprocessing function
48
  def crop_bottom(img):
 
93
  image = image.cuda()
94
 
95
  # Extract features
96
+ inputs = processor(images=image, return_tensors="pt")
97
+ outputs = vit(**inputs, output_attentions = True)
98
+ last_hidden_states = outputs.last_hidden_state
99
+ cls_hidden_state = last_hidden_states[:, 0, :]
100
  with torch.no_grad():
101
+ province_preds, coord_preds = model(cls_hidden_state)
102
+
103
+ _, predicted_province = torch.max(province_preds, 1)
104
+ predicted_coords = coord_preds.cpu().numpy()
 
 
 
 
105
 
106
  # Load class to index mapping
107
  class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}