AceXRoux commited on
Commit
0f4b5d1
·
verified ·
1 Parent(s): 514541f

Update app.py

Browse files

Fixed side extension

Files changed (1) hide show
  1. app.py +112 -125
app.py CHANGED
@@ -6,10 +6,9 @@ Upload any image and predict where it was taken using Vision-Language Models
6
 
7
  import gradio as gr
8
  from PIL import Image
9
- from transformers import AutoProcessor, AutoModelForImageTextToText
10
  import torch
11
  import re
12
- import math
13
  from dataclasses import dataclass
14
 
15
  # ============================================================================
@@ -61,7 +60,6 @@ def parse_response(text: str) -> ParsedResponse:
61
  if not text:
62
  return ParsedResponse(None, None, None, None, text, False)
63
 
64
- # Parse key-value lines
65
  key_pattern = re.compile(
66
  r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$'
67
  )
@@ -97,7 +95,6 @@ def parse_response(text: str) -> ParsedResponse:
97
  except ValueError:
98
  pass
99
 
100
- # Build coords if available
101
  coords = None
102
  if "lat" in parsed and "lon" in parsed:
103
  try:
@@ -131,110 +128,114 @@ def load_model():
131
  """Load model once on startup"""
132
  global model, processor
133
  if model is None:
134
- print(f"Loading model: {MODEL_NAME}")
135
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
136
- model = AutoModelForImageTextToText.from_pretrained(
137
- MODEL_NAME,
138
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
139
- device_map="auto" if torch.cuda.is_available() else "cpu"
140
- )
141
- print("Model loaded successfully!")
 
 
 
 
 
142
 
143
  def predict_location(image):
144
  """Predict geolocation from an image"""
145
- if image is None:
146
- return "Please upload an image.", ""
147
-
148
- # Ensure model is loaded
149
- load_model()
150
-
151
- # Convert to PIL if needed
152
- if not isinstance(image, Image.Image):
153
- image = Image.fromarray(image).convert("RGB")
154
- else:
155
- image = image.convert("RGB")
156
-
157
- # Prepare prompt
158
- messages = [
159
- {
160
- "role": "user",
161
- "content": [
162
- {"type": "image"},
163
- {"type": "text", "text": PROMPT_TEMPLATE}
164
- ]
165
- }
166
- ]
167
-
168
- # Process inputs
169
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
170
- inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
171
-
172
- # Move to device
173
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
174
-
175
- # Generate
176
- with torch.no_grad():
177
- output_ids = model.generate(
178
- **inputs,
179
- max_new_tokens=256,
180
- do_sample=False,
181
- )
182
-
183
- # Decode
184
- generated_ids = output_ids[0][inputs['input_ids'].shape[1]:]
185
- response = processor.decode(generated_ids, skip_special_tokens=True).strip()
186
-
187
- # Parse
188
- parsed = parse_response(response)
189
-
190
- # Format output
191
- output = f"""
192
- ## 🤖 Raw Model Response:
193
- ```
194
- {response}
195
- ```
196
 
197
- ---
 
 
 
 
198
 
199
- ## 📍 Parsed Prediction:
200
 
201
- **City:** {parsed.city or "Not provided"}
202
- **Region:** {parsed.region or "Not provided"}
203
- **Country:** {parsed.country or "Not provided"}
204
- **Coordinates:** {f"{parsed.coords.lat:.6f}, {parsed.coords.lon:.6f}" if parsed.coords else "Not provided"}
205
- **Format Valid:** {"✅ Yes" if parsed.format_valid else "❌ No"}
206
  """
207
-
208
- # Create map embed
209
- map_html = ""
210
- if parsed.coords:
211
- map_html = f"""
212
- <div style="margin-top: 20px;">
213
- <iframe
214
- width="100%"
215
- height="450"
216
- frameborder="0"
217
- scrolling="no"
218
- marginheight="0"
219
- marginwidth="0"
220
- src="https://www.openstreetmap.org/export/embed.html?bbox={parsed.coords.lon-0.1},{parsed.coords.lat-0.1},{parsed.coords.lon+0.1},{parsed.coords.lat+0.1}&marker={parsed.coords.lat},{parsed.coords.lon}"
221
- style="border: 2px solid #ddd; border-radius: 8px;">
222
- </iframe>
223
- <div style="margin-top: 10px; text-align: center;">
224
- <a href="https://www.google.com/maps?q={parsed.coords.lat},{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #4285f4; text-decoration: none; font-weight: bold;">
225
- 🗺️ View on Google Maps
226
- </a>
227
- <span style="color: #666;">|</span>
228
- <a href="https://www.openstreetmap.org/?mlat={parsed.coords.lat}&mlon={parsed.coords.lon}#map=12/{parsed.coords.lat}/{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #7ebc6f; text-decoration: none; font-weight: bold;">
229
- 🌍 View on OpenStreetMap
230
- </a>
231
  </div>
232
- </div>
233
- """
234
- else:
235
- map_html = "<div style='text-align: center; padding: 20px; color: #666;'>No valid coordinates found</div>"
236
-
237
- return output, map_html
 
 
 
 
 
 
238
 
239
  # ============================================================================
240
  # Gradio Interface
@@ -245,12 +246,7 @@ with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo:
245
  """
246
  # 🌍 GeoVLM - AI-Powered Geolocation
247
 
248
- Upload any image and let AI predict where it was taken using vision-language models!
249
-
250
- ### How it works:
251
- - Analyzes visual features: architecture, vegetation, road signs, landscape
252
- - Uses state-of-the-art vision-language models (Qwen2-VL)
253
- - Predicts city, region, country, and GPS coordinates
254
 
255
  **Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct
256
  """
@@ -258,11 +254,7 @@ with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo:
258
 
259
  with gr.Row():
260
  with gr.Column(scale=1):
261
- image_input = gr.Image(
262
- type="pil",
263
- label="📸 Upload Image",
264
- height=400
265
- )
266
  predict_btn = gr.Button("🔍 Predict Location", variant="primary", size="lg")
267
 
268
  gr.Markdown(
@@ -276,34 +268,29 @@ with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo:
276
  )
277
 
278
  with gr.Column(scale=1):
279
- output_text = gr.Markdown(label="Results")
280
- map_output = gr.HTML(label="Map")
281
 
282
  gr.Markdown(
283
  """
284
  ---
285
  ### 🎯 Use Cases:
286
- - **OSINT Research** - Verify photo locations for investigations
287
  - **GeoGuessr Training** - Practice location identification
288
- - **Education** - Learn about geographic features and cultures
289
- - **Travel Planning** - Identify interesting locations from photos
290
 
291
  ---
292
 
293
- **Note:** This is a demo. Predictions may not always be accurate. Use responsibly for educational and research purposes.
294
 
295
- Built with ❤️ using [Gradio](https://gradio.app) and [Hugging Face Transformers](https://huggingface.co/transformers)
296
  """
297
  )
298
 
299
- # Event handlers
300
- predict_btn.click(
301
- fn=predict_location,
302
- inputs=image_input,
303
- outputs=[output_text, map_output]
304
- )
305
 
306
  if __name__ == "__main__":
307
  print("🚀 Starting GeoVLM...")
308
  load_model()
309
- demo.launch()
 
6
 
7
  import gradio as gr
8
  from PIL import Image
9
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
10
  import torch
11
  import re
 
12
  from dataclasses import dataclass
13
 
14
  # ============================================================================
 
60
  if not text:
61
  return ParsedResponse(None, None, None, None, text, False)
62
 
 
63
  key_pattern = re.compile(
64
  r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$'
65
  )
 
95
  except ValueError:
96
  pass
97
 
 
98
  coords = None
99
  if "lat" in parsed and "lon" in parsed:
100
  try:
 
128
  """Load model once on startup"""
129
  global model, processor
130
  if model is None:
131
+ print(f"🔄 Loading model: {MODEL_NAME}")
132
+ try:
133
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
134
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
135
+ MODEL_NAME,
136
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
137
+ device_map="auto",
138
+ trust_remote_code=True
139
+ )
140
+ print("✅ Model loaded successfully!")
141
+ except Exception as e:
142
+ print(f"❌ Error loading model: {e}")
143
+ raise
144
 
145
  def predict_location(image):
146
  """Predict geolocation from an image"""
147
+ try:
148
+ if image is None:
149
+ return "⚠️ Please upload an image.", ""
150
+
151
+ load_model()
152
+
153
+ print("📸 Processing image...")
154
+
155
+ if not isinstance(image, Image.Image):
156
+ image = Image.fromarray(image).convert("RGB")
157
+ else:
158
+ image = image.convert("RGB")
159
+
160
+ messages = [
161
+ {
162
+ "role": "user",
163
+ "content": [
164
+ {"type": "image", "image": image},
165
+ {"type": "text", "text": PROMPT_TEMPLATE}
166
+ ]
167
+ }
168
+ ]
169
+
170
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
171
+ inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
172
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
173
+
174
+ print("🤖 Generating prediction...")
175
+
176
+ with torch.no_grad():
177
+ output_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False)
178
+
179
+ generated_ids = output_ids[0][inputs['input_ids'].shape[1]:]
180
+ response = processor.decode(generated_ids, skip_special_tokens=True).strip()
181
+
182
+ print(f"✅ Response generated")
183
+
184
+ parsed = parse_response(response)
185
+
186
+ output = f"""
187
+ ## 🤖 AI Prediction
 
 
 
 
 
 
 
 
 
 
188
 
189
+ **📍 Location Details:**
190
+ - **City:** {parsed.city or "Unknown"}
191
+ - **Region:** {parsed.region or "Unknown"}
192
+ - **Country:** {parsed.country or "Unknown"}
193
+ - **Coordinates:** {f"{parsed.coords.lat:.6f}°, {parsed.coords.lon:.6f}°" if parsed.coords else "Not found"}
194
 
195
+ ---
196
 
197
+ ## 🔍 Raw Response:
198
+ ```
199
+ {response}
200
+ ```
 
201
  """
202
+
203
+ map_html = ""
204
+ if parsed.coords:
205
+ map_html = f"""
206
+ <div style="margin-top: 20px;">
207
+ <iframe
208
+ width="100%"
209
+ height="450"
210
+ frameborder="0"
211
+ scrolling="no"
212
+ marginheight="0"
213
+ marginwidth="0"
214
+ src="https://www.openstreetmap.org/export/embed.html?bbox={parsed.coords.lon-0.1},{parsed.coords.lat-0.1},{parsed.coords.lon+0.1},{parsed.coords.lat+0.1}&marker={parsed.coords.lat},{parsed.coords.lon}"
215
+ style="border: 2px solid #ddd; border-radius: 8px;">
216
+ </iframe>
217
+ <div style="margin-top: 10px; text-align: center;">
218
+ <a href="https://www.google.com/maps?q={parsed.coords.lat},{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #4285f4; text-decoration: none; font-weight: bold;">
219
+ 🗺️ Google Maps
220
+ </a>
221
+ <span style="color: #666;">|</span>
222
+ <a href="https://www.openstreetmap.org/?mlat={parsed.coords.lat}&mlon={parsed.coords.lon}#map=12/{parsed.coords.lat}/{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #7ebc6f; text-decoration: none; font-weight: bold;">
223
+ 🌍 OpenStreetMap
224
+ </a>
225
+ </div>
226
  </div>
227
+ """
228
+ else:
229
+ map_html = "<div style='text-align: center; padding: 20px; color: #666;'>❌ No valid coordinates found</div>"
230
+
231
+ return output, map_html
232
+
233
+ except Exception as e:
234
+ error_msg = f"❌ Error: {str(e)}"
235
+ print(error_msg)
236
+ import traceback
237
+ traceback.print_exc()
238
+ return error_msg, ""
239
 
240
  # ============================================================================
241
  # Gradio Interface
 
246
  """
247
  # 🌍 GeoVLM - AI-Powered Geolocation
248
 
249
+ Upload any image and let AI predict where it was taken!
 
 
 
 
 
250
 
251
  **Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct
252
  """
 
254
 
255
  with gr.Row():
256
  with gr.Column(scale=1):
257
+ image_input = gr.Image(type="pil", label="📸 Upload Image", height=400)
 
 
 
 
258
  predict_btn = gr.Button("🔍 Predict Location", variant="primary", size="lg")
259
 
260
  gr.Markdown(
 
268
  )
269
 
270
  with gr.Column(scale=1):
271
+ output_text = gr.Markdown(label="📊 Results")
272
+ map_output = gr.HTML(label="🗺️ Map Location")
273
 
274
  gr.Markdown(
275
  """
276
  ---
277
  ### 🎯 Use Cases:
278
+ - **OSINT Research** - Verify photo locations
279
  - **GeoGuessr Training** - Practice location identification
280
+ - **Education** - Learn world geography
281
+ - **Travel** - Discover interesting places
282
 
283
  ---
284
 
285
+ **Note:** Predictions take 2-5 minutes on CPU. Accuracy varies by location.
286
 
287
+ Built by [Vance Poitier](https://www.linkedin.com/in/vance-poitier/) | Based on [vlm-gym](https://github.com/sdan/vlm-gym)
288
  """
289
  )
290
 
291
+ predict_btn.click(fn=predict_location, inputs=image_input, outputs=[output_text, map_output])
 
 
 
 
 
292
 
293
  if __name__ == "__main__":
294
  print("🚀 Starting GeoVLM...")
295
  load_model()
296
+ demo.launch(server_name="0.0.0.0", server_port=7860)