JohanBeytell commited on
Commit
2bb8939
·
verified ·
1 Parent(s): aa114d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -11
app.py CHANGED
@@ -3,7 +3,10 @@ import torch, numpy as np, json
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
 
6
  import os
 
 
7
 
8
  EXPORT_DIR = "."
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -36,7 +39,7 @@ class GeoHybridModel(nn.Module):
36
  feat = self.shared(x)
37
  return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
38
 
39
- # Load weights
40
  model = GeoHybridModel(dim, num_fine, num_countries)
41
  model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
42
  model.to(DEVICE).eval()
@@ -54,8 +57,8 @@ def haversine(lat1, lon1, lat2, lon2):
54
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
55
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
56
 
57
- # ---------------- Prediction ----------------
58
- def predict_geohash(img: Image.Image):
59
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
60
  with torch.no_grad():
61
  emb = clip_model.get_image_features(**c_in)
@@ -65,22 +68,40 @@ def predict_geohash(img: Image.Image):
65
  out_offset_np = out_offset.cpu().numpy()[0]
66
 
67
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
68
- preds = []
69
- for i in topk_idx:
 
 
 
70
  geoh = id2geoh[i]
71
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
72
  lat_pred = lat_base + out_offset_np[0]*cell_lat
73
  lon_pred = lon_base + out_offset_np[1]*cell_lon
74
- preds.append(f"{geoh} {lat_pred:.5f},{lon_pred:.5f}")
75
- return "\n".join(preds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # ---------------- Gradio UI ----------------
78
  iface = gr.Interface(
79
- fn=predict_geohash,
80
  inputs=gr.Image(type="pil"),
81
- outputs=gr.Textbox(),
82
- title="Locus - GeoGuessr Image to Coordinates model",
83
- description="Upload a streetview image and get top-K predicted geohashes with lat/lon."
84
  )
85
 
86
  if __name__ == "__main__":
 
3
  from PIL import Image
4
  from transformers import CLIPProcessor, CLIPModel
5
  import pygeohash as pgh
6
+ import folium
7
  import os
8
+ from io import BytesIO
9
+ import base64
10
 
11
  EXPORT_DIR = "."
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
39
  feat = self.shared(x)
40
  return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
41
 
42
+ # Load model weights
43
  model = GeoHybridModel(dim, num_fine, num_countries)
44
  model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
45
  model.to(DEVICE).eval()
 
57
  a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
58
  return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
59
 
60
+ # ---------------- Prediction + map ----------------
61
+ def predict_geohash_map(img: Image.Image):
62
  c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
63
  with torch.no_grad():
64
  emb = clip_model.get_image_features(**c_in)
 
68
  out_offset_np = out_offset.cpu().numpy()[0]
69
 
70
  topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
71
+ preds_text = []
72
+ map_center = None
73
+ fmap = folium.Map(tiles="OpenStreetMap", zoom_start=2)
74
+
75
+ for rank, i in enumerate(topk_idx, 1):
76
  geoh = id2geoh[i]
77
  lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
78
  lat_pred = lat_base + out_offset_np[0]*cell_lat
79
  lon_pred = lon_base + out_offset_np[1]*cell_lon
80
+ if map_center is None:
81
+ map_center = [lat_pred, lon_pred]
82
+ fmap.location = map_center
83
+ fmap.zoom_start = 6
84
+ preds_text.append(f"{rank}. {geoh} → {lat_pred:.5f},{lon_pred:.5f}")
85
+ folium.Marker(
86
+ location=[lat_pred, lon_pred],
87
+ popup=f"Top-{rank}: {geoh}",
88
+ icon=folium.Icon(color="blue" if rank==1 else "green")
89
+ ).add_to(fmap)
90
+
91
+ # Convert folium map to HTML iframe
92
+ fmap_file = BytesIO()
93
+ fmap.save(fmap_file, close_file=False)
94
+ fmap_html = fmap_file.getvalue().decode()
95
+
96
+ return "\n".join(preds_text), fmap_html
97
 
98
  # ---------------- Gradio UI ----------------
99
  iface = gr.Interface(
100
+ fn=predict_geohash_map,
101
  inputs=gr.Image(type="pil"),
102
+ outputs=[gr.Textbox(label="Top-K Geohash Predictions"), gr.HTML(label="Map")],
103
+ title="GeoGuessr CLIP Top-K Predictor",
104
+ description="Upload a streetview image and see top-K predicted geohashes and map locations."
105
  )
106
 
107
  if __name__ == "__main__":