Spaces:
Sleeping
Sleeping
File size: 7,332 Bytes
9283b32 855fdc8 9283b32 7f9817c 855fdc8 9283b32 894aed8 9283b32 855fdc8 9283b32 855fdc8 9283b32 7f9817c 855fdc8 894aed8 9283b32 894aed8 a560523 963878a 2bb8939 9283b32 894aed8 a560523 894aed8 7f9817c 894aed8 855fdc8 894aed8 855fdc8 894aed8 7f9817c 894aed8 7f9817c 894aed8 855fdc8 894aed8 855fdc8 894aed8 855fdc8 894aed8 855fdc8 7f9817c 855fdc8 894aed8 7f9817c 963878a 7f9817c 9283b32 855fdc8 b44aa42 4c1811b 545aba6 4c1811b b44aa42 855fdc8 894aed8 855fdc8 a58a7bc 894aed8 9283b32 855fdc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gradio as gr
import torch, numpy as np, json, os
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import pygeohash as pgh
import plotly.graph_objects as go
import torch.nn as nn
EXPORT_DIR = "."
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------- Load metadata ----------------
metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
geoh2id = metadata["geoh2id"]
id2geoh = {int(v): k for k, v in geoh2id.items()}
country2id = metadata["country2id"]
clip_model_name = metadata["clip_model"]
dim = metadata["embedding_dim"]
num_fine = len(geoh2id)
num_countries = len(country2id)
# ---------------- Model ----------------
class GeoHybridModel(nn.Module):
def __init__(self, in_dim, num_classes, num_countries, hidden=1024, drop=0.3):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(drop),
nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(drop)
)
self.classifier = nn.Linear(hidden//2, num_classes)
self.regressor = nn.Linear(hidden//2, 2)
self.country_classifier = nn.Linear(hidden//2, num_countries)
def forward(self, x):
feat = self.shared(x)
return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
model = GeoHybridModel(dim, num_fine, num_countries)
model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
model.to(DEVICE).eval()
# ---------------- CLIP ----------------
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
# ---------------- Haversine ----------------
def haversine(lat1, lon1, lat2, lon2):
R = 6371.0
phi1,phi2 = np.radians(lat1), np.radians(lat2)
dphi = np.radians(lat2-lat1)
dlambda = np.radians(lon2-lon1)
a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
# ---------------- Prediction + Map ----------------
def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None, top_k=5):
c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
emb = clip_model.get_image_features(**c_in)
emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
out_class, out_offset, _ = model(emb)
out_class_np = out_class.cpu().numpy()[0]
out_offset_np = out_offset.cpu().numpy()[0]
topk_idx = out_class_np.argsort()[-top_k:][::-1]
lats, lons = [], []
# Markdown table header
preds_text = ["| Rank | Latitude | Longitude | Distance |",
"|------|----------|-----------|----------|"]
for rank, i in enumerate(topk_idx, 1):
geoh = id2geoh[i]
lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
lat_pred = lat_base + out_offset_np[0] * cell_lat
lon_pred = lon_base + out_offset_np[1] * cell_lon
dist_str = ""
if true_lat is not None and true_lon is not None:
dist = haversine(true_lat, true_lon, lat_pred, lon_pred)
if dist < 100:
dist_str = f"🟢 {dist:.1f} km"
elif dist < 1000:
dist_str = f"🟡 {dist:.1f} km"
else:
dist_str = f"🔴 {dist:.1f} km"
preds_text.append(
f"| {rank} | {lat_pred:.5f} | {lon_pred:.5f} | {dist_str} |"
)
lats.append(lat_pred)
lons.append(lon_pred)
# -------- Map Plot --------
fig = go.Figure()
# Predictions
fig.add_trace(go.Scattermapbox(
lat=lats, lon=lons, mode="markers+text",
text=[f"Top-{i+1}" for i in range(len(lats))],
textposition="top right",
marker=go.scattermapbox.Marker(size=12, color="blue"),
name="Predictions", showlegend=False
))
# True location (optional)
if true_lat is not None and true_lon is not None:
fig.add_trace(go.Scattermapbox(
lat=[true_lat], lon=[true_lon], mode="markers+text",
text=["True Location"], textposition="bottom right",
marker=go.scattermapbox.Marker(size=14, color="red"),
name="True Location", showlegend=False
))
# Connect with lines
for lat_p, lon_p in zip(lats, lons):
fig.add_trace(go.Scattermapbox(
lat=[true_lat, lat_p],
lon=[true_lon, lon_p],
mode="lines",
line=dict(width=2, color="green"),
showlegend=False
))
# Layout
center_lat = true_lat if true_lat is not None else lats[0]
center_lon = true_lon if true_lon is not None else lons[0]
fig.update_layout(
mapbox_style="open-street-map",
hovermode="closest",
mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
margin={"r": 0, "t": 0, "l": 0, "b": 0},
showlegend=False
)
return "\n".join(preds_text), fig
# ---------------- Gradio UI ----------------
with gr.Blocks() as demo:
gr.Markdown(
"""
# 🌍 Locus - GeoGuessr Image to Coordinates Model
Upload a **streetview-style image** (photo or screenshot).
The model can handle images with a small UI overlay (like from Google Street View),
as long as the **main scene is clearly visible**. Locus was trained on roughly 80k+ Google Street View images without textual descriptions.
You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.
## ⚠️ Important Usage Notice
Locus is developed **solely for research, experimentation, and educational exploration**.
It is intended for tasks such as:
- Studying computer vision and geolocation models
- Building games or challenges similar to GeoGuessr
- Learning about machine learning workflows and dataset training
You may **not** use Locus for any illegal, unethical, or harmful purposes, including but not limited to:
- Doxing (revealing private or personal information about individuals)
- Stalking, harassment, or invasion of privacy
- Law enforcement or military targeting
- Commercial exploitation without consent
- Any activity that violates local, national, or international laws
By using Locus, you agree to respect these boundaries and acknowledge that the model is a **research prototype**, not a tool for real-world deployment.
"""
)
with gr.Row():
img_in = gr.Image(type="pil", label="Upload Streetview Image")
with gr.Column():
lat_in = gr.Number(label="True Latitude (optional)")
lon_in = gr.Number(label="True Longitude (optional)")
topk_in = gr.Slider(1, 10, value=5, step=1,
label="Top K Predictions",
info="Choose how many top predictions to display (1–10).")
out_text = gr.Markdown(label="Predictions (Markdown Table)")
out_plot = gr.Plot(label="Map")
run_btn = gr.Button("Run Prediction", variant="primary")
run_btn.click(predict_geohash_map,
[img_in, lat_in, lon_in, topk_in],
[out_text, out_plot])
if __name__ == "__main__":
demo.launch() |