JohanBeytell commited on
Commit
855fdc8
·
verified ·
1 Parent(s): 60ec685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -28
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
- import torch, numpy as np, json
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
- import os
7
  import plotly.graph_objects as go
 
8
 
9
  EXPORT_DIR = "."
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,8 +20,7 @@ dim = metadata["embedding_dim"]
20
  num_fine = len(geoh2id)
21
  num_countries = len(country2id)
22
 
23
- # ---------------- Model definition ----------------
24
- import torch.nn as nn
25
  class GeoHybridModel(nn.Module):
26
  def __init__(self, in_dim, num_classes, num_countries, hidden=1024, drop=0.3):
27
  super().__init__()
@@ -37,12 +36,11 @@ class GeoHybridModel(nn.Module):
37
  feat = self.shared(x)
38
  return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
39
 
40
- # Load model weights
41
  model = GeoHybridModel(dim, num_fine, num_countries)
42
  model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
43
  model.to(DEVICE).eval()
44
 
45
- # Load CLIP
46
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
47
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
48
 
@@ -55,8 +53,8 @@ def haversine(lat1, lon1, lat2, lon2):
55
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
56
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
57
 
58
- # ---------------- Prediction + Plotly map ----------------
59
- def predict_geohash_map(img: Image.Image):
60
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
61
  with torch.no_grad():
62
  emb = clip_model.get_image_features(**c_in)
@@ -74,42 +72,68 @@ def predict_geohash_map(img: Image.Image):
74
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
75
  lat_pred = lat_base + out_offset_np[0]*cell_lat
76
  lon_pred = lon_base + out_offset_np[1]*cell_lon
77
-
78
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
79
  lats.append(lat_pred)
80
  lons.append(lon_pred)
81
  labels.append(f"Top-{rank}: {geoh}")
82
 
83
- # Plotly scattermapbox
84
- fig = go.Figure(go.Scattermapbox(
85
- lat=lats,
86
- lon=lons,
87
- mode="markers+text",
88
- text=labels,
89
- textposition="top right",
90
  marker=go.scattermapbox.Marker(size=12, color="blue"),
 
91
  ))
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  fig.update_layout(
94
  mapbox_style="open-street-map",
95
  hovermode="closest",
96
- mapbox=dict(
97
- center=go.layout.mapbox.Center(lat=lats[0], lon=lons[0]),
98
- zoom=4
99
- ),
100
  margin={"r":0,"t":0,"l":0,"b":0}
101
  )
102
 
103
  return "\n".join(preds_text), fig
104
 
105
  # ---------------- Gradio UI ----------------
106
- iface = gr.Interface(
107
- fn=predict_geohash_map,
108
- inputs=gr.Image(type="pil"),
109
- outputs=[gr.Textbox(label="Top-K Geohash Predictions"), gr.Plot(label="Map")],
110
- title="GeoGuessr CLIP Top-K Predictor",
111
- description="Upload a streetview image and see top-K predicted geohashes and map locations."
112
- )
 
 
 
 
113
 
114
  if __name__ == "__main__":
115
- iface.launch()
 
1
  import gradio as gr
2
+ import torch, numpy as np, json, os
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
 
6
  import plotly.graph_objects as go
7
+ import torch.nn as nn
8
 
9
  EXPORT_DIR = "."
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
20
  num_fine = len(geoh2id)
21
  num_countries = len(country2id)
22
 
23
+ # ---------------- Model ----------------
 
24
  class GeoHybridModel(nn.Module):
25
  def __init__(self, in_dim, num_classes, num_countries, hidden=1024, drop=0.3):
26
  super().__init__()
 
36
  feat = self.shared(x)
37
  return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
38
 
 
39
  model = GeoHybridModel(dim, num_fine, num_countries)
40
  model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
41
  model.to(DEVICE).eval()
42
 
43
+ # ---------------- CLIP ----------------
44
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
45
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
46
 
 
53
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
54
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
55
 
56
+ # ---------------- Prediction + Map ----------------
57
+ def predict_geohash_map(img: Image.Image, true_lat=None, true_lon=None):
58
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
59
  with torch.no_grad():
60
  emb = clip_model.get_image_features(**c_in)
 
72
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
73
  lat_pred = lat_base + out_offset_np[0]*cell_lat
74
  lon_pred = lon_base + out_offset_np[1]*cell_lon
 
75
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
76
  lats.append(lat_pred)
77
  lons.append(lon_pred)
78
  labels.append(f"Top-{rank}: {geoh}")
79
 
80
+ # Base map
81
+ fig = go.Figure()
82
+
83
+ # Prediction markers
84
+ fig.add_trace(go.Scattermapbox(
85
+ lat=lats, lon=lons, mode="markers+text",
86
+ text=labels, textposition="top right",
87
  marker=go.scattermapbox.Marker(size=12, color="blue"),
88
+ name="Predictions"
89
  ))
90
 
91
+ # True location + lines
92
+ if true_lat is not None and true_lon is not None:
93
+ fig.add_trace(go.Scattermapbox(
94
+ lat=[true_lat], lon=[true_lon], mode="markers+text",
95
+ text=["True Location"], textposition="bottom right",
96
+ marker=go.scattermapbox.Marker(size=14, color="red"),
97
+ name="True Location"
98
+ ))
99
+ # Add connecting lines + distance annotations
100
+ for rank, (lat_p, lon_p) in enumerate(zip(lats, lons), 1):
101
+ dist = haversine(true_lat, true_lon, lat_p, lon_p)
102
+ fig.add_trace(go.Scattermapbox(
103
+ lat=[true_lat, lat_p],
104
+ lon=[true_lon, lon_p],
105
+ mode="lines+text",
106
+ text=[None, f"{dist:.1f} km"],
107
+ textposition="middle right",
108
+ line=dict(width=2, color="green"),
109
+ showlegend=False
110
+ ))
111
+ preds_text.append(f"Dist to Top-{rank}: {dist:.1f} km")
112
+
113
+ # Layout
114
+ center_lat = true_lat if true_lat is not None else lats[0]
115
+ center_lon = true_lon if true_lon is not None else lons[0]
116
  fig.update_layout(
117
  mapbox_style="open-street-map",
118
  hovermode="closest",
119
+ mapbox=dict(center=go.layout.mapbox.Center(lat=center_lat, lon=center_lon), zoom=4),
 
 
 
120
  margin={"r":0,"t":0,"l":0,"b":0}
121
  )
122
 
123
  return "\n".join(preds_text), fig
124
 
125
  # ---------------- Gradio UI ----------------
126
+ with gr.Blocks() as demo:
127
+ with gr.Row():
128
+ img_in = gr.Image(type="pil", label="Upload Streetview Image")
129
+ with gr.Column():
130
+ lat_in = gr.Number(label="True Latitude (optional)")
131
+ lon_in = gr.Number(label="True Longitude (optional)")
132
+ out_text = gr.Textbox(label="Predictions + Distances")
133
+ out_plot = gr.Plot(label="Map")
134
+
135
+ run_btn = gr.Button("Run Prediction")
136
+ run_btn.click(predict_geohash_map, [img_in, lat_in, lon_in], [out_text, out_plot])
137
 
138
  if __name__ == "__main__":
139
+ demo.launch()