| import streamlit as st |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision |
| import torchvision.transforms as transforms |
| import transformers |
| from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification |
| from PIL import Image |
| import matplotlib.pyplot as plt |
|
|
| class geoNet(nn.Module): |
| def __init__(self): |
| super(geoNet, self).__init__() |
| self.name = "geo" |
| self.fc1 = nn.Linear(768, 512) |
| self.bn1 = nn.BatchNorm1d(512) |
| self.fc2 = nn.Linear(512, 256) |
| self.bn2 = nn.BatchNorm1d(256) |
| self.fc3 = nn.Linear(256, 128) |
| self.bn3 = nn.BatchNorm1d(128) |
| self.fc4 = nn.Linear(128, 64) |
| self.bn4 = nn.BatchNorm1d(64) |
| self.classifier = nn.Linear(64, 63) |
| self.regressor = nn.Linear(64, 2) |
|
|
| def forward(self, x): |
| x = x.view(x.size(0), -1) |
| x = F.relu(self.bn1(self.fc1(x))) |
| x = F.relu(self.bn2(self.fc2(x))) |
| x = F.relu(self.bn3(self.fc3(x))) |
| x = F.relu(self.bn4(self.fc4(x))) |
| province_pred = self.classifier(x) |
| coords_pred = self.regressor(x) |
| return province_pred, coords_pred |
|
|
| model = geoNet() |
| model_path = 'geo_53.07_15.52' |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
| model.eval() |
|
|
| |
| model_name = 'google/vit-base-patch16-224-in21k' |
| vit = ViTModel.from_pretrained(model_name, attn_implementation="eager") |
| processor = ViTImageProcessor.from_pretrained(model_name) |
| vit.eval() |
|
|
| |
| def crop_bottom(img): |
| width, height = img.size |
| return img.crop((0, 0, width, height - 18)) |
|
|
| preprocess = transforms.Compose([ |
| transforms.Lambda(crop_bottom), |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
|
|
| |
| def load_image(image_path): |
| image = Image.open(image_path).convert('RGB') |
| image = preprocess(image) |
| image = image.unsqueeze(0) |
| return image |
|
|
| def imshow(img): |
| fig, ax = plt.subplots() |
| img = img / 2 + 0.5 |
| npimg = img.numpy() |
| plt.imshow(np.transpose(npimg, (1, 2, 0))) |
| plt.axis('off') |
| st.pyplot(fig) |
|
|
| |
| st.title("Image Geolocation Prediction") |
|
|
| |
| uploaded_file = st.file_uploader("Choose an image...", type="jpg") |
| if uploaded_file is not None: |
| |
| image_path = "uploaded_image.jpg" |
| with open(image_path, "wb") as f: |
| f.write(uploaded_file.getbuffer()) |
|
|
| |
| image = load_image(image_path) |
| st.image(image_path, caption='Uploaded Image', use_column_width=True) |
| image_to_show = image.squeeze(0) |
| imshow(image_to_show) |
|
|
| |
| if torch.cuda.is_available(): |
| image = image.cuda() |
|
|
| |
| image = Image.open(image_path).convert('RGB') |
| inputs = processor(images=image, return_tensors="pt") |
| outputs = vit(**inputs, output_attentions = True) |
| last_hidden_states = outputs.last_hidden_state |
| cls_hidden_state = last_hidden_states[:, 0, :] |
| with torch.no_grad(): |
| province_preds, coord_preds = model(cls_hidden_state) |
| |
| _, predicted_province = torch.max(province_preds, 1) |
| predicted_coords = coord_preds.cpu().numpy() |
|
|
| |
| class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62} |
| idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()} |
|
|
| |
| st.markdown( |
| f"<h3 style='font-size:20px;'>Predicted Province/State Index: {idx_to_class.get(predicted_province.item(), None)}</h3>", |
| unsafe_allow_html=True |
| ) |
| st.markdown( |
| f"<h3 style='font-size:20px;'>Predicted Coordinates: {predicted_coords}</h3>", |
| unsafe_allow_html=True |
| ) |
| else: |
| st.write("Please upload an image to get predictions.") |
|
|