JohanBeytell commited on
Commit
963878a
·
verified ·
1 Parent(s): a2e2256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -35
app.py CHANGED
@@ -3,10 +3,7 @@ import torch, numpy as np, json
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
- import folium
7
  import os
8
- from io import BytesIO
9
- import base64
10
 
11
  EXPORT_DIR = "."
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -48,16 +45,7 @@ model.to(DEVICE).eval()
48
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
49
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
50
 
51
- # ---------------- Haversine ----------------
52
- def haversine(lat1, lon1, lat2, lon2):
53
- R = 6371.0
54
- phi1,phi2 = np.radians(lat1), np.radians(lat2)
55
- dphi = np.radians(lat2-lat1)
56
- dlambda = np.radians(lon2-lon1)
57
- a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
58
- return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
59
-
60
- # ---------------- Prediction + map ----------------
61
  def predict_geohash_map(img: Image.Image):
62
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
63
  with torch.no_grad():
@@ -67,39 +55,30 @@ def predict_geohash_map(img: Image.Image):
67
  out_class_np = out_class.cpu().numpy()[0]
68
  out_offset_np = out_offset.cpu().numpy()[0]
69
 
 
70
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
71
  preds_text = []
72
- map_center = None
73
- fmap = folium.Map(tiles="OpenStreetMap", zoom_start=2)
74
-
75
  for rank, i in enumerate(topk_idx, 1):
76
  geoh = id2geoh[i]
77
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
78
- lat_pred = lat_base + out_offset_np[0]*cell_lat
79
- lon_pred = lon_base + out_offset_np[1]*cell_lon
80
- if map_center is None:
81
- map_center = [lat_pred, lon_pred]
82
- fmap.location = map_center
83
- fmap.zoom_start = 6
84
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
85
- folium.Marker(
86
- location=[lat_pred, lon_pred],
87
- popup=f"Top-{rank}: {geoh}",
88
- icon=folium.Icon(color="blue" if rank==1 else "green")
89
- ).add_to(fmap)
90
-
91
- # Convert folium map to HTML iframe
92
- fmap_file = BytesIO()
93
- fmap.save(fmap_file, close_file=False)
94
- fmap_html = fmap_file.getvalue().decode()
95
-
96
- return "\n".join(preds_text), fmap_html
97
 
98
  # ---------------- Gradio UI ----------------
99
  iface = gr.Interface(
100
  fn=predict_geohash_map,
101
  inputs=gr.Image(type="pil"),
102
- outputs=[gr.Textbox(label="Top-K Geohash Predictions"), gr.HTML(label="Map")],
 
 
 
103
  title="GeoGuessr CLIP Top-K Predictor",
104
  description="Upload a streetview image and see top-K predicted geohashes and map locations."
105
  )
 
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
 
6
  import os
 
 
7
 
8
  EXPORT_DIR = "."
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
45
  clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
46
  clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
47
 
48
+ # ---------------- Prediction ----------------
 
 
 
 
 
 
 
 
 
49
  def predict_geohash_map(img: Image.Image):
50
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
51
  with torch.no_grad():
 
55
  out_class_np = out_class.cpu().numpy()[0]
56
  out_offset_np = out_offset.cpu().numpy()[0]
57
 
58
+ # Get top-k class predictions
59
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
60
  preds_text = []
61
+ coords = []
62
+
 
63
  for rank, i in enumerate(topk_idx, 1):
64
  geoh = id2geoh[i]
65
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
66
+ lat_pred = float(lat_base + out_offset_np[0]*cell_lat)
67
+ lon_pred = float(lon_base + out_offset_np[1]*cell_lon)
68
+
 
 
 
69
  preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
70
+ coords.append([lat_pred, lon_pred])
71
+
72
+ return "\n".join(preds_text), coords
 
 
 
 
 
 
 
 
 
73
 
74
  # ---------------- Gradio UI ----------------
75
  iface = gr.Interface(
76
  fn=predict_geohash_map,
77
  inputs=gr.Image(type="pil"),
78
+ outputs=[
79
+ gr.Textbox(label="Top-K Geohash Predictions"),
80
+ gr.Map(label="Predicted Locations") # <-- Gradio native map
81
+ ],
82
  title="GeoGuessr CLIP Top-K Predictor",
83
  description="Upload a streetview image and see top-K predicted geohashes and map locations."
84
  )