JohanBeytell commited on
Commit
9283b32
·
verified ·
1 Parent(s): 3eaf89d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch, numpy as np, json
3
+ from PIL import Image
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import pygeohash as pgh
6
+ import os
7
+
8
+ EXPORT_DIR = "."
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ TOP_K = 3
11
+
12
+ # ---------------- Load metadata ----------------
13
+ metadata = json.load(open(os.path.join(EXPORT_DIR, "metadata.json")))
14
+ geoh2id = metadata["geoh2id"]
15
+ id2geoh = {int(v): k for k,v in geoh2id.items()}
16
+ country2id = metadata["country2id"]
17
+ clip_model_name = metadata["clip_model"]
18
+ dim = metadata["embedding_dim"]
19
+ num_fine = len(geoh2id)
20
+ num_countries = len(country2id)
21
+
22
+ # ---------------- Model definition ----------------
23
+ import torch.nn as nn
24
+ class GeoHybridModel(nn.Module):
25
+ def __init__(self, in_dim, num_classes, num_countries, hidden=1024, drop=0.3):
26
+ super().__init__()
27
+ self.shared = nn.Sequential(
28
+ nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(drop),
29
+ nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(drop)
30
+ )
31
+ self.classifier = nn.Linear(hidden//2, num_classes)
32
+ self.regressor = nn.Linear(hidden//2, 2)
33
+ self.country_classifier = nn.Linear(hidden//2, num_countries)
34
+
35
+ def forward(self, x):
36
+ feat = self.shared(x)
37
+ return self.classifier(feat), self.regressor(feat), self.country_classifier(feat)
38
+
39
+ # Load weights
40
+ model = GeoHybridModel(dim, num_fine, num_countries)
41
+ model.load_state_dict(torch.load(os.path.join(EXPORT_DIR, "model.pt"), map_location=DEVICE))
42
+ model.to(DEVICE).eval()
43
+
44
+ # Load CLIP
45
+ clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
46
+ clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE).eval()
47
+
48
+ # ---------------- Haversine ----------------
49
+ def haversine(lat1, lon1, lat2, lon2):
50
+ R = 6371.0
51
+ phi1,phi2 = np.radians(lat1), np.radians(lat2)
52
+ dphi = np.radians(lat2-lat1)
53
+ dlambda = np.radians(lon2-lon1)
54
+ a = np.sin(dphi/2)**2 + np.cos(phi1)*np.cos(phi2)*np.sin(dlambda/2)**2
55
+ return 2*R*np.arctan2(np.sqrt(a), np.sqrt(1-a))
56
+
57
+ # ---------------- Prediction ----------------
58
+ def predict_geohash(img: Image.Image):
59
+ c_in = clip_processor(images=img, return_tensors="pt").to(DEVICE)
60
+ with torch.no_grad():
61
+ emb = clip_model.get_image_features(**c_in)
62
+ emb = emb / emb.norm(p=2, dim=-1, keepdim=True)
63
+ out_class, out_offset, _ = model(emb)
64
+ out_class_np = out_class.cpu().numpy()[0]
65
+ out_offset_np = out_offset.cpu().numpy()[0]
66
+
67
+ topk_idx = out_class_np.argsort()[-TOP_K:][::-1]
68
+ preds = []
69
+ for i in topk_idx:
70
+ geoh = id2geoh[i]
71
+ lat_base, lon_base, cell_lat, cell_lon = pgh.decode_exactly(geoh)
72
+ lat_pred = lat_base + out_offset_np[0]*cell_lat
73
+ lon_pred = lon_base + out_offset_np[1]*cell_lon
74
+ preds.append(f"{geoh} → {lat_pred:.5f},{lon_pred:.5f}")
75
+ return "\n".join(preds)
76
+
77
+ # ---------------- Gradio UI ----------------
78
+ iface = gr.Interface(
79
+ fn=predict_geohash,
80
+ inputs=gr.Image(type="pil"),
81
+ outputs=gr.Textbox(),
82
+ title="Locus - GeoGuessr Image to Coordinates model",
83
+ description="Upload a streetview image and get top-K predicted geohashes with lat/lon."
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ iface.launch()