|
|
|
|
|
""" |
|
|
GeoVLM with 3D Globe Visualization |
|
|
Interactive 3D globe that flies to predicted locations |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
import torch |
|
|
import re |
|
|
import json |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class Coords: |
|
|
lat: float |
|
|
lon: float |
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class ParsedResponse: |
|
|
city: str | None |
|
|
region: str | None |
|
|
country: str | None |
|
|
coords: Coords | None |
|
|
raw_text: str |
|
|
format_valid: bool |
|
|
|
|
|
PROMPT_TEMPLATE = ( |
|
|
"Look at the image and guess the location.\n" |
|
|
"Respond with EXACTLY these 5 lines, no extra text:\n" |
|
|
"City: <city name>\n" |
|
|
"Region: <state or region>\n" |
|
|
"Country: <country name or ISO-2 code>\n" |
|
|
"Latitude: <number between -90 and 90>\n" |
|
|
"Longitude: <number between -180 and 180>\n" |
|
|
) |
|
|
|
|
|
KEY_ALIASES = { |
|
|
"city": "city", "country": "country", "region": "region", |
|
|
"state": "region", "province": "region", |
|
|
"latitude": "lat", "lat": "lat", |
|
|
"longitude": "lon", "lon": "lon", |
|
|
} |
|
|
|
|
|
def parse_response(text: str) -> ParsedResponse: |
|
|
"""Parse structured 5-line format""" |
|
|
parsed = {} |
|
|
if not text: |
|
|
return ParsedResponse(None, None, None, None, text, False) |
|
|
|
|
|
key_pattern = re.compile( |
|
|
r'^\s*(?:[-*+\u2022]\s*)?(?P<key>[A-Za-z][A-Za-z0-9\s\-/_.]*?)\s*:\s*(?P<value>.+)$' |
|
|
) |
|
|
|
|
|
for line in text.splitlines(): |
|
|
match = key_pattern.match(line) |
|
|
if not match: |
|
|
continue |
|
|
|
|
|
key_raw = match.group("key").strip().lower() |
|
|
key_raw = re.sub(r"\s+", " ", key_raw.strip("*_`\"' ")) |
|
|
canonical = KEY_ALIASES.get(key_raw) |
|
|
|
|
|
if canonical is None: |
|
|
continue |
|
|
|
|
|
value_raw = match.group("value").strip().strip("`\"' \t") |
|
|
value_raw = re.sub(r"^[*_`]+|[*_`]+$", "", value_raw).strip() |
|
|
|
|
|
if canonical in {"city", "region", "country"}: |
|
|
if value_raw and canonical not in parsed: |
|
|
parsed[canonical] = value_raw |
|
|
elif canonical in {"lat", "lon"}: |
|
|
if canonical not in parsed: |
|
|
match_num = re.search(r"-?\d+(?:[.,]\d+)?", value_raw) |
|
|
if match_num: |
|
|
try: |
|
|
parsed[canonical] = float(match_num.group(0).replace(",", ".")) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
coords = None |
|
|
if "lat" in parsed and "lon" in parsed: |
|
|
try: |
|
|
lat, lon = parsed["lat"], parsed["lon"] |
|
|
if -90 <= lat <= 90 and -180 <= lon <= 180: |
|
|
coords = Coords(lat=lat, lon=lon) |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
return ParsedResponse( |
|
|
city=parsed.get("city"), |
|
|
region=parsed.get("region"), |
|
|
country=parsed.get("country"), |
|
|
coords=coords, |
|
|
raw_text=text, |
|
|
format_valid=bool(len(parsed) >= 2), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" |
|
|
|
|
|
def load_model(): |
|
|
global model, processor |
|
|
if model is None: |
|
|
print(f"Loading model: {MODEL_NAME}") |
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else "cpu" |
|
|
) |
|
|
print("Model loaded!") |
|
|
|
|
|
def predict_location(image): |
|
|
"""Predict geolocation and return globe visualization data""" |
|
|
if image is None: |
|
|
return "Please upload an image.", "", "" |
|
|
|
|
|
load_model() |
|
|
|
|
|
if not isinstance(image, Image.Image): |
|
|
image = Image.fromarray(image).convert("RGB") |
|
|
else: |
|
|
image = image.convert("RGB") |
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image"}, |
|
|
{"type": "text", "text": PROMPT_TEMPLATE} |
|
|
] |
|
|
}] |
|
|
|
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate(**inputs, max_new_tokens=256, do_sample=False) |
|
|
|
|
|
generated_ids = output_ids[0][inputs['input_ids'].shape[1]:] |
|
|
response = processor.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
parsed = parse_response(response) |
|
|
|
|
|
|
|
|
output = f""" |
|
|
## 🤖 AI Prediction |
|
|
|
|
|
**📍 Location Details:** |
|
|
- **City:** {parsed.city or "Unknown"} |
|
|
- **Region:** {parsed.region or "Unknown"} |
|
|
- **Country:** {parsed.country or "Unknown"} |
|
|
- **Coordinates:** {f"{parsed.coords.lat:.6f}°, {parsed.coords.lon:.6f}°" if parsed.coords else "Not found"} |
|
|
|
|
|
--- |
|
|
|
|
|
## 🔍 Raw Response: |
|
|
``` |
|
|
{response} |
|
|
``` |
|
|
""" |
|
|
|
|
|
|
|
|
globe_html = create_globe_html(parsed) if parsed.coords else "<div style='text-align:center; padding:50px; color:#666;'>No coordinates found</div>" |
|
|
|
|
|
|
|
|
info_html = create_info_card(parsed) |
|
|
|
|
|
return output, globe_html, info_html |
|
|
|
|
|
def create_globe_html(parsed: ParsedResponse) -> str: |
|
|
"""Create Three.js globe visualization with day/night toggle and country borders""" |
|
|
if not parsed.coords: |
|
|
return "" |
|
|
|
|
|
lat, lon = parsed.coords.lat, parsed.coords.lon |
|
|
|
|
|
html = f""" |
|
|
<!DOCTYPE html> |
|
|
<html> |
|
|
<head> |
|
|
<style> |
|
|
body {{ margin: 0; padding: 0; overflow: hidden; background: #000; position: relative; }} |
|
|
#globeViz {{ width: 100%; height: 600px; }} |
|
|
.location-label {{ |
|
|
color: white; |
|
|
font-size: 16px; |
|
|
font-family: Arial, sans-serif; |
|
|
background: rgba(0,0,0,0.7); |
|
|
padding: 8px 12px; |
|
|
border-radius: 4px; |
|
|
pointer-events: none; |
|
|
}} |
|
|
.controls {{ |
|
|
position: absolute; |
|
|
top: 20px; |
|
|
right: 20px; |
|
|
z-index: 100; |
|
|
display: flex; |
|
|
gap: 10px; |
|
|
}} |
|
|
.control-btn {{ |
|
|
background: rgba(255,255,255,0.9); |
|
|
border: none; |
|
|
padding: 10px 16px; |
|
|
border-radius: 6px; |
|
|
cursor: pointer; |
|
|
font-weight: bold; |
|
|
font-size: 14px; |
|
|
transition: all 0.3s; |
|
|
box-shadow: 0 2px 8px rgba(0,0,0,0.3); |
|
|
}} |
|
|
.control-btn:hover {{ |
|
|
background: white; |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 4px 12px rgba(0,0,0,0.4); |
|
|
}} |
|
|
.control-btn.active {{ |
|
|
background: #667eea; |
|
|
color: white; |
|
|
}} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="controls"> |
|
|
<button class="control-btn active" id="dayBtn" onclick="setDayMode()">☀️ Day</button> |
|
|
<button class="control-btn" id="nightBtn" onclick="setNightMode()">🌙 Night</button> |
|
|
<button class="control-btn" id="bordersBtn" onclick="toggleBorders()">🗺️ Borders</button> |
|
|
</div> |
|
|
<div id="globeViz"></div> |
|
|
|
|
|
<script src="//unpkg.com/globe.gl"></script> |
|
|
<script> |
|
|
let showBorders = false; |
|
|
let currentMode = 'day'; |
|
|
|
|
|
const myGlobe = Globe() |
|
|
.globeImageUrl('//unpkg.com/three-globe/example/img/earth-blue-marble.jpg') |
|
|
.bumpImageUrl('//unpkg.com/three-globe/example/img/earth-topology.png') |
|
|
.backgroundImageUrl('//unpkg.com/three-globe/example/img/night-sky.png') |
|
|
.pointOfView({{ lat: {lat}, lng: {lon}, altitude: 2.5 }}, 0) |
|
|
.atmosphereColor('lightskyblue') |
|
|
.atmosphereAltitude(0.15) |
|
|
(document.getElementById('globeViz')); |
|
|
|
|
|
// Load country borders |
|
|
fetch('//unpkg.com/world-atlas/countries-50m.json') |
|
|
.then(res => res.json()) |
|
|
.then(countries => {{ |
|
|
window.countriesData = countries; |
|
|
}}); |
|
|
|
|
|
// Add marker point |
|
|
const markerData = [{{ |
|
|
lat: {lat}, |
|
|
lng: {lon}, |
|
|
size: 0.5, |
|
|
color: '#ff4444', |
|
|
label: '{parsed.city or "Location"}', |
|
|
city: '{parsed.city or "Unknown"}', |
|
|
region: '{parsed.region or "Unknown"}', |
|
|
country: '{parsed.country or "Unknown"}' |
|
|
}}]; |
|
|
|
|
|
myGlobe |
|
|
.pointsData(markerData) |
|
|
.pointAltitude('size') |
|
|
.pointColor('color') |
|
|
.pointRadius(0.6) |
|
|
.pointLabel(d => ` |
|
|
<div class="location-label"> |
|
|
<b>${{d.city}}</b><br/> |
|
|
${{d.region}}, ${{d.country}}<br/> |
|
|
${{d.lat.toFixed(4)}}°, ${{d.lng.toFixed(4)}}° |
|
|
</div> |
|
|
`); |
|
|
|
|
|
// Animate to location |
|
|
myGlobe.pointOfView({{ lat: {lat}, lng: {lon}, altitude: 1.5 }}, 3000); |
|
|
|
|
|
// Auto-rotate |
|
|
myGlobe.controls().autoRotate = true; |
|
|
myGlobe.controls().autoRotateSpeed = 0.3; |
|
|
|
|
|
// Add pulsing ring animation |
|
|
const ringData = [{{ |
|
|
lat: {lat}, |
|
|
lng: {lon}, |
|
|
maxR: 10, |
|
|
propagationSpeed: 2, |
|
|
repeatPeriod: 1500 |
|
|
}}]; |
|
|
|
|
|
myGlobe |
|
|
.ringsData(ringData) |
|
|
.ringColor(() => 'rgba(255,68,68,0.5)') |
|
|
.ringMaxRadius('maxR') |
|
|
.ringPropagationSpeed('propagationSpeed') |
|
|
.ringRepeatPeriod('repeatPeriod'); |
|
|
|
|
|
// Add arcs for visual effect |
|
|
const arcData = [{{ |
|
|
startLat: {lat}, |
|
|
startLng: {lon}, |
|
|
endLat: {lat + 10}, |
|
|
endLng: {lon + 10}, |
|
|
color: ['rgba(255,68,68,0.4)', 'rgba(255,68,68,0.1)'] |
|
|
}}]; |
|
|
|
|
|
myGlobe |
|
|
.arcsData(arcData) |
|
|
.arcColor('color') |
|
|
.arcDashLength(0.4) |
|
|
.arcDashGap(0.2) |
|
|
.arcDashAnimateTime(2000) |
|
|
.arcStroke(0.5); |
|
|
|
|
|
// Mode switching functions |
|
|
function setDayMode() {{ |
|
|
currentMode = 'day'; |
|
|
myGlobe |
|
|
.globeImageUrl('//unpkg.com/three-globe/example/img/earth-blue-marble.jpg') |
|
|
.bumpImageUrl('//unpkg.com/three-globe/example/img/earth-topology.png'); |
|
|
|
|
|
document.getElementById('dayBtn').classList.add('active'); |
|
|
document.getElementById('nightBtn').classList.remove('active'); |
|
|
}} |
|
|
|
|
|
function setNightMode() {{ |
|
|
currentMode = 'night'; |
|
|
myGlobe |
|
|
.globeImageUrl('//unpkg.com/three-globe/example/img/earth-night.jpg') |
|
|
.bumpImageUrl('//unpkg.com/three-globe/example/img/earth-topology.png'); |
|
|
|
|
|
document.getElementById('nightBtn').classList.add('active'); |
|
|
document.getElementById('dayBtn').classList.remove('active'); |
|
|
}} |
|
|
|
|
|
function toggleBorders() {{ |
|
|
showBorders = !showBorders; |
|
|
const btn = document.getElementById('bordersBtn'); |
|
|
|
|
|
if (showBorders && window.countriesData) {{ |
|
|
const countries = topojson.feature(window.countriesData, window.countriesData.objects.countries); |
|
|
myGlobe |
|
|
.polygonsData(countries.features) |
|
|
.polygonAltitude(0.01) |
|
|
.polygonCapColor(() => 'rgba(200, 200, 200, 0.1)') |
|
|
.polygonSideColor(() => 'rgba(200, 200, 200, 0.05)') |
|
|
.polygonStrokeColor(() => '#ffffff') |
|
|
.polygonLabel(({{ properties: d }}) => ` |
|
|
<div class="location-label"> |
|
|
<b>${{d.name}}</b> |
|
|
</div> |
|
|
`); |
|
|
btn.classList.add('active'); |
|
|
}} else {{ |
|
|
myGlobe.polygonsData([]); |
|
|
btn.classList.remove('active'); |
|
|
}} |
|
|
}} |
|
|
</script> |
|
|
<script src="//unpkg.com/topojson-client"></script> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
return html |
|
|
|
|
|
def create_info_card(parsed: ParsedResponse) -> str: |
|
|
"""Create information card with details""" |
|
|
if not parsed.coords: |
|
|
return "" |
|
|
|
|
|
lat, lon = parsed.coords.lat, parsed.coords.lon |
|
|
|
|
|
html = f""" |
|
|
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
border-radius: 12px; padding: 24px; color: white; margin-top: 20px;"> |
|
|
<h2 style="margin: 0 0 16px 0; font-size: 24px;">📍 Predicted Location</h2> |
|
|
|
|
|
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-bottom: 20px;"> |
|
|
<div style="background: rgba(255,255,255,0.1); padding: 12px; border-radius: 8px;"> |
|
|
<div style="font-size: 12px; opacity: 0.8;">City</div> |
|
|
<div style="font-size: 18px; font-weight: bold;">{parsed.city or "Unknown"}</div> |
|
|
</div> |
|
|
<div style="background: rgba(255,255,255,0.1); padding: 12px; border-radius: 8px;"> |
|
|
<div style="font-size: 12px; opacity: 0.8;">Region</div> |
|
|
<div style="font-size: 18px; font-weight: bold;">{parsed.region or "Unknown"}</div> |
|
|
</div> |
|
|
<div style="background: rgba(255,255,255,0.1); padding: 12px; border-radius: 8px;"> |
|
|
<div style="font-size: 12px; opacity: 0.8;">Country</div> |
|
|
<div style="font-size: 18px; font-weight: bold;">{parsed.country or "Unknown"}</div> |
|
|
</div> |
|
|
<div style="background: rgba(255,255,255,0.1); padding: 12px; border-radius: 8px;"> |
|
|
<div style="font-size: 12px; opacity: 0.8;">Coordinates</div> |
|
|
<div style="font-size: 14px; font-weight: bold;">{lat:.4f}°, {lon:.4f}°</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div style="display: flex; gap: 12px; flex-wrap: wrap;"> |
|
|
<a href="https://www.google.com/maps?q={lat},{lon}" target="_blank" |
|
|
style="background: #4285f4; color: white; padding: 10px 20px; |
|
|
border-radius: 6px; text-decoration: none; font-weight: bold;"> |
|
|
🗺️ Google Maps |
|
|
</a> |
|
|
<a href="https://www.openstreetmap.org/?mlat={lat}&mlon={lon}#map=12/{lat}/{lon}" target="_blank" |
|
|
style="background: #7ebc6f; color: white; padding: 10px 20px; |
|
|
border-radius: 6px; text-decoration: none; font-weight: bold;"> |
|
|
🌍 OpenStreetMap |
|
|
</a> |
|
|
<a href="https://www.google.com/search?q={parsed.city}+{parsed.country}" target="_blank" |
|
|
style="background: #ea4335; color: white; padding: 10px 20px; |
|
|
border-radius: 6px; text-decoration: none; font-weight: bold;"> |
|
|
🔍 Learn More |
|
|
</a> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
return html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="GeoVLM - 3D Globe", theme=gr.themes.Soft(), css=""" |
|
|
.gradio-container {max-width: 1400px !important;} |
|
|
.globe-container {height: 600px !important;} |
|
|
""") as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🌍 GeoVLM - AI Geolocation with 3D Globe |
|
|
|
|
|
Upload any image and watch the AI predict its location on an interactive 3D globe! |
|
|
|
|
|
**Powered by:** [vlm-gym](https://github.com/sdan/vlm-gym) | Vision-Language Models | Three.js Globe |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="📸 Upload Image", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button( |
|
|
"🔍 Analyze & Locate", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### 💡 Tips: |
|
|
- Outdoor images work best |
|
|
- Street views are ideal |
|
|
- Landmarks help accuracy |
|
|
- Clear, well-lit photos |
|
|
|
|
|
### 🎯 Features: |
|
|
- 3D interactive globe |
|
|
- Flies to predicted location |
|
|
- Pulsing marker animation |
|
|
- Auto-rotating globe |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("🌐 3D Globe"): |
|
|
globe_output = gr.HTML( |
|
|
label="Interactive Globe", |
|
|
elem_classes=["globe-container"] |
|
|
) |
|
|
|
|
|
with gr.Tab("📊 Details"): |
|
|
info_output = gr.HTML(label="Location Info") |
|
|
output_text = gr.Markdown(label="Analysis") |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
|
|
|
### 🎮 How It Works: |
|
|
|
|
|
1. **Upload** any image with visible location clues |
|
|
2. **AI analyzes** architecture, vegetation, signs, landscape |
|
|
3. **Globe flies** to the predicted location in 3D |
|
|
4. **Explore** the area with interactive controls |
|
|
|
|
|
### 🔬 Technology: |
|
|
- **Vision Model:** Qwen2-VL-2B-Instruct |
|
|
- **Training:** Reinforcement learning on 5M geotagged images |
|
|
- **Visualization:** Three.js Globe.GL |
|
|
- **Dataset:** OSV5M (OpenStreetView 5M) |
|
|
|
|
|
### 🚀 Use Cases: |
|
|
- **OSINT Research** - Verify photo locations |
|
|
- **Education** - Learn world geography |
|
|
- **Travel** - Discover new places |
|
|
- **Training** - Practice geolocation skills |
|
|
|
|
|
--- |
|
|
|
|
|
Built with ❤️ by AceXRoux | [GitHub](https://github.com/sdan/vlm-gym) | [LinkedIn](https://linkedin.com/in/your-profile) |
|
|
""") |
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_location, |
|
|
inputs=image_input, |
|
|
outputs=[output_text, globe_output, info_output] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🌍 Starting GeoVLM with 3D Globe...") |
|
|
load_model() |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |