|
|
|
|
|
""" |
|
|
GeoVLM - AI-Powered Geolocation |
|
|
Upload any image and predict where it was taken using Vision-Language Models |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration |
|
|
import torch |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class Coords: |
|
|
"""Geographic coordinates""" |
|
|
lat: float |
|
|
lon: float |
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class ParsedResponse: |
|
|
"""Structured model output""" |
|
|
city: str | None |
|
|
region: str | None |
|
|
country: str | None |
|
|
coords: Coords | None |
|
|
raw_text: str |
|
|
format_valid: bool |
|
|
|
|
|
PROMPT_TEMPLATE = ( |
|
|
"Look at the image and guess the location.\n" |
|
|
"Respond with EXACTLY these 5 lines, no extra text:\n" |
|
|
"City: <city name>\n" |
|
|
"Region: <state or region>\n" |
|
|
"Country: <country name or ISO-2 code>\n" |
|
|
"Latitude: <number between -90 and 90>\n" |
|
|
"Longitude: <number between -180 and 180>\n" |
|
|
) |
|
|
|
|
|
KEY_ALIASES = { |
|
|
"city": "city", |
|
|
"country": "country", |
|
|
"region": "region", |
|
|
"state": "region", |
|
|
"province": "region", |
|
|
"latitude": "lat", |
|
|
"lat": "lat", |
|
|
"longitude": "lon", |
|
|
"lon": "lon", |
|
|
} |
|
|
|
|
|
def parse_response(text: str) -> ParsedResponse: |
|
|
"""Parse structured 5-line format""" |
|
|
parsed = {} |
|
|
|
|
|
if not text: |
|
|
return ParsedResponse(None, None, None, None, text, False) |
|
|
|
|
|
key_pattern = re.compile( |
|
|
r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$' |
|
|
) |
|
|
|
|
|
for line in text.splitlines(): |
|
|
match = key_pattern.match(line) |
|
|
if not match: |
|
|
continue |
|
|
|
|
|
key_raw = match.group("key").strip().lower() |
|
|
key_raw = key_raw.strip("*_`\"' ") |
|
|
key_raw = re.sub(r"\s+", " ", key_raw) |
|
|
canonical = KEY_ALIASES.get(key_raw) |
|
|
|
|
|
if canonical is None: |
|
|
continue |
|
|
|
|
|
value_raw = match.group("value").strip() |
|
|
value_raw = value_raw.strip("`\"' \t") |
|
|
value_raw = re.sub(r"^[*_`]+", "", value_raw) |
|
|
value_raw = re.sub(r"[*_`]+$", "", value_raw) |
|
|
value_raw = value_raw.strip() |
|
|
|
|
|
if canonical in {"city", "region", "country"}: |
|
|
if value_raw and canonical not in parsed: |
|
|
parsed[canonical] = value_raw |
|
|
elif canonical in {"lat", "lon"}: |
|
|
if canonical not in parsed: |
|
|
match_num = re.search(r"-?\d+(?:[.,]\d+)?", value_raw) |
|
|
if match_num: |
|
|
try: |
|
|
parsed[canonical] = float(match_num.group(0).replace(",", ".")) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
coords = None |
|
|
if "lat" in parsed and "lon" in parsed: |
|
|
try: |
|
|
lat = parsed["lat"] |
|
|
lon = parsed["lon"] |
|
|
if -90 <= lat <= 90 and -180 <= lon <= 180: |
|
|
coords = Coords(lat=lat, lon=lon) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
format_valid = bool(len(parsed) >= 2) |
|
|
|
|
|
return ParsedResponse( |
|
|
city=parsed.get("city"), |
|
|
region=parsed.get("region"), |
|
|
country=parsed.get("country"), |
|
|
coords=coords, |
|
|
raw_text=text, |
|
|
format_valid=format_valid, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" |
|
|
|
|
|
def load_model(): |
|
|
"""Load model once on startup""" |
|
|
global model, processor |
|
|
if model is None: |
|
|
print(f"🔄 Loading model: {MODEL_NAME}") |
|
|
try: |
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("✅ Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
def predict_location(image): |
|
|
"""Predict geolocation from an image""" |
|
|
try: |
|
|
if image is None: |
|
|
return "⚠️ Please upload an image.", "" |
|
|
|
|
|
load_model() |
|
|
|
|
|
print("📸 Processing image...") |
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image).convert("RGB") |
|
|
else: |
|
|
image = image.convert("RGB") |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": PROMPT_TEMPLATE} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
print("🤖 Generating prediction...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False) |
|
|
|
|
|
generated_ids = output_ids[0][inputs['input_ids'].shape[1]:] |
|
|
response = processor.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
print(f"✅ Response generated") |
|
|
|
|
|
parsed = parse_response(response) |
|
|
|
|
|
output = f""" |
|
|
## 🤖 AI Prediction |
|
|
|
|
|
**📍 Location Details:** |
|
|
- **City:** {parsed.city or "Unknown"} |
|
|
- **Region:** {parsed.region or "Unknown"} |
|
|
- **Country:** {parsed.country or "Unknown"} |
|
|
- **Coordinates:** {f"{parsed.coords.lat:.6f}°, {parsed.coords.lon:.6f}°" if parsed.coords else "Not found"} |
|
|
|
|
|
--- |
|
|
|
|
|
## 🔍 Raw Response: |
|
|
``` |
|
|
{response} |
|
|
``` |
|
|
""" |
|
|
|
|
|
map_html = "" |
|
|
if parsed.coords: |
|
|
map_html = f""" |
|
|
<div style="margin-top: 20px;"> |
|
|
<iframe |
|
|
width="100%" |
|
|
height="450" |
|
|
frameborder="0" |
|
|
scrolling="no" |
|
|
marginheight="0" |
|
|
marginwidth="0" |
|
|
src="https://www.openstreetmap.org/export/embed.html?bbox={parsed.coords.lon-0.1},{parsed.coords.lat-0.1},{parsed.coords.lon+0.1},{parsed.coords.lat+0.1}&marker={parsed.coords.lat},{parsed.coords.lon}" |
|
|
style="border: 2px solid #ddd; border-radius: 8px;"> |
|
|
</iframe> |
|
|
<div style="margin-top: 10px; text-align: center;"> |
|
|
<a href="https://www.google.com/maps?q={parsed.coords.lat},{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #4285f4; text-decoration: none; font-weight: bold;"> |
|
|
🗺️ Google Maps |
|
|
</a> |
|
|
<span style="color: #666;">|</span> |
|
|
<a href="https://www.openstreetmap.org/?mlat={parsed.coords.lat}&mlon={parsed.coords.lon}#map=12/{parsed.coords.lat}/{parsed.coords.lon}" target="_blank" style="margin: 0 10px; color: #7ebc6f; text-decoration: none; font-weight: bold;"> |
|
|
🌍 OpenStreetMap |
|
|
</a> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
else: |
|
|
map_html = "<div style='text-align: center; padding: 20px; color: #666;'>❌ No valid coordinates found</div>" |
|
|
|
|
|
return output, map_html |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Error: {str(e)}" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return error_msg, "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🌍 GeoVLM - AI-Powered Geolocation |
|
|
|
|
|
Upload any image and let AI predict where it was taken! |
|
|
|
|
|
**Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image(type="pil", label="📸 Upload Image", height=400) |
|
|
predict_btn = gr.Button("🔍 Predict Location", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 💡 Tips: |
|
|
- Outdoor images work best |
|
|
- Street views are ideal |
|
|
- Clear photos with visible landmarks |
|
|
- Unique architectural or natural features help |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_text = gr.Markdown(label="📊 Results") |
|
|
map_output = gr.HTML(label="🗺️ Map Location") |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### 🎯 Use Cases: |
|
|
- **OSINT Research** - Verify photo locations |
|
|
- **GeoGuessr Training** - Practice location identification |
|
|
- **Education** - Learn world geography |
|
|
- **Travel** - Discover interesting places |
|
|
|
|
|
--- |
|
|
|
|
|
**Note:** Predictions take 2-5 minutes on CPU. Accuracy varies by location. |
|
|
|
|
|
Built by [Vance Poitier](https://www.linkedin.com/in/vance-poitier/) | Based on [vlm-gym](https://github.com/sdan/vlm-gym) |
|
|
""" |
|
|
) |
|
|
|
|
|
predict_btn.click(fn=predict_location, inputs=image_input, outputs=[output_text, map_output]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Starting GeoVLM...") |
|
|
load_model() |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |