import streamlit as st import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms import transformers from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification from PIL import Image import matplotlib.pyplot as plt class geoNet(nn.Module): def __init__(self): super(geoNet, self).__init__() self.name = "geo" self.fc1 = nn.Linear(768, 512) self.bn1 = nn.BatchNorm1d(512) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.fc3 = nn.Linear(256, 128) self.bn3 = nn.BatchNorm1d(128) self.fc4 = nn.Linear(128, 64) self.bn4 = nn.BatchNorm1d(64) self.classifier = nn.Linear(64, 63) # Classification head self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude) def forward(self, x): x = x.view(x.size(0), -1) x = F.relu(self.bn1(self.fc1(x))) x = F.relu(self.bn2(self.fc2(x))) x = F.relu(self.bn3(self.fc3(x))) x = F.relu(self.bn4(self.fc4(x))) province_pred = self.classifier(x) coords_pred = self.regressor(x) return province_pred, coords_pred model = geoNet() model_path = 'geo_53.07_15.52' model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() # Load pretrained ViT model for feature extraction model_name = 'google/vit-base-patch16-224-in21k' vit = ViTModel.from_pretrained(model_name, attn_implementation="eager") processor = ViTImageProcessor.from_pretrained(model_name) vit.eval() # Define preprocessing function def crop_bottom(img): width, height = img.size return img.crop((0, 0, width, height - 18)) # get rid of author label preprocess = transforms.Compose([ transforms.Lambda(crop_bottom), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Load and preprocess the image def load_image(image_path): image = Image.open(image_path).convert('RGB') image = preprocess(image) image = image.unsqueeze(0) # Add batch dimension return image def imshow(img): fig, ax = plt.subplots() img = img / 2 + 0.5 # Denormalize the image npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.axis('off') # Turn off axis labels st.pyplot(fig) # Streamlit app st.title("Image Geolocation Prediction") # Upload image uploaded_file = st.file_uploader("Choose an image...", type="jpg") if uploaded_file is not None: # Save uploaded image to disk image_path = "uploaded_image.jpg" with open(image_path, "wb") as f: f.write(uploaded_file.getbuffer()) # Load and show the image image = load_image(image_path) st.image(image_path, caption='Uploaded Image', use_column_width=True) image_to_show = image.squeeze(0) imshow(image_to_show) # Move image to GPU if available if torch.cuda.is_available(): image = image.cuda() # Extract features image = Image.open(image_path).convert('RGB') inputs = processor(images=image, return_tensors="pt") outputs = vit(**inputs, output_attentions = True) last_hidden_states = outputs.last_hidden_state cls_hidden_state = last_hidden_states[:, 0, :] with torch.no_grad(): province_preds, coord_preds = model(cls_hidden_state) _, predicted_province = torch.max(province_preds, 1) predicted_coords = coord_preds.cpu().numpy() # Load class to index mapping class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62} idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()} # Display predictions with increased font size st.markdown( f"