#!/usr/bin/env python3 """ GeoVLM - AI-Powered Geolocation Upload any image and predict where it was taken using Vision-Language Models """ import gradio as gr from PIL import Image from transformers import AutoProcessor, Qwen2VLForConditionalGeneration import torch import re from dataclasses import dataclass # ============================================================================ # Simplified Geolocation Parser (from vlm-gym) # ============================================================================ @dataclass(frozen=True) class Coords: """Geographic coordinates""" lat: float lon: float @dataclass(frozen=True) class ParsedResponse: """Structured model output""" city: str | None region: str | None country: str | None coords: Coords | None raw_text: str format_valid: bool PROMPT_TEMPLATE = ( "Look at the image and guess the location.\n" "Respond with EXACTLY these 5 lines, no extra text:\n" "City: \n" "Region: \n" "Country: \n" "Latitude: \n" "Longitude: \n" ) KEY_ALIASES = { "city": "city", "country": "country", "region": "region", "state": "region", "province": "region", "latitude": "lat", "lat": "lat", "longitude": "lon", "lon": "lon", } def parse_response(text: str) -> ParsedResponse: """Parse structured 5-line format""" parsed = {} if not text: return ParsedResponse(None, None, None, None, text, False) key_pattern = re.compile( r'^\s*(?:[-*+\u2022]\s*)?(?P[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P.+)$' ) for line in text.splitlines(): match = key_pattern.match(line) if not match: continue key_raw = match.group("key").strip().lower() key_raw = key_raw.strip("*_`\"' ") key_raw = re.sub(r"\s+", " ", key_raw) canonical = KEY_ALIASES.get(key_raw) if canonical is None: continue value_raw = match.group("value").strip() value_raw = value_raw.strip("`\"' \t") value_raw = re.sub(r"^[*_`]+", "", value_raw) value_raw = re.sub(r"[*_`]+$", "", value_raw) value_raw = value_raw.strip() if canonical in {"city", "region", "country"}: if value_raw and canonical not in parsed: parsed[canonical] = value_raw elif canonical in {"lat", "lon"}: if canonical not in parsed: match_num = re.search(r"-?\d+(?:[.,]\d+)?", value_raw) if match_num: try: parsed[canonical] = float(match_num.group(0).replace(",", ".")) except ValueError: pass coords = None if "lat" in parsed and "lon" in parsed: try: lat = parsed["lat"] lon = parsed["lon"] if -90 <= lat <= 90 and -180 <= lon <= 180: coords = Coords(lat=lat, lon=lon) except (ValueError, TypeError): pass format_valid = bool(len(parsed) >= 2) return ParsedResponse( city=parsed.get("city"), region=parsed.get("region"), country=parsed.get("country"), coords=coords, raw_text=text, format_valid=format_valid, ) # ============================================================================ # Model Setup # ============================================================================ model = None processor = None MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" def load_model(): """Load model once on startup""" global model, processor if model is None: print(f"πŸ”„ Loading model: {MODEL_NAME}") try: processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True ) print("βœ… Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") raise def predict_location(image): """Predict geolocation from an image""" try: if image is None: return "⚠️ Please upload an image.", "" load_model() print("πŸ“Έ Processing image...") if not isinstance(image, Image.Image): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": PROMPT_TEMPLATE} ] } ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} print("πŸ€– Generating prediction...") with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False) generated_ids = output_ids[0][inputs['input_ids'].shape[1]:] response = processor.decode(generated_ids, skip_special_tokens=True).strip() print(f"βœ… Response generated") parsed = parse_response(response) output = f""" ## πŸ€– AI Prediction **πŸ“ Location Details:** - **City:** {parsed.city or "Unknown"} - **Region:** {parsed.region or "Unknown"} - **Country:** {parsed.country or "Unknown"} - **Coordinates:** {f"{parsed.coords.lat:.6f}Β°, {parsed.coords.lon:.6f}Β°" if parsed.coords else "Not found"} --- ## πŸ” Raw Response: ``` {response} ``` """ map_html = "" if parsed.coords: map_html = f""" """ else: map_html = "
❌ No valid coordinates found
" return output, map_html except Exception as e: error_msg = f"❌ Error: {str(e)}" print(error_msg) import traceback traceback.print_exc() return error_msg, "" # ============================================================================ # Gradio Interface # ============================================================================ with gr.Blocks(title="GeoVLM - AI Geolocation", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🌍 GeoVLM - AI-Powered Geolocation Upload any image and let AI predict where it was taken! **Powered by [vlm-gym](https://github.com/sdan/vlm-gym)** | Model: Qwen2-VL-2B-Instruct """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="πŸ“Έ Upload Image", height=400) predict_btn = gr.Button("πŸ” Predict Location", variant="primary", size="lg") gr.Markdown( """ ### πŸ’‘ Tips: - Outdoor images work best - Street views are ideal - Clear photos with visible landmarks - Unique architectural or natural features help """ ) with gr.Column(scale=1): output_text = gr.Markdown(label="πŸ“Š Results") map_output = gr.HTML(label="πŸ—ΊοΈ Map Location") gr.Markdown( """ --- ### 🎯 Use Cases: - **OSINT Research** - Verify photo locations - **GeoGuessr Training** - Practice location identification - **Education** - Learn world geography - **Travel** - Discover interesting places --- **Note:** Predictions take 2-5 minutes on CPU. Accuracy varies by location. Built by [Vance Poitier](https://www.linkedin.com/in/vance-poitier/) | Based on [vlm-gym](https://github.com/sdan/vlm-gym) """ ) predict_btn.click(fn=predict_location, inputs=image_input, outputs=[output_text, map_output]) if __name__ == "__main__": print("πŸš€ Starting GeoVLM...") load_model() demo.launch(server_name="0.0.0.0", server_port=7860)