GeoVLM / app.py
AceXRoux's picture
Update app.py
0f4b5d1 verified
#!/usr/bin/env python3
"""
GeoVLM - AI-Powered Geolocation
Upload any image and predict where it was taken using Vision-Language Models
"""
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
import torch
import re
from dataclasses import dataclass
# ============================================================================
# Simplified Geolocation Parser (from vlm-gym)
# ============================================================================
@dataclass(frozen=True)
class Coords:
"""Geographic coordinates"""
lat: float
lon: float
@dataclass(frozen=True)
class ParsedResponse:
"""Structured model output"""
city: str | None
region: str | None
country: str | None
coords: Coords | None
raw_text: str
format_valid: bool
PROMPT_TEMPLATE = (
"Look at the image and guess the location.\n"
"Respond with EXACTLY these 5 lines, no extra text:\n"
"City: <city name>\n"
"Region: <state or region>\n"
"Country: <country name or ISO-2 code>\n"
"Latitude: <number between -90 and 90>\n"
"Longitude: <number between -180 and 180>\n"
)
KEY_ALIASES = {
"city": "city",
"country": "country",
"region": "region",
"state": "region",
"province": "region",
"latitude": "lat",
"lat": "lat",
"longitude": "lon",
"lon": "lon",
}
def parse_response(text: str) -> ParsedResponse:
"""Parse structured 5-line format"""
parsed = {}
if not text:
return ParsedResponse(None, None, None, None, text, False)
key_pattern = re.compile(
r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$'
)
for line in text.splitlines():
match = key_pattern.match(line)
if not match:
continue
key_raw = match.group("key").strip().lower()
key_raw = key_raw.strip("*_`\"' ")
key_raw = re.sub(r"\s+", " ", key_raw)
canonical = KEY_ALIASES.get(key_raw)
if canonical is None:
continue
value_raw = match.group("value").strip()
value_raw = value_raw.strip("`\"' \t")
value_raw = re.sub(r"^[*_`]+", "", value_raw)
value_raw = re.sub(r"[*_`]+$", "", value_raw)
value_raw = value_raw.strip()
if canonical in {"city", "region", "country"}:
if value_raw and canonical not in parsed:
parsed[canonical] = value_raw
elif canonical in {"lat", "lon"}:
if canonical not in parsed:
match_num = re.search(r"-?\d+(?:[.,]\d+)?", value_raw)
if match_num:
try:
parsed[canonical] = float(match_num.group(0).replace(",", "."))
except ValueError:
pass
coords = None
if "lat" in parsed and "lon" in parsed:
try:
lat = parsed["lat"]
lon = parsed["lon"]
if -90 <= lat <= 90 and -180 <= lon <= 180:
coords = Coords(lat=lat, lon=lon)
except (ValueError, TypeError):
pass
format_valid = bool(len(parsed) >= 2)
return ParsedResponse(
city=parsed.get("city"),
region=parsed.get("region"),
country=parsed.get("country"),
coords=coords,
raw_text=text,
format_valid=format_valid,
)
# ============================================================================
# Model Setup
# ============================================================================
model = None
processor = None
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
def load_model():
"""Load model once on startup"""
global model, processor
if model is None:
print(f"🔄 Loading model: {MODEL_NAME}")
try:
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def predict_location(image):
"""Predict geolocation from an image"""
try:
if image is None:
return "⚠️ Please upload an image.", ""
load_model()
print("📸 Processing image...")
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": PROMPT_TEMPLATE}
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
print("🤖 Generating prediction...")
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False)
generated_ids = output_ids[0][inputs['input_ids'].shape[1]:]
response = processor.decode(generated_ids, skip_special_tokens=True).strip()
print(f"✅ Response generated")
parsed = parse_response(response)
output = f"""
## 🤖 AI Prediction
**📍 Location Details:**
- **City:** {parsed.city or "Unknown"}
- **Region:** {parsed.region or "Unknown"}
- **Country:** {parsed.country or "Unknown"}
- **Coordinates:** {f"{parsed.coords.lat:.6f}°, {parsed.coords.lon:.6f}°" if parsed.coords else "Not found"}
---
## 🔍 Raw Response:
```
{response}
```
"""
map_html = ""
if parsed.coords:
map_html = f"""
<div style="margin-top: 20px;">
<iframe
width="100%"
height="450"
frameborder="0"
scrolling="no"
marginheight="0"
marginwidth="0"
src="https://www.openstreetmap.org/export/embed.html?bbox={parsed.coords.lon-0.1},{parsed.coords.lat-0.1},{parsed.coords.lon+0.1},{parsed.coords.lat+0.1}&marker={parsed.coords.lat},{parsed.coords.lon}"
style="border: 2px solid #ddd; border-radius: 8px;">
</iframe>
<div style="margin-top: 10px; text-align: center;">
<a href="https://www.google.com/maps?q={parsed.coords.lat},{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #4285f4; text-decoration: none; font-weight: bold;">
🗺️ Google Maps
</a>
<span style="color: #666;">|</span>
<a href="https://www.openstreetmap.org/?mlat={parsed.coords.lat}&mlon={parsed.coords.lon}#map=12/{parsed.coords.lat}/{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #7ebc6f; text-decoration: none; font-weight: bold;">
🌍 OpenStreetMap
</a>
</div>
</div>
"""
else:
map_html = "<div style='text-align: center; padding: 20px; color: #666;'>❌ No valid coordinates found</div>"
return output, map_html
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return error_msg, ""
# ============================================================================
# Gradio Interface
# ============================================================================
with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🌍 GeoVLM - AI-Powered Geolocation
Upload any image and let AI predict where it was taken!
**Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="📸 Upload Image", height=400)
predict_btn = gr.Button("🔍 Predict Location", variant="primary", size="lg")
gr.Markdown(
"""
### 💡 Tips:
- Outdoor images work best
- Street views are ideal
- Clear photos with visible landmarks
- Unique architectural or natural features help
"""
)
with gr.Column(scale=1):
output_text = gr.Markdown(label="📊 Results")
map_output = gr.HTML(label="🗺️ Map Location")
gr.Markdown(
"""
---
### 🎯 Use Cases:
- **OSINT Research** - Verify photo locations
- **GeoGuessr Training** - Practice location identification
- **Education** - Learn world geography
- **Travel** - Discover interesting places
---
**Note:** Predictions take 2-5 minutes on CPU. Accuracy varies by location.
Built by [Vance Poitier](https://www.linkedin.com/in/vance-poitier/) | Based on [vlm-gym](https://github.com/sdan/vlm-gym)
"""
)
predict_btn.click(fn=predict_location, inputs=image_input, outputs=[output_text, map_output])
if __name__ == "__main__":
print("🚀 Starting GeoVLM...")
load_model()
demo.launch(server_name="0.0.0.0", server_port=7860)