AceXRoux commited on
Commit
48499b0
·
verified ·
1 Parent(s): e112c49

Upload 2 files

Browse files

Adding the main files

Files changed (2) hide show
  1. app.py +309 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoVLM - AI-Powered Geolocation
4
+ Upload any image and predict where it was taken using Vision-Language Models
5
+ """
6
+
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from transformers import AutoProcessor, AutoModelForImageTextToText
10
+ import torch
11
+ import re
12
+ import math
13
+ from dataclasses import dataclass
14
+
15
+ # ============================================================================
16
+ # Simplified Geolocation Parser (from vlm-gym)
17
+ # ============================================================================
18
+
19
+ @dataclass(frozen=True)
20
+ class Coords:
21
+ """Geographic coordinates"""
22
+ lat: float
23
+ lon: float
24
+
25
+ @dataclass(frozen=True)
26
+ class ParsedResponse:
27
+ """Structured model output"""
28
+ city: str | None
29
+ region: str | None
30
+ country: str | None
31
+ coords: Coords | None
32
+ raw_text: str
33
+ format_valid: bool
34
+
35
+ PROMPT_TEMPLATE = (
36
+ "Look at the image and guess the location.\n"
37
+ "Respond with EXACTLY these 5 lines, no extra text:\n"
38
+ "City: <city name>\n"
39
+ "Region: <state or region>\n"
40
+ "Country: <country name or ISO-2 code>\n"
41
+ "Latitude: <number between -90 and 90>\n"
42
+ "Longitude: <number between -180 and 180>\n"
43
+ )
44
+
45
+ KEY_ALIASES = {
46
+ "city": "city",
47
+ "country": "country",
48
+ "region": "region",
49
+ "state": "region",
50
+ "province": "region",
51
+ "latitude": "lat",
52
+ "lat": "lat",
53
+ "longitude": "lon",
54
+ "lon": "lon",
55
+ }
56
+
57
+ def parse_response(text: str) -> ParsedResponse:
58
+ """Parse structured 5-line format"""
59
+ parsed = {}
60
+
61
+ if not text:
62
+ return ParsedResponse(None, None, None, None, text, False)
63
+
64
+ # Parse key-value lines
65
+ key_pattern = re.compile(
66
+ r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$'
67
+ )
68
+
69
+ for line in text.splitlines():
70
+ match = key_pattern.match(line)
71
+ if not match:
72
+ continue
73
+
74
+ key_raw = match.group("key").strip().lower()
75
+ key_raw = key_raw.strip("*_`\"' ")
76
+ key_raw = re.sub(r"\s+", " ", key_raw)
77
+ canonical = KEY_ALIASES.get(key_raw)
78
+
79
+ if canonical is None:
80
+ continue
81
+
82
+ value_raw = match.group("value").strip()
83
+ value_raw = value_raw.strip("`\"' \t")
84
+ value_raw = re.sub(r"^[*_`]+", "", value_raw)
85
+ value_raw = re.sub(r"[*_`]+$", "", value_raw)
86
+ value_raw = value_raw.strip()
87
+
88
+ if canonical in {"city", "region", "country"}:
89
+ if value_raw and canonical not in parsed:
90
+ parsed[canonical] = value_raw
91
+ elif canonical in {"lat", "lon"}:
92
+ if canonical not in parsed:
93
+ match_num = re.search(r"-?\d+(?:[.,]\d+)?", value_raw)
94
+ if match_num:
95
+ try:
96
+ parsed[canonical] = float(match_num.group(0).replace(",", "."))
97
+ except ValueError:
98
+ pass
99
+
100
+ # Build coords if available
101
+ coords = None
102
+ if "lat" in parsed and "lon" in parsed:
103
+ try:
104
+ lat = parsed["lat"]
105
+ lon = parsed["lon"]
106
+ if -90 <= lat <= 90 and -180 <= lon <= 180:
107
+ coords = Coords(lat=lat, lon=lon)
108
+ except (ValueError, TypeError):
109
+ pass
110
+
111
+ format_valid = bool(len(parsed) >= 2)
112
+
113
+ return ParsedResponse(
114
+ city=parsed.get("city"),
115
+ region=parsed.get("region"),
116
+ country=parsed.get("country"),
117
+ coords=coords,
118
+ raw_text=text,
119
+ format_valid=format_valid,
120
+ )
121
+
122
+ # ============================================================================
123
+ # Model Setup
124
+ # ============================================================================
125
+
126
+ model = None
127
+ processor = None
128
+ MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
129
+
130
+ def load_model():
131
+ """Load model once on startup"""
132
+ global model, processor
133
+ if model is None:
134
+ print(f"Loading model: {MODEL_NAME}")
135
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
136
+ model = AutoModelForImageTextToText.from_pretrained(
137
+ MODEL_NAME,
138
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
139
+ device_map="auto" if torch.cuda.is_available() else "cpu"
140
+ )
141
+ print("Model loaded successfully!")
142
+
143
+ def predict_location(image):
144
+ """Predict geolocation from an image"""
145
+ if image is None:
146
+ return "Please upload an image.", ""
147
+
148
+ # Ensure model is loaded
149
+ load_model()
150
+
151
+ # Convert to PIL if needed
152
+ if not isinstance(image, Image.Image):
153
+ image = Image.fromarray(image).convert("RGB")
154
+ else:
155
+ image = image.convert("RGB")
156
+
157
+ # Prepare prompt
158
+ messages = [
159
+ {
160
+ "role": "user",
161
+ "content": [
162
+ {"type": "image"},
163
+ {"type": "text", "text": PROMPT_TEMPLATE}
164
+ ]
165
+ }
166
+ ]
167
+
168
+ # Process inputs
169
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
170
+ inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
171
+
172
+ # Move to device
173
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
174
+
175
+ # Generate
176
+ with torch.no_grad():
177
+ output_ids = model.generate(
178
+ **inputs,
179
+ max_new_tokens=256,
180
+ do_sample=False,
181
+ )
182
+
183
+ # Decode
184
+ generated_ids = output_ids[0][inputs['input_ids'].shape[1]:]
185
+ response = processor.decode(generated_ids, skip_special_tokens=True).strip()
186
+
187
+ # Parse
188
+ parsed = parse_response(response)
189
+
190
+ # Format output
191
+ output = f"""
192
+ ## 🤖 Raw Model Response:
193
+ ```
194
+ {response}
195
+ ```
196
+
197
+ ---
198
+
199
+ ## 📍 Parsed Prediction:
200
+
201
+ **City:** {parsed.city or "Not provided"}
202
+ **Region:** {parsed.region or "Not provided"}
203
+ **Country:** {parsed.country or "Not provided"}
204
+ **Coordinates:** {f"{parsed.coords.lat:.6f}, {parsed.coords.lon:.6f}" if parsed.coords else "Not provided"}
205
+ **Format Valid:** {"✅ Yes" if parsed.format_valid else "❌ No"}
206
+ """
207
+
208
+ # Create map embed
209
+ map_html = ""
210
+ if parsed.coords:
211
+ map_html = f"""
212
+ <div style="margin-top: 20px;">
213
+ <iframe
214
+ width="100%"
215
+ height="450"
216
+ frameborder="0"
217
+ scrolling="no"
218
+ marginheight="0"
219
+ marginwidth="0"
220
+ src="https://www.openstreetmap.org/export/embed.html?bbox={parsed.coords.lon-0.1},{parsed.coords.lat-0.1},{parsed.coords.lon+0.1},{parsed.coords.lat+0.1}&marker={parsed.coords.lat},{parsed.coords.lon}"
221
+ style="border: 2px solid #ddd; border-radius: 8px;">
222
+ </iframe>
223
+ <div style="margin-top: 10px; text-align: center;">
224
+ <a href="https://www.google.com/maps?q={parsed.coords.lat},{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #4285f4; text-decoration: none; font-weight: bold;">
225
+ 🗺️ View on Google Maps
226
+ </a>
227
+ <span style="color: #666;">|</span>
228
+ <a href="https://www.openstreetmap.org/?mlat={parsed.coords.lat}&mlon={parsed.coords.lon}#map=12/{parsed.coords.lat}/{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #7ebc6f; text-decoration: none; font-weight: bold;">
229
+ 🌍 View on OpenStreetMap
230
+ </a>
231
+ </div>
232
+ </div>
233
+ """
234
+ else:
235
+ map_html = "<div style='text-align: center; padding: 20px; color: #666;'>No valid coordinates found</div>"
236
+
237
+ return output, map_html
238
+
239
+ # ============================================================================
240
+ # Gradio Interface
241
+ # ============================================================================
242
+
243
+ with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo:
244
+ gr.Markdown(
245
+ """
246
+ # 🌍 GeoVLM - AI-Powered Geolocation
247
+
248
+ Upload any image and let AI predict where it was taken using vision-language models!
249
+
250
+ ### How it works:
251
+ - Analyzes visual features: architecture, vegetation, road signs, landscape
252
+ - Uses state-of-the-art vision-language models (Qwen2-VL)
253
+ - Predicts city, region, country, and GPS coordinates
254
+
255
+ **Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct
256
+ """
257
+ )
258
+
259
+ with gr.Row():
260
+ with gr.Column(scale=1):
261
+ image_input = gr.Image(
262
+ type="pil",
263
+ label="📸 Upload Image",
264
+ height=400
265
+ )
266
+ predict_btn = gr.Button("🔍 Predict Location", variant="primary", size="lg")
267
+
268
+ gr.Markdown(
269
+ """
270
+ ### 💡 Tips:
271
+ - Outdoor images work best
272
+ - Street views are ideal
273
+ - Clear photos with visible landmarks
274
+ - Unique architectural or natural features help
275
+ """
276
+ )
277
+
278
+ with gr.Column(scale=1):
279
+ output_text = gr.Markdown(label="Results")
280
+ map_output = gr.HTML(label="Map")
281
+
282
+ gr.Markdown(
283
+ """
284
+ ---
285
+ ### 🎯 Use Cases:
286
+ - **OSINT Research** - Verify photo locations for investigations
287
+ - **GeoGuessr Training** - Practice location identification
288
+ - **Education** - Learn about geographic features and cultures
289
+ - **Travel Planning** - Identify interesting locations from photos
290
+
291
+ ---
292
+
293
+ **Note:** This is a demo. Predictions may not always be accurate. Use responsibly for educational and research purposes.
294
+
295
+ Built with ❤️ using [Gradio](https://gradio.app) and [Hugging Face Transformers](https://huggingface.co/transformers)
296
+ """
297
+ )
298
+
299
+ # Event handlers
300
+ predict_btn.click(
301
+ fn=predict_location,
302
+ inputs=image_input,
303
+ outputs=[output_text, map_output]
304
+ )
305
+
306
+ if __name__ == "__main__":
307
+ print("🚀 Starting GeoVLM...")
308
+ load_model()
309
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ accelerate
6
+ Pillow
7
+ qwen-vl-utils