JohanBeytell commited on
Commit
7f9817c
·
verified ·
1 Parent(s): 963878a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -11
app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
  import os
 
7
 
8
  EXPORT_DIR = "."
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -45,7 +46,16 @@ model.to(DEVICE).eval()
45
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
46
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
47
 
48
- # ---------------- Prediction ----------------
 
 
 
 
 
 
 
 
 
49
  def predict_geohash_map(img: Image.Image):
50
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
51
  with torch.no_grad():
@@ -55,30 +65,48 @@ def predict_geohash_map(img: Image.Image):
55
  out_class_np = out_class.cpu().numpy()[0]
56
  out_offset_np = out_offset.cpu().numpy()[0]
57
 
58
- # Get top-k class predictions
59
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
60
  preds_text = []
61
- coords = []
62
 
63
  for rank, i in enumerate(topk_idx, 1):
64
  geoh = id2geoh[i]
65
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
66
- lat_pred = float(lat_base + out_offset_np[0]*cell_lat)
67
- lon_pred = float(lon_base + out_offset_np[1]*cell_lon)
68
 
69
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
70
- coords.append([lat_pred, lon_pred])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- return "\n".join(preds_text), coords
73
 
74
  # ---------------- Gradio UI ----------------
75
  iface = gr.Interface(
76
  fn=predict_geohash_map,
77
  inputs=gr.Image(type="pil"),
78
- outputs=[
79
- gr.Textbox(label="Top-K Geohash Predictions"),
80
- gr.Map(label="Predicted Locations") # <-- Gradio native map
81
- ],
82
  title="GeoGuessr CLIP Top-K Predictor",
83
  description="Upload a streetview image and see top-K predicted geohashes and map locations."
84
  )
 
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
  import os
7
+ import plotly.graph_objects as go
8
 
9
  EXPORT_DIR = "."
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
46
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
47
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
48
 
49
+ # ---------------- Haversine ----------------
50
+ def haversine(lat1, lon1, lat2, lon2):
51
+ R = 6371.0
52
+ phi1,phi2 = np.radians(lat1), np.radians(lat2)
53
+ dphi = np.radians(lat2-lat1)
54
+ dlambda = np.radians(lon2-lon1)
55
+ a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
56
+ return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
57
+
58
+ # ---------------- Prediction + Plotly map ----------------
59
  def predict_geohash_map(img: Image.Image):
60
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
61
  with torch.no_grad():
 
65
  out_class_np = out_class.cpu().numpy()[0]
66
  out_offset_np = out_offset.cpu().numpy()[0]
67
 
 
68
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
69
  preds_text = []
70
+ lats, lons, labels = [], [], []
71
 
72
  for rank, i in enumerate(topk_idx, 1):
73
  geoh = id2geoh[i]
74
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
75
+ lat_pred = lat_base + out_offset_np[0]*cell_lat
76
+ lon_pred = lon_base + out_offset_np[1]*cell_lon
77
 
78
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
79
+ lats.append(lat_pred)
80
+ lons.append(lon_pred)
81
+ labels.append(f"Top-{rank}: {geoh}")
82
+
83
+ # Plotly scattermapbox
84
+ fig = go.Figure(go.Scattermapbox(
85
+ lat=lats,
86
+ lon=lons,
87
+ mode="markers+text",
88
+ text=labels,
89
+ textposition="top right",
90
+ marker=go.scattermapbox.Marker(size=12, color="blue"),
91
+ ))
92
+
93
+ fig.update_layout(
94
+ mapbox_style="open-street-map",
95
+ hovermode="closest",
96
+ mapbox=dict(
97
+ center=go.layout.mapbox.Center(lat=lats[0], lon=lons[0]),
98
+ zoom=4
99
+ ),
100
+ margin={"r":0,"t":0,"l":0,"b":0}
101
+ )
102
 
103
+ return "\n".join(preds_text), fig
104
 
105
  # ---------------- Gradio UI ----------------
106
  iface = gr.Interface(
107
  fn=predict_geohash_map,
108
  inputs=gr.Image(type="pil"),
109
+ outputs=[gr.Textbox(label="Top-K Geohash Predictions"), gr.Plot(label="Map")],
 
 
 
110
  title="GeoGuessr CLIP Top-K Predictor",
111
  description="Upload a streetview image and see top-K predicted geohashes and map locations."
112
  )