Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,15 +5,16 @@ from transformers import CLIPProcessor, CLIPModel
|
|
| 5 |
import pygeohash as pgh
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import torch.nn as nn
|
|
|
|
|
|
|
| 8 |
|
| 9 |
EXPORT_DIR = "."
|
| 10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
-
TOP_K = 5
|
| 12 |
|
| 13 |
# ---------------- Load metadata ----------------
|
| 14 |
metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
|
| 15 |
geoh2id = metadata["geoh2id"]
|
| 16 |
-
id2geoh = {int(v): k for k,v in geoh2id.items()}
|
| 17 |
country2id = metadata["country2id"]
|
| 18 |
clip_model_name = metadata["clip_model"]
|
| 19 |
dim = metadata["embedding_dim"]
|
|
@@ -53,8 +54,18 @@ def haversine(lat1, lon1, lat2, lon2):
|
|
| 53 |
a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
|
| 54 |
return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# ---------------- Prediction + Map ----------------
|
| 57 |
-
def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
|
| 58 |
c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
|
| 59 |
with torch.no_grad():
|
| 60 |
emb = clip_model.get_image_features(**c_in)
|
|
@@ -63,52 +74,67 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
|
|
| 63 |
out_class_np = out_class.cpu().numpy()[0]
|
| 64 |
out_offset_np = out_offset.cpu().numpy()[0]
|
| 65 |
|
| 66 |
-
topk_idx = out_class_np.argsort()[-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
for rank, i in enumerate(topk_idx, 1):
|
| 71 |
geoh = id2geoh[i]
|
| 72 |
lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
|
| 73 |
-
lat_pred = lat_base + out_offset_np[0]*cell_lat
|
| 74 |
-
lon_pred = lon_base + out_offset_np[1]*cell_lon
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
lats.append(lat_pred)
|
| 77 |
lons.append(lon_pred)
|
| 78 |
-
labels.append(f"Top-{rank}: {geoh}")
|
| 79 |
|
| 80 |
-
#
|
| 81 |
fig = go.Figure()
|
| 82 |
|
| 83 |
-
#
|
| 84 |
fig.add_trace(go.Scattermapbox(
|
| 85 |
lat=lats, lon=lons, mode="markers+text",
|
| 86 |
-
text=
|
|
|
|
| 87 |
marker=go.scattermapbox.Marker(size=12, color="blue"),
|
| 88 |
-
name="Predictions"
|
| 89 |
))
|
| 90 |
|
| 91 |
-
# True location
|
| 92 |
if true_lat is not None and true_lon is not None:
|
| 93 |
fig.add_trace(go.Scattermapbox(
|
| 94 |
lat=[true_lat], lon=[true_lon], mode="markers+text",
|
| 95 |
text=["True Location"], textposition="bottom right",
|
| 96 |
marker=go.scattermapbox.Marker(size=14, color="red"),
|
| 97 |
-
name="True Location"
|
| 98 |
))
|
| 99 |
-
#
|
| 100 |
-
for
|
| 101 |
-
dist = haversine(true_lat, true_lon, lat_p, lon_p)
|
| 102 |
fig.add_trace(go.Scattermapbox(
|
| 103 |
lat=[true_lat, lat_p],
|
| 104 |
lon=[true_lon, lon_p],
|
| 105 |
-
mode="lines
|
| 106 |
-
text=[None, f"{dist:.1f} km"],
|
| 107 |
-
textposition="middle right",
|
| 108 |
line=dict(width=2, color="green"),
|
| 109 |
showlegend=False
|
| 110 |
))
|
| 111 |
-
preds_text.append(f"Dist to Top-{rank}: {dist:.1f} km")
|
| 112 |
|
| 113 |
# Layout
|
| 114 |
center_lat = true_lat if true_lat is not None else lats[0]
|
|
@@ -117,7 +143,8 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
|
|
| 117 |
mapbox_style="open-street-map",
|
| 118 |
hovermode="closest",
|
| 119 |
mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
|
| 120 |
-
margin={"r":0,"t":0,"l":0,"b":0}
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
return "\n".join(preds_text), fig
|
|
@@ -126,27 +153,30 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
|
|
| 126 |
with gr.Blocks() as demo:
|
| 127 |
gr.Markdown(
|
| 128 |
"""
|
| 129 |
-
# π Locus - GeoGuessr Image to Coordinates Model
|
| 130 |
Upload a **streetview-style image** (photo or screenshot).
|
| 131 |
The model can handle images with a small UI overlay (like from Google Street View),
|
| 132 |
-
as long as the **main scene is clearly visible**.
|
| 133 |
|
| 134 |
-
You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.
|
| 135 |
-
|
| 136 |
-
> **Note:**
|
| 137 |
-
> Locus was trained solely on publicly accessible Google Street View images, not on **actual GeoGuessr data**.
|
| 138 |
"""
|
| 139 |
)
|
| 140 |
with gr.Row():
|
| 141 |
img_in = gr.Image(type="pil", label="Upload Streetview Image")
|
| 142 |
with gr.Column():
|
| 143 |
-
lat_in = gr.Number(label="True Latitude (optional)"
|
| 144 |
-
lon_in = gr.Number(label="True Longitude (optional)"
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
out_plot = gr.Plot(label="Map")
|
| 147 |
|
| 148 |
run_btn = gr.Button("Run Prediction", variant="primary")
|
| 149 |
-
run_btn.click(predict_geohash_map,
|
|
|
|
|
|
|
| 150 |
|
| 151 |
if __name__ == "__main__":
|
| 152 |
demo.launch()
|
|
|
|
| 5 |
import pygeohash as pgh
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import torch.nn as nn
|
| 8 |
+
import reverse_geocoder as rg
|
| 9 |
+
import pycountry # for full country names
|
| 10 |
|
| 11 |
EXPORT_DIR = "."
|
| 12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 13 |
|
| 14 |
# ---------------- Load metadata ----------------
|
| 15 |
metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
|
| 16 |
geoh2id = metadata["geoh2id"]
|
| 17 |
+
id2geoh = {int(v): k for k, v in geoh2id.items()}
|
| 18 |
country2id = metadata["country2id"]
|
| 19 |
clip_model_name = metadata["clip_model"]
|
| 20 |
dim = metadata["embedding_dim"]
|
|
|
|
| 54 |
a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
|
| 55 |
return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
|
| 56 |
|
| 57 |
+
# ---------------- Country lookup ----------------
|
| 58 |
+
def latlon_to_country(lat, lon):
|
| 59 |
+
try:
|
| 60 |
+
res = rg.search((lat, lon))[0]
|
| 61 |
+
cc = res['cc']
|
| 62 |
+
country = pycountry.countries.get(alpha_2=cc)
|
| 63 |
+
return country.name if country else cc, cc
|
| 64 |
+
except:
|
| 65 |
+
return "Unknown", "??"
|
| 66 |
+
|
| 67 |
# ---------------- Prediction + Map ----------------
|
| 68 |
+
def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None, top_k=5):
|
| 69 |
c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
|
| 70 |
with torch.no_grad():
|
| 71 |
emb = clip_model.get_image_features(**c_in)
|
|
|
|
| 74 |
out_class_np = out_class.cpu().numpy()[0]
|
| 75 |
out_offset_np = out_offset.cpu().numpy()[0]
|
| 76 |
|
| 77 |
+
topk_idx = out_class_np.argsort()[-top_k:][::-1]
|
| 78 |
+
lats, lons = [], []
|
| 79 |
+
|
| 80 |
+
# Markdown table header
|
| 81 |
+
preds_text = ["| Rank | Country | Code | Latitude | Longitude | Distance |",
|
| 82 |
+
"|------|---------|------|----------|-----------|----------|"]
|
| 83 |
|
| 84 |
for rank, i in enumerate(topk_idx, 1):
|
| 85 |
geoh = id2geoh[i]
|
| 86 |
lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
|
| 87 |
+
lat_pred = lat_base + out_offset_np[0] * cell_lat
|
| 88 |
+
lon_pred = lon_base + out_offset_np[1] * cell_lon
|
| 89 |
+
|
| 90 |
+
country_name, country_code = latlon_to_country(lat_pred, lon_pred)
|
| 91 |
+
|
| 92 |
+
dist_str = ""
|
| 93 |
+
if true_lat is not None and true_lon is not None:
|
| 94 |
+
dist = haversine(true_lat, true_lon, lat_pred, lon_pred)
|
| 95 |
+
if dist < 100:
|
| 96 |
+
dist_str = f"π’ {dist:.1f} km"
|
| 97 |
+
elif dist < 1000:
|
| 98 |
+
dist_str = f"π‘ {dist:.1f} km"
|
| 99 |
+
else:
|
| 100 |
+
dist_str = f"π΄ {dist:.1f} km"
|
| 101 |
+
|
| 102 |
+
preds_text.append(
|
| 103 |
+
f"| {rank} | {country_name} | {country_code} | {lat_pred:.5f} | {lon_pred:.5f} | {dist_str} |"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
lats.append(lat_pred)
|
| 107 |
lons.append(lon_pred)
|
|
|
|
| 108 |
|
| 109 |
+
# -------- Map Plot --------
|
| 110 |
fig = go.Figure()
|
| 111 |
|
| 112 |
+
# Predictions
|
| 113 |
fig.add_trace(go.Scattermapbox(
|
| 114 |
lat=lats, lon=lons, mode="markers+text",
|
| 115 |
+
text=[f"Top-{i+1}" for i in range(len(lats))],
|
| 116 |
+
textposition="top right",
|
| 117 |
marker=go.scattermapbox.Marker(size=12, color="blue"),
|
| 118 |
+
name="Predictions", showlegend=False
|
| 119 |
))
|
| 120 |
|
| 121 |
+
# True location (optional)
|
| 122 |
if true_lat is not None and true_lon is not None:
|
| 123 |
fig.add_trace(go.Scattermapbox(
|
| 124 |
lat=[true_lat], lon=[true_lon], mode="markers+text",
|
| 125 |
text=["True Location"], textposition="bottom right",
|
| 126 |
marker=go.scattermapbox.Marker(size=14, color="red"),
|
| 127 |
+
name="True Location", showlegend=False
|
| 128 |
))
|
| 129 |
+
# Connect with lines
|
| 130 |
+
for lat_p, lon_p in zip(lats, lons):
|
|
|
|
| 131 |
fig.add_trace(go.Scattermapbox(
|
| 132 |
lat=[true_lat, lat_p],
|
| 133 |
lon=[true_lon, lon_p],
|
| 134 |
+
mode="lines",
|
|
|
|
|
|
|
| 135 |
line=dict(width=2, color="green"),
|
| 136 |
showlegend=False
|
| 137 |
))
|
|
|
|
| 138 |
|
| 139 |
# Layout
|
| 140 |
center_lat = true_lat if true_lat is not None else lats[0]
|
|
|
|
| 143 |
mapbox_style="open-street-map",
|
| 144 |
hovermode="closest",
|
| 145 |
mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
|
| 146 |
+
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
| 147 |
+
showlegend=False
|
| 148 |
)
|
| 149 |
|
| 150 |
return "\n".join(preds_text), fig
|
|
|
|
| 153 |
with gr.Blocks() as demo:
|
| 154 |
gr.Markdown(
|
| 155 |
"""
|
| 156 |
+
# π Locus - GeoGuessr Image to Coordinates Model
|
| 157 |
Upload a **streetview-style image** (photo or screenshot).
|
| 158 |
The model can handle images with a small UI overlay (like from Google Street View),
|
| 159 |
+
as long as the **main scene is clearly visible**.
|
| 160 |
|
| 161 |
+
You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.
|
|
|
|
|
|
|
|
|
|
| 162 |
"""
|
| 163 |
)
|
| 164 |
with gr.Row():
|
| 165 |
img_in = gr.Image(type="pil", label="Upload Streetview Image")
|
| 166 |
with gr.Column():
|
| 167 |
+
lat_in = gr.Number(label="True Latitude (optional)")
|
| 168 |
+
lon_in = gr.Number(label="True Longitude (optional)")
|
| 169 |
+
topk_in = gr.Slider(1, 10, value=5, step=1,
|
| 170 |
+
label="Top K Predictions",
|
| 171 |
+
info="Choose how many top predictions to display (1β10).")
|
| 172 |
+
|
| 173 |
+
out_text = gr.Markdown(label="Predictions (Markdown Table)")
|
| 174 |
out_plot = gr.Plot(label="Map")
|
| 175 |
|
| 176 |
run_btn = gr.Button("Run Prediction", variant="primary")
|
| 177 |
+
run_btn.click(predict_geohash_map,
|
| 178 |
+
[img_in, lat_in, lon_in, topk_in],
|
| 179 |
+
[out_text, out_plot])
|
| 180 |
|
| 181 |
if __name__ == "__main__":
|
| 182 |
demo.launch()
|