JohanBeytell commited on
Commit
894aed8
Β·
verified Β·
1 Parent(s): a60b794

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -34
app.py CHANGED
@@ -5,15 +5,16 @@ from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
  import plotly.graph_objects as go
7
  import torch.nn as nn
 
 
8
 
9
  EXPORT_DIR = "."
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- TOP_K = 5
12
 
13
  # ---------------- Load metadata ----------------
14
  metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
15
  geoh2id = metadata["geoh2id"]
16
- id2geoh = {int(v): k for k,v in geoh2id.items()}
17
  country2id = metadata["country2id"]
18
  clip_model_name = metadata["clip_model"]
19
  dim = metadata["embedding_dim"]
@@ -53,8 +54,18 @@ def haversine(lat1, lon1, lat2, lon2):
53
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
54
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
55
 
 
 
 
 
 
 
 
 
 
 
56
  # ---------------- Prediction + Map ----------------
57
- def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
58
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
59
  with torch.no_grad():
60
  emb = clip_model.get_image_features(**c_in)
@@ -63,52 +74,67 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
63
  out_class_np = out_class.cpu().numpy()[0]
64
  out_offset_np = out_offset.cpu().numpy()[0]
65
 
66
- topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
67
- preds_text = []
68
- lats, lons, labels = [], [], []
 
 
 
69
 
70
  for rank, i in enumerate(topk_idx, 1):
71
  geoh = id2geoh[i]
72
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
73
- lat_pred = lat_base + out_offset_np[0]*cell_lat
74
- lon_pred = lon_base + out_offset_np[1]*cell_lon
75
- preds_text.append(f"{rank}. {geoh} β†’ {lat_pred:.5f},{lon_pred:.5f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  lats.append(lat_pred)
77
  lons.append(lon_pred)
78
- labels.append(f"Top-{rank}: {geoh}")
79
 
80
- # Base map
81
  fig = go.Figure()
82
 
83
- # Prediction markers
84
  fig.add_trace(go.Scattermapbox(
85
  lat=lats, lon=lons, mode="markers+text",
86
- text=labels, textposition="top right",
 
87
  marker=go.scattermapbox.Marker(size=12, color="blue"),
88
- name="Predictions"
89
  ))
90
 
91
- # True location + lines
92
  if true_lat is not None and true_lon is not None:
93
  fig.add_trace(go.Scattermapbox(
94
  lat=[true_lat], lon=[true_lon], mode="markers+text",
95
  text=["True Location"], textposition="bottom right",
96
  marker=go.scattermapbox.Marker(size=14, color="red"),
97
- name="True Location"
98
  ))
99
- # Add connecting lines + distance annotations
100
- for rank, (lat_p, lon_p) in enumerate(zip(lats, lons), 1):
101
- dist = haversine(true_lat, true_lon, lat_p, lon_p)
102
  fig.add_trace(go.Scattermapbox(
103
  lat=[true_lat, lat_p],
104
  lon=[true_lon, lon_p],
105
- mode="lines+text",
106
- text=[None, f"{dist:.1f} km"],
107
- textposition="middle right",
108
  line=dict(width=2, color="green"),
109
  showlegend=False
110
  ))
111
- preds_text.append(f"Dist to Top-{rank}: {dist:.1f} km")
112
 
113
  # Layout
114
  center_lat = true_lat if true_lat is not None else lats[0]
@@ -117,7 +143,8 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
117
  mapbox_style="open-street-map",
118
  hovermode="closest",
119
  mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
120
- margin={"r":0,"t":0,"l":0,"b":0}
 
121
  )
122
 
123
  return "\n".join(preds_text), fig
@@ -126,27 +153,30 @@ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
126
  with gr.Blocks() as demo:
127
  gr.Markdown(
128
  """
129
- # 🌍 Locus - GeoGuessr Image to Coordinates Model
130
  Upload a **streetview-style image** (photo or screenshot).
131
  The model can handle images with a small UI overlay (like from Google Street View),
132
- as long as the **main scene is clearly visible**.
133
 
134
- You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.
135
-
136
- > **Note:**
137
- > Locus was trained solely on publicly accessible Google Street View images, not on **actual GeoGuessr data**.
138
  """
139
  )
140
  with gr.Row():
141
  img_in = gr.Image(type="pil", label="Upload Streetview Image")
142
  with gr.Column():
143
- lat_in = gr.Number(label="True Latitude (optional)", info="Enter the known latitude if available.")
144
- lon_in = gr.Number(label="True Longitude (optional)", info="Enter the known longitude if available.")
145
- out_text = gr.Textbox(label="Predictions + Distances")
 
 
 
 
146
  out_plot = gr.Plot(label="Map")
147
 
148
  run_btn = gr.Button("Run Prediction", variant="primary")
149
- run_btn.click(predict_geohash_map, [img_in, lat_in, lon_in], [out_text, out_plot])
 
 
150
 
151
  if __name__ == "__main__":
152
  demo.launch()
 
5
  import pygeohash as pgh
6
  import plotly.graph_objects as go
7
  import torch.nn as nn
8
+ import reverse_geocoder as rg
9
+ import pycountry # for full country names
10
 
11
  EXPORT_DIR = "."
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
13
 
14
  # ---------------- Load metadata ----------------
15
  metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
16
  geoh2id = metadata["geoh2id"]
17
+ id2geoh = {int(v): k for k, v in geoh2id.items()}
18
  country2id = metadata["country2id"]
19
  clip_model_name = metadata["clip_model"]
20
  dim = metadata["embedding_dim"]
 
54
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
55
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
56
 
57
+ # ---------------- Country lookup ----------------
58
+ def latlon_to_country(lat, lon):
59
+ try:
60
+ res = rg.search((lat, lon))[0]
61
+ cc = res['cc']
62
+ country = pycountry.countries.get(alpha_2=cc)
63
+ return country.name if country else cc, cc
64
+ except:
65
+ return "Unknown", "??"
66
+
67
  # ---------------- Prediction + Map ----------------
68
+ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None, top_k=5):
69
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
70
  with torch.no_grad():
71
  emb = clip_model.get_image_features(**c_in)
 
74
  out_class_np = out_class.cpu().numpy()[0]
75
  out_offset_np = out_offset.cpu().numpy()[0]
76
 
77
+ topk_idx = out_class_np.argsort()[-top_k:][::-1]
78
+ lats, lons = [], []
79
+
80
+ # Markdown table header
81
+ preds_text = ["| Rank | Country | Code | Latitude | Longitude | Distance |",
82
+ "|------|---------|------|----------|-----------|----------|"]
83
 
84
  for rank, i in enumerate(topk_idx, 1):
85
  geoh = id2geoh[i]
86
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
87
+ lat_pred = lat_base + out_offset_np[0] * cell_lat
88
+ lon_pred = lon_base + out_offset_np[1] * cell_lon
89
+
90
+ country_name, country_code = latlon_to_country(lat_pred, lon_pred)
91
+
92
+ dist_str = ""
93
+ if true_lat is not None and true_lon is not None:
94
+ dist = haversine(true_lat, true_lon, lat_pred, lon_pred)
95
+ if dist < 100:
96
+ dist_str = f"🟒 {dist:.1f} km"
97
+ elif dist < 1000:
98
+ dist_str = f"🟑 {dist:.1f} km"
99
+ else:
100
+ dist_str = f"πŸ”΄ {dist:.1f} km"
101
+
102
+ preds_text.append(
103
+ f"| {rank} | {country_name} | {country_code} | {lat_pred:.5f} | {lon_pred:.5f} | {dist_str} |"
104
+ )
105
+
106
  lats.append(lat_pred)
107
  lons.append(lon_pred)
 
108
 
109
+ # -------- Map Plot --------
110
  fig = go.Figure()
111
 
112
+ # Predictions
113
  fig.add_trace(go.Scattermapbox(
114
  lat=lats, lon=lons, mode="markers+text",
115
+ text=[f"Top-{i+1}" for i in range(len(lats))],
116
+ textposition="top right",
117
  marker=go.scattermapbox.Marker(size=12, color="blue"),
118
+ name="Predictions", showlegend=False
119
  ))
120
 
121
+ # True location (optional)
122
  if true_lat is not None and true_lon is not None:
123
  fig.add_trace(go.Scattermapbox(
124
  lat=[true_lat], lon=[true_lon], mode="markers+text",
125
  text=["True Location"], textposition="bottom right",
126
  marker=go.scattermapbox.Marker(size=14, color="red"),
127
+ name="True Location", showlegend=False
128
  ))
129
+ # Connect with lines
130
+ for lat_p, lon_p in zip(lats, lons):
 
131
  fig.add_trace(go.Scattermapbox(
132
  lat=[true_lat, lat_p],
133
  lon=[true_lon, lon_p],
134
+ mode="lines",
 
 
135
  line=dict(width=2, color="green"),
136
  showlegend=False
137
  ))
 
138
 
139
  # Layout
140
  center_lat = true_lat if true_lat is not None else lats[0]
 
143
  mapbox_style="open-street-map",
144
  hovermode="closest",
145
  mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
146
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
147
+ showlegend=False
148
  )
149
 
150
  return "\n".join(preds_text), fig
 
153
  with gr.Blocks() as demo:
154
  gr.Markdown(
155
  """
156
+ # 🌍 Locus - GeoGuessr Image to Coordinates Model
157
  Upload a **streetview-style image** (photo or screenshot).
158
  The model can handle images with a small UI overlay (like from Google Street View),
159
+ as long as the **main scene is clearly visible**.
160
 
161
+ You can optionally provide the actual coordinates to compare predictions and calculate distances on a world map.
 
 
 
162
  """
163
  )
164
  with gr.Row():
165
  img_in = gr.Image(type="pil", label="Upload Streetview Image")
166
  with gr.Column():
167
+ lat_in = gr.Number(label="True Latitude (optional)")
168
+ lon_in = gr.Number(label="True Longitude (optional)")
169
+ topk_in = gr.Slider(1, 10, value=5, step=1,
170
+ label="Top K Predictions",
171
+ info="Choose how many top predictions to display (1–10).")
172
+
173
+ out_text = gr.Markdown(label="Predictions (Markdown Table)")
174
  out_plot = gr.Plot(label="Map")
175
 
176
  run_btn = gr.Button("Run Prediction", variant="primary")
177
+ run_btn.click(predict_geohash_map,
178
+ [img_in, lat_in, lon_in, topk_in],
179
+ [out_text, out_plot])
180
 
181
  if __name__ == "__main__":
182
  demo.launch()