File size: 7,332 Bytes
9283b32
855fdc8
9283b32
 
 
7f9817c
855fdc8
9283b32
 
 
 
 
 
 
894aed8
9283b32
 
 
 
 
 
855fdc8
9283b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855fdc8
9283b32
 
 
7f9817c
 
 
 
 
 
 
 
 
855fdc8
894aed8
9283b32
 
 
 
 
 
 
 
894aed8
 
 
 
a560523
 
963878a
2bb8939
9283b32
 
894aed8
 
 
 
 
 
 
 
 
 
 
 
 
 
a560523
894aed8
 
7f9817c
 
 
894aed8
855fdc8
 
894aed8
855fdc8
 
894aed8
 
7f9817c
894aed8
7f9817c
 
894aed8
855fdc8
 
 
 
 
894aed8
855fdc8
894aed8
 
855fdc8
 
 
894aed8
855fdc8
 
 
 
 
 
 
7f9817c
 
 
855fdc8
894aed8
 
7f9817c
963878a
7f9817c
9283b32
 
855fdc8
b44aa42
 
4c1811b
 
 
 
 
 
 
545aba6
4c1811b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b44aa42
855fdc8
 
 
894aed8
 
 
 
 
 
 
855fdc8
 
a58a7bc
894aed8
 
 
9283b32
 
855fdc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import torch, numpy as np, json, os
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import pygeohash as pgh
import plotly.graph_objects as go
import torch.nn as nn

EXPORT_DIR = "."
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------- Load metadata ----------------
metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
geoh2id = metadata["geoh2id"]
id2geoh = {int(v): k for k, v in geoh2id.items()}
country2id = metadata["country2id"]
clip_model_name = metadata["clip_model"]
dim = metadata["embedding_dim"]
num_fine = len(geoh2id)
num_countries = len(country2id)

# ---------------- Model ----------------
class GeoHybridModel(nn.Module):
    def __init__(self, in_dim, num_classes, num_countries, hidden=1024, drop=0.3):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(drop)
        )
        self.classifier = nn.Linear(hidden//2, num_classes)
        self.regressor  = nn.Linear(hidden//2, 2)
        self.country_classifier = nn.Linear(hidden//2, num_countries)

    def forward(self, x):
        feat = self.shared(x)
        return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)

model = GeoHybridModel(dim, num_fine, num_countries)
model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
model.to(DEVICE).eval()

# ---------------- CLIP ----------------
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()

# ---------------- Haversine ----------------
def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0
    phi1,phi2 = np.radians(lat1), np.radians(lat2)
    dphi = np.radians(lat2-lat1)
    dlambda = np.radians(lon2-lon1)
    a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
    return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))

# ---------------- Prediction + Map ----------------
def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None, top_k=5):
    c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        emb = clip_model.get_image_features(**c_in)
        emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
        out_class, out_offset, _ = model(emb)
        out_class_np = out_class.cpu().numpy()[0]
        out_offset_np = out_offset.cpu().numpy()[0]

    topk_idx = out_class_np.argsort()[-top_k:][::-1]
    lats, lons = [], []

    # Markdown table header
    preds_text = ["| Rank | Latitude | Longitude | Distance |",
                  "|------|----------|-----------|----------|"]

    for rank, i in enumerate(topk_idx, 1):
        geoh = id2geoh[i]
        lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
        lat_pred = lat_base + out_offset_np[0] * cell_lat
        lon_pred = lon_base + out_offset_np[1] * cell_lon

        dist_str = ""
        if true_lat is not None and true_lon is not None:
            dist = haversine(true_lat, true_lon, lat_pred, lon_pred)
            if dist < 100:
                dist_str = f"🟢 {dist:.1f} km"
            elif dist < 1000:
                dist_str = f"🟡 {dist:.1f} km"
            else:
                dist_str = f"🔴 {dist:.1f} km"

        preds_text.append(
            f"| {rank} | {lat_pred:.5f} | {lon_pred:.5f} | {dist_str} |"
        )

        lats.append(lat_pred)
        lons.append(lon_pred)

    # -------- Map Plot --------
    fig = go.Figure()

    # Predictions
    fig.add_trace(go.Scattermapbox(
        lat=lats, lon=lons, mode="markers+text",
        text=[f"Top-{i+1}" for i in range(len(lats))],
        textposition="top right",
        marker=go.scattermapbox.Marker(size=12, color="blue"),
        name="Predictions", showlegend=False
    ))

    # True location (optional)
    if true_lat is not None and true_lon is not None:
        fig.add_trace(go.Scattermapbox(
            lat=[true_lat], lon=[true_lon], mode="markers+text",
            text=["True Location"], textposition="bottom right",
            marker=go.scattermapbox.Marker(size=14, color="red"),
            name="True Location", showlegend=False
        ))
        # Connect with lines
        for lat_p, lon_p in zip(lats, lons):
            fig.add_trace(go.Scattermapbox(
                lat=[true_lat, lat_p],
                lon=[true_lon, lon_p],
                mode="lines",
                line=dict(width=2, color="green"),
                showlegend=False
            ))

    # Layout
    center_lat = true_lat if true_lat is not None else lats[0]
    center_lon = true_lon if true_lon is not None else lons[0]
    fig.update_layout(
        mapbox_style="open-street-map",
        hovermode="closest",
        mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
        margin={"r": 0, "t": 0, "l": 0, "b": 0},
        showlegend=False
    )

    return "\n".join(preds_text), fig

# ---------------- Gradio UI ----------------
with gr.Blocks() as demo:
    gr.Markdown(
        """
# 🌍 Locus - GeoGuessr Image to Coordinates Model  
Upload a **streetview-style image** (photo or screenshot).  
The model can handle images with a small UI overlay (like from Google Street View),  
as long as the **main scene is clearly visible**. Locus was trained on roughly 80k+ Google Street View images without textual descriptions.

You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.  

## ⚠️ Important Usage Notice  
Locus is developed **solely for research, experimentation, and educational exploration**.  
It is intended for tasks such as:  
- Studying computer vision and geolocation models  
- Building games or challenges similar to GeoGuessr  
- Learning about machine learning workflows and dataset training  

You may **not** use Locus for any illegal, unethical, or harmful purposes, including but not limited to:  
- Doxing (revealing private or personal information about individuals)  
- Stalking, harassment, or invasion of privacy  
- Law enforcement or military targeting  
- Commercial exploitation without consent  
- Any activity that violates local, national, or international laws  

By using Locus, you agree to respect these boundaries and acknowledge that the model is a **research prototype**, not a tool for real-world deployment.  
"""
    )
    with gr.Row():
        img_in = gr.Image(type="pil", label="Upload Streetview Image")
        with gr.Column():
            lat_in = gr.Number(label="True Latitude (optional)")
            lon_in = gr.Number(label="True Longitude (optional)")
            topk_in = gr.Slider(1, 10, value=5, step=1,
                                label="Top K Predictions",
                                info="Choose how many top predictions to display (1–10).")

    out_text = gr.Markdown(label="Predictions (Markdown Table)")
    out_plot = gr.Plot(label="Map")

    run_btn = gr.Button("Run Prediction", variant="primary")
    run_btn.click(predict_geohash_map,
                  [img_in, lat_in, lon_in, topk_in],
                  [out_text, out_plot])

if __name__ == "__main__":
    demo.launch()