File size: 5,228 Bytes
57dd1f2
8d31b64
57dd1f2
8d31b64
1ee730d
57dd1f2
 
6719b6d
bd054d0
57dd1f2
 
 
b14c4ae
57dd1f2
b14c4ae
 
e9f56b6
 
 
 
 
 
 
 
 
 
57dd1f2
 
b14c4ae
e9f56b6
 
 
 
b14c4ae
 
 
57dd1f2
12600c7
85b699e
57dd1f2
 
 
e9f56b6
 
 
 
 
57dd1f2
 
 
 
 
 
 
 
 
f4eb0c0
 
57dd1f2
 
 
 
 
 
 
 
 
 
1b94069
57dd1f2
 
 
 
533adca
57dd1f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ff029
e9f56b6
 
 
 
57dd1f2
e9f56b6
 
 
 
57dd1f2
 
 
 
 
1b94069
 
 
 
 
 
 
 
 
57dd1f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import transformers
from transformers import ViTModel, ViTFeatureExtractor, ViTImageProcessor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt

class geoNet(nn.Module):
    def __init__(self):
        super(geoNet, self).__init__()
        self.name = "geo"
        self.fc1 = nn.Linear(768, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)
        self.classifier = nn.Linear(64, 63) # Classification head
        self.regressor = nn.Linear(64, 2) # Regression head for (latitude, longitude)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.bn3(self.fc3(x)))
        x = F.relu(self.bn4(self.fc4(x)))
        province_pred = self.classifier(x)
        coords_pred = self.regressor(x)
        return province_pred, coords_pred

model = geoNet()
model_path = 'geo_53.07_15.52'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Load pretrained ViT model for feature extraction
model_name = 'google/vit-base-patch16-224-in21k'
vit = ViTModel.from_pretrained(model_name, attn_implementation="eager")
processor = ViTImageProcessor.from_pretrained(model_name)
vit.eval()

# Define preprocessing function
def crop_bottom(img):
    width, height = img.size
    return img.crop((0, 0, width, height - 18))  # get rid of author label

preprocess = transforms.Compose([
    transforms.Lambda(crop_bottom),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load and preprocess the image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = preprocess(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

def imshow(img):
    fig, ax = plt.subplots()
    img = img / 2 + 0.5  # Denormalize the image
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')  # Turn off axis labels
    st.pyplot(fig)

# Streamlit app
st.title("Image Geolocation Prediction")

# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
    # Save uploaded image to disk
    image_path = "uploaded_image.jpg"
    with open(image_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    # Load and show the image
    image = load_image(image_path)
    st.image(image_path, caption='Uploaded Image', use_column_width=True)
    image_to_show = image.squeeze(0)
    imshow(image_to_show)

    # Move image to GPU if available
    if torch.cuda.is_available():
        image = image.cuda()

    # Extract features
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    outputs = vit(**inputs, output_attentions = True)
    last_hidden_states = outputs.last_hidden_state
    cls_hidden_state = last_hidden_states[:, 0, :]
    with torch.no_grad():
        province_preds, coord_preds = model(cls_hidden_state)
        
    _, predicted_province = torch.max(province_preds, 1)
    predicted_coords = coord_preds.cpu().numpy()

    # Load class to index mapping
    class_to_idx = {'Alabama': 0, 'Alaska': 1, 'Alberta': 2, 'Arizona': 3, 'Arkansas': 4, 'British Columbia': 5, 'California': 6, 'Colorado': 7, 'Connecticut': 8, 'Delaware': 9, 'Florida': 10, 'Georgia': 11, 'Hawaii': 12, 'Idaho': 13, 'Illinois': 14, 'Indiana': 15, 'Iowa': 16, 'Kansas': 17, 'Kentucky': 18, 'Louisiana': 19, 'Maine': 20, 'Manitoba': 21, 'Maryland': 22, 'Massachusetts': 23, 'Michigan': 24, 'Minnesota': 25, 'Mississippi': 26, 'Missouri': 27, 'Montana': 28, 'Nebraska': 29, 'Nevada': 30, 'New Brunswick': 31, 'New Hampshire': 32, 'New Jersey': 33, 'New Mexico': 34, 'New York': 35, 'Newfoundland and Labrador': 36, 'North Carolina': 37, 'North Dakota': 38, 'Northwest Territories': 39, 'Nova Scotia': 40, 'Nunavut': 41, 'Ohio': 42, 'Oklahoma': 43, 'Ontario': 44, 'Oregon': 45, 'Pennsylvania': 46, 'Prince Edward Island': 47, 'Quebec': 48, 'Rhode Island': 49, 'Saskatchewan': 50, 'South Carolina': 51, 'South Dakota': 52, 'Tennessee': 53, 'Texas': 54, 'Utah': 55, 'Vermont': 56, 'Virginia': 57, 'Washington': 58, 'West Virginia': 59, 'Wisconsin': 60, 'Wyoming': 61, 'Yukon': 62}
    idx_to_class = {idx: class_name for class_name, idx in class_to_idx.items()}

    # Display predictions with increased font size
    st.markdown(
        f"<h3 style='font-size:20px;'>Predicted Province/State Index: {idx_to_class.get(predicted_province.item(), None)}</h3>",
        unsafe_allow_html=True
    )
    st.markdown(
        f"<h3 style='font-size:20px;'>Predicted Coordinates: {predicted_coords}</h3>",
        unsafe_allow_html=True
    )
else:
    st.write("Please upload an image to get predictions.")