Update app.py
Browse files
app.py
CHANGED
|
@@ -12,13 +12,23 @@ class geoNet(nn.Module):
|
|
| 12 |
def __init__(self):
|
| 13 |
super(geoNet, self).__init__()
|
| 14 |
self.name = "geo"
|
| 15 |
-
self.fc1 = nn.Linear(
|
| 16 |
-
self.
|
| 17 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def forward(self, x):
|
| 20 |
x = x.view(x.size(0), -1)
|
| 21 |
-
x = F.relu(self.fc1(x))
|
|
|
|
|
|
|
|
|
|
| 22 |
province_pred = self.classifier(x)
|
| 23 |
coords_pred = self.regressor(x)
|
| 24 |
return province_pred, coords_pred
|
|
@@ -28,10 +38,11 @@ model_path = 'model_geo_epoch79.pth'
|
|
| 28 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 29 |
model.eval()
|
| 30 |
|
| 31 |
-
# Load pretrained
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
# Define preprocessing function
|
| 37 |
def crop_bottom(img):
|
|
@@ -82,15 +93,15 @@ if uploaded_file is not None:
|
|
| 82 |
image = image.cuda()
|
| 83 |
|
| 84 |
# Extract features
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
with torch.no_grad():
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# Predict province and coordinates
|
| 91 |
-
province_preds, coord_preds = model(features)
|
| 92 |
-
_, predicted_province = torch.max(province_preds, 1)
|
| 93 |
-
predicted_coords = coord_preds.cpu().numpy()
|
| 94 |
|
| 95 |
# Load class to index mapping
|
| 96 |
class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
|
|
|
|
| 12 |
def __init__(self):
|
| 13 |
super(geoNet, self).__init__()
|
| 14 |
self.name = "geo"
|
| 15 |
+
self.fc1 = nn.Linear(768, 512)
|
| 16 |
+
self.bn1 = nn.BatchNorm1d(512)
|
| 17 |
+
self.fc2 = nn.Linear(512, 256)
|
| 18 |
+
self.bn2 = nn.BatchNorm1d(256)
|
| 19 |
+
self.fc3 = nn.Linear(256, 128)
|
| 20 |
+
self.bn3 = nn.BatchNorm1d(128)
|
| 21 |
+
self.fc4 = nn.Linear(128, 64)
|
| 22 |
+
self.bn4 = nn.BatchNorm1d(64)
|
| 23 |
+
self.classifier = nn.Linear(64, 63) # Classification head
|
| 24 |
+
self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude)
|
| 25 |
|
| 26 |
def forward(self, x):
|
| 27 |
x = x.view(x.size(0), -1)
|
| 28 |
+
x = F.relu(self.bn1(self.fc1(x)))
|
| 29 |
+
x = F.relu(self.bn2(self.fc2(x)))
|
| 30 |
+
x = F.relu(self.bn3(self.fc3(x)))
|
| 31 |
+
x = F.relu(self.bn4(self.fc4(x)))
|
| 32 |
province_pred = self.classifier(x)
|
| 33 |
coords_pred = self.regressor(x)
|
| 34 |
return province_pred, coords_pred
|
|
|
|
| 38 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 39 |
model.eval()
|
| 40 |
|
| 41 |
+
# Load pretrained ViT model for feature extraction
|
| 42 |
+
model_name = 'google/vit-base-patch16-224-in21k'
|
| 43 |
+
vit = ViTModel.from_pretrained(model_name, attn_implementation="eager")
|
| 44 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
| 45 |
+
vit.eval()
|
| 46 |
|
| 47 |
# Define preprocessing function
|
| 48 |
def crop_bottom(img):
|
|
|
|
| 93 |
image = image.cuda()
|
| 94 |
|
| 95 |
# Extract features
|
| 96 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 97 |
+
outputs = vit(**inputs, output_attentions = True)
|
| 98 |
+
last_hidden_states = outputs.last_hidden_state
|
| 99 |
+
cls_hidden_state = last_hidden_states[:, 0, :]
|
| 100 |
with torch.no_grad():
|
| 101 |
+
province_preds, coord_preds = model(cls_hidden_state)
|
| 102 |
+
|
| 103 |
+
_, predicted_province = torch.max(province_preds, 1)
|
| 104 |
+
predicted_coords = coord_preds.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Load class to index mapping
|
| 107 |
class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
|