File size: 5,228 Bytes
57dd1f2 8d31b64 57dd1f2 8d31b64 1ee730d 57dd1f2 6719b6d bd054d0 57dd1f2 b14c4ae 57dd1f2 b14c4ae e9f56b6 57dd1f2 b14c4ae e9f56b6 b14c4ae 57dd1f2 12600c7 85b699e 57dd1f2 e9f56b6 57dd1f2 f4eb0c0 57dd1f2 1b94069 57dd1f2 533adca 57dd1f2 73ff029 e9f56b6 57dd1f2 e9f56b6 57dd1f2 1b94069 57dd1f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import transformers
from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt
class geoNet(nn.Module):
def __init__(self):
super(geoNet, self).__init__()
self.name = "geo"
self.fc1 = nn.Linear(768, 512)
self.bn1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.fc3 = nn.Linear(256, 128)
self.bn3 = nn.BatchNorm1d(128)
self.fc4 = nn.Linear(128, 64)
self.bn4 = nn.BatchNorm1d(64)
self.classifier = nn.Linear(64, 63) # Classification head
self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude)
def forward(self, x):
x = x.view(x.size(0), -1)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.bn3(self.fc3(x)))
x = F.relu(self.bn4(self.fc4(x)))
province_pred = self.classifier(x)
coords_pred = self.regressor(x)
return province_pred, coords_pred
model = geoNet()
model_path = 'geo_53.07_15.52'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Load pretrained ViT model for feature extraction
model_name = 'google/vit-base-patch16-224-in21k'
vit = ViTModel.from_pretrained(model_name, attn_implementation="eager")
processor = ViTImageProcessor.from_pretrained(model_name)
vit.eval()
# Define preprocessing function
def crop_bottom(img):
width, height = img.size
return img.crop((0, 0, width, height - 18)) # get rid of author label
preprocess = transforms.Compose([
transforms.Lambda(crop_bottom),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load and preprocess the image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = preprocess(image)
image = image.unsqueeze(0) # Add batch dimension
return image
def imshow(img):
fig, ax = plt.subplots()
img = img / 2 + 0.5 # Denormalize the image
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off') # Turn off axis labels
st.pyplot(fig)
# Streamlit app
st.title("Image Geolocation Prediction")
# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
# Save uploaded image to disk
image_path = "uploaded_image.jpg"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Load and show the image
image = load_image(image_path)
st.image(image_path, caption='Uploaded Image', use_column_width=True)
image_to_show = image.squeeze(0)
imshow(image_to_show)
# Move image to GPU if available
if torch.cuda.is_available():
image = image.cuda()
# Extract features
image = Image.open(image_path).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
outputs = vit(**inputs, output_attentions = True)
last_hidden_states = outputs.last_hidden_state
cls_hidden_state = last_hidden_states[:, 0, :]
with torch.no_grad():
province_preds, coord_preds = model(cls_hidden_state)
_, predicted_province = torch.max(province_preds, 1)
predicted_coords = coord_preds.cpu().numpy()
# Load class to index mapping
class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()}
# Display predictions with increased font size
st.markdown(
f"<h3 style='font-size:20px;'>Predicted Province/State Index: {idx_to_class.get(predicted_province.item(), None)}</h3>",
unsafe_allow_html=True
)
st.markdown(
f"<h3 style='font-size:20px;'>Predicted Coordinates: {predicted_coords}</h3>",
unsafe_allow_html=True
)
else:
st.write("Please upload an image to get predictions.")
|