GeoAgent / app.py
ghost233lism's picture
Update app.py
2b6c61e verified
"""
GeoAgent - Geolocation Battle Gradio Space
Players compete with AI in a geolocation race, based on AI's pre-inferred answers.
"""
import json
import math
import random
import re
import warnings
from pathlib import Path
import folium
import gradio as gr
from branca.element import MacroElement, Template
from geopy.distance import geodesic
from gradio_folium import Folium
from PIL import Image
# Path configuration
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / "data"
QUESTIONS_PATH = DATA_DIR / "questions.json"
IMAGES_DIR = DATA_DIR / "images"
# GeoScore parameters: Full score 5000, sigma controls decay (differs by mode)
GEO_SCORE_MAX = 5000
GEO_SCORE_SIGMA_CHINA = 5000 # km
GEO_SCORE_SIGMA_WORLD = 18050 # km
UI_THEME = gr.themes.Soft()
ANALYSIS_PLACEHOLDER = "### AI Analysis\nThe explanation for this round will be shown after submitting an answer."
UI_CSS = """
#coord_input { display: none !important; }
#control_row, #battle_row, #round_section {
width: 100% !important;
max-width: none !important;
margin-left: 0 !important;
margin-right: 0 !important;
}
#battle_row {
align-items: stretch !important;
flex-wrap: nowrap !important;
gap: clamp(8px, 1.2vw, 18px) !important;
}
#battle_row > .gr-column {
align-self: stretch !important;
min-width: 0 !important;
}
#player_col, #ai_col {
flex: 1.05 1 0 !important;
min-width: 0 !important;
}
#player_image, #ai_image {
overflow: hidden !important;
cursor: grab;
}
#player_image img, #ai_image img {
object-fit: contain !important;
transform-origin: center center;
transition: transform 0.05s ease-out;
-webkit-user-drag: none;
user-drag: none;
user-select: none;
}
#player_map, #ai_map { margin-top: 0 !important; margin-bottom: 0 !important; line-height: 0 !important; height: 480px !important; overflow: hidden !important; }
#player_map > div, #ai_map > div { margin-top: 0 !important; margin-bottom: 0 !important; padding-top: 0 !important; padding-bottom: 0 !important; height: 480px !important; min-height: 0 !important; }
#player_map iframe, #ai_map iframe { display: block !important; vertical-align: top !important; height: 480px !important; }
#player_image, #ai_image, #player_map, #ai_map {
width: 100% !important;
max-width: none !important;
min-width: 0 !important;
margin-left: 0 !important;
margin-right: 0 !important;
}
#ai_media_col {
flex: 1 1 0 !important;
max-width: none !important;
min-width: 0 !important;
}
#ai_analysis_col {
flex: 0.95 1 0 !important;
min-width: 0 !important;
max-width: 620px !important;
display: flex !important;
flex-direction: column !important;
}
#ai_analysis_box {
flex: 1 1 auto !important;
min-height: 0 !important;
max-height: none !important;
height: auto !important;
overflow: auto;
border: 1px solid #d1d5db;
border-radius: 10px;
padding: 14px;
background: #ffffff;
line-height: 1.65;
text-align: justify;
text-justify: inter-word;
}
#ai_analysis_box h3 {
margin: 0 0 10px 0 !important;
}
#ai_analysis_box h4 {
margin: 10px 0 6px 0 !important;
}
#ai_analysis_box p {
margin: 6px 0 !important;
}
#ai_analysis_box ul, #ai_analysis_box ol {
margin: 6px 0 !important;
padding-left: 20px !important;
}
#ai_analysis_box li {
margin: 3px 0 !important;
}
#ai_analysis_box hr {
margin: 10px 0 !important;
}
"""
UI_HEAD = """
<script>
(function () {
var SCALE_MIN = 1.0;
var SCALE_MAX = 5.0;
var ZOOM_FACTOR = 1.1;
var stateMap = new WeakMap();
var dragging = null;
function getImageContainer(target) {
if (!target || !target.closest) return null;
return target.closest('#player_image, #ai_image');
}
function getState(container) {
var st = stateMap.get(container);
if (!st) {
st = { scale: 1, tx: 0, ty: 0 };
stateMap.set(container, st);
}
return st;
}
function applyTransform(container) {
var img = container.querySelector('img');
if (!img) return;
var st = getState(container);
img.dataset.zoomScale = String(st.scale);
img.draggable = false;
img.style.transformOrigin = '0 0';
img.style.transform = 'translate(' + st.tx + 'px, ' + st.ty + 'px) scale(' + st.scale + ')';
}
function resetImageZoom(container) {
var st = getState(container);
st.scale = 1;
st.tx = 0;
st.ty = 0;
container.style.cursor = 'grab';
applyTransform(container);
}
function zoomAtPoint(container, clientX, clientY, zoomIn) {
var st = getState(container);
var rect = container.getBoundingClientRect();
var px = clientX - rect.left;
var py = clientY - rect.top;
var beforeX = (px - st.tx) / st.scale;
var beforeY = (py - st.ty) / st.scale;
var nextScale = zoomIn ? st.scale * ZOOM_FACTOR : st.scale / ZOOM_FACTOR;
nextScale = Math.max(SCALE_MIN, Math.min(SCALE_MAX, nextScale));
st.tx = px - beforeX * nextScale;
st.ty = py - beforeY * nextScale;
st.scale = nextScale;
applyTransform(container);
}
// Wheel zoom: zoom at mouse cursor
document.addEventListener('wheel', function (e) {
var container = getImageContainer(e.target);
if (!container) return;
e.preventDefault();
zoomAtPoint(container, e.clientX, e.clientY, e.deltaY < 0);
}, { passive: false });
// Drag to pan (map style)
document.addEventListener('mousedown', function (e) {
if (e.button !== 0) return;
var container = getImageContainer(e.target);
if (!container) return;
e.preventDefault();
var st = getState(container);
dragging = {
container: container,
startX: e.clientX,
startY: e.clientY,
startTx: st.tx,
startTy: st.ty,
};
container.style.cursor = 'grabbing';
});
document.addEventListener('mousemove', function (e) {
if (!dragging) return;
var st = getState(dragging.container);
st.tx = dragging.startTx + (e.clientX - dragging.startX);
st.ty = dragging.startTy + (e.clientY - dragging.startY);
applyTransform(dragging.container);
});
document.addEventListener('mouseup', function () {
if (!dragging) return;
dragging.container.style.cursor = 'grab';
dragging = null;
});
// Double click to reset zoom to 1x
document.addEventListener('dblclick', function (e) {
var container = getImageContainer(e.target);
if (!container) return;
e.preventDefault();
resetImageZoom(container);
});
// Prevent default browser drag of image files
document.addEventListener('dragstart', function (e) {
var container = getImageContainer(e.target);
if (!container) return;
e.preventDefault();
});
function observeImage(img) {
if (img.dataset.zoomObserved === '1') return;
img.dataset.zoomObserved = '1';
var container = img.closest('#player_image, #ai_image');
if (!container) return;
resetImageZoom(container);
var obs = new MutationObserver(function (records) {
for (var i = 0; i < records.length; i++) {
if (records[i].type === 'attributes' && records[i].attributeName === 'src') {
resetImageZoom(container);
}
}
});
obs.observe(img, { attributes: true, attributeFilter: ['src'] });
}
function bindImageObservers() {
var imgs = document.querySelectorAll('#player_image img, #ai_image img');
for (var i = 0; i < imgs.length; i++) observeImage(imgs[i]);
}
bindImageObservers();
var pageObserver = new MutationObserver(bindImageObservers);
pageObserver.observe(document.body, { childList: true, subtree: true });
})();
</script>
"""
ROUND_TABLE_HEADERS = ["Round", "Player Dist (km)", "Player GeoScore", "AI Dist (km)", "AI GeoScore", "Winner"]
MEDIA_WIDTH = 700
MEDIA_HEIGHT = 480 # 4:3
class SingleClickMarker(MacroElement):
"""Folium plugin: Only one marker is kept when the map is clicked."""
def __init__(self):
super().__init__()
self._template = Template(
"""
{% macro script(this, kwargs) %}
var map = {{ this._parent.get_name() }};
var selectedMarker = null;
map.on('click', function(e) {
if (selectedMarker) {
map.removeLayer(selectedMarker);
}
selectedMarker = L.marker(e.latlng).addTo(map);
selectedMarker.bindPopup(
"Lat: " + e.latlng.lat.toFixed(6) +
"<br>Lng: " + e.latlng.lng.toFixed(6) +
"<br><small>Selected as your answer</small>"
).openPopup();
// Sync to hidden Gradio input for backend submission
var coord = e.latlng.lat + "," + e.latlng.lng;
try {
var doc = (window.parent && window.parent !== window) ? window.parent.document : document;
var container = doc.getElementById("coord_input");
if (container) {
var inp = container.querySelector("textarea") || container.querySelector("input");
if (inp) {
inp.value = coord;
inp.dispatchEvent(new Event("input", { bubbles: true }));
inp.dispatchEvent(new Event("change", { bubbles: true }));
}
}
} catch (err) {}
});
{% endmacro %}
"""
)
def load_questions(mode: str | None = None) -> list[dict]:
"""Load questions, can filter by mode."""
if not QUESTIONS_PATH.exists():
return []
with open(QUESTIONS_PATH, encoding="utf-8") as f:
data = json.load(f)
if mode and mode != "all":
data = [q for q in data if q.get("mode") == mode]
return data
def sample_questions(mode: str, num_rounds: int) -> list[dict]:
"""Randomly sample questions by mode and rounds."""
pool = load_questions(mode)
if not pool:
return []
num_rounds = min(num_rounds, len(pool))
return random.sample(pool, num_rounds)
def get_image_path(rel_path: str) -> Path:
"""Get image absolute path under data/images/ (filename or subpath)."""
return IMAGES_DIR / rel_path
def _create_placeholder_image() -> Path:
"""Create a placeholder image and return its path."""
placeholder_path = IMAGES_DIR / "_placeholder.png"
if placeholder_path.exists():
return placeholder_path
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
img = Image.new("RGB", (400, 300), color=(240, 240, 240))
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 24)
except OSError:
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24)
except OSError:
font = ImageFont.load_default()
draw.text((50, 130), "Image Placeholder (Add images at data/images/)", fill=(100, 100, 100), font=font)
img.save(placeholder_path)
return placeholder_path
def get_image_or_placeholder(rel_path: str) -> str:
"""Return image path, or placeholder if not exists."""
p = get_image_path(rel_path)
if p.exists():
return str(p)
return str(_create_placeholder_image())
def geodesic_km(lat1: float, lng1: float, lat2: float, lng2: float) -> float:
"""Compute geodesic (kilometers) between two points."""
return geodesic((lat1, lng1), (lat2, lng2)).kilometers
def geo_score(distance_km: float, mode: str = "world") -> float:
"""Compute GeoScore from distance, sigma by mode 'china/world'."""
if distance_km < 0 or math.isinf(distance_km):
return 0.0
sigma = GEO_SCORE_SIGMA_CHINA if mode == "china" else GEO_SCORE_SIGMA_WORLD
return GEO_SCORE_MAX * math.exp(-10 * (distance_km / sigma))
def resolve_num_rounds(num_rounds_choice, custom_rounds) -> int:
"""Interpret the number of rounds (supports custom)."""
if str(num_rounds_choice) == "Custom":
try:
val = int(custom_rounds)
except (TypeError, ValueError):
val = 5
else:
try:
val = int(num_rounds_choice)
except (TypeError, ValueError):
val = 5
return max(1, val)
def resolve_time_limit_seconds(time_limit_choice: str, custom_time_limit) -> int:
"""Interpret time limit (seconds), supports custom."""
mapping = {"Off": 0, "30s": 30, "60s": 60, "3min": 180}
if time_limit_choice == "Custom":
try:
val = int(custom_time_limit)
except (TypeError, ValueError):
val = 180
return max(1, val)
return mapping.get(time_limit_choice, 180)
def build_round_table_and_stats(round_results: list[dict]) -> tuple[list[list], str]:
"""Convert per-round results to table and statistics string."""
if not round_results:
return [], "Average distance & GeoScore: No data yet."
rows: list[list] = []
player_dist_sum = 0.0
ai_dist_sum = 0.0
player_score_sum = 0.0
ai_score_sum = 0.0
for item in round_results:
p_score = item["player_score"]
a_score = item["ai_score"]
winner = "Player" if p_score > a_score else ("AI" if a_score > p_score else "Draw")
rows.append(
[
item["round"],
f"{item['player_distance']:.1f}",
int(p_score),
f"{item['ai_distance']:.1f}",
int(a_score),
winner,
]
)
player_dist_sum += item["player_distance"]
ai_dist_sum += item["ai_distance"]
player_score_sum += p_score
ai_score_sum += a_score
n = len(round_results)
stats_md = (
f"Player avg. distance: {player_dist_sum / n:.1f} km | Player GeoScore: {int(player_score_sum / n)} | "
f"AI avg. distance: {ai_dist_sum / n:.1f} km | AI GeoScore: {int(ai_score_sum / n)}"
)
return rows, stats_md
def _extract_markdown_json_block(text: str) -> str:
"""Extract JSON text from ```json ... ``` or ``` ... ``` (robust for stored strings)."""
t = text.strip()
# 1) Regex: 匹配 ```json 或 ``` 与结尾 ```
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", t, flags=re.IGNORECASE)
if m:
return m.group(1).strip()
# 2) 手动剥离:以 ``` 开头时去掉首行并截到最后一个 ```
if t.startswith("```"):
first_line_end = t.find("\n")
if first_line_end >= 0:
t = t[first_line_end + 1 :]
end = t.rfind("```")
if end >= 0:
t = t[:end]
return t.strip()
return t
def _extract_balanced_json_object(text: str) -> str | None:
"""Extract first balanced-bracket JSON from a string."""
start = text.find("{")
if start < 0:
return None
depth = 0
for idx in range(start, len(text)):
ch = text[idx]
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return text[start : idx + 1]
return None
def _normalize_section(section: dict | None) -> dict:
"""Normalize section, fill missing values."""
section = section or {}
clues = section.get("Clues", [])
if isinstance(clues, str):
clues = [clues]
if not isinstance(clues, list):
clues = []
return {
"Clues": [str(c) for c in clues],
"Reasoning": str(section.get("Reasoning", "")),
"Conclusion": str(section.get("Conclusion", "")),
"Uncertainty": str(section.get("Uncertainty", "")),
}
def parse_ai_analysis_payload(raw: str | dict | None) -> tuple[dict, str]:
"""
Multi-strategy parse of AI analysis JSON, returning structure+method.
- strategy1: direct dict
- strategy2: markdown fenced json
- strategy3: direct json.loads
- strategy4: balanced bracket extract then json.loads
- strategy5: regex fallback
"""
if isinstance(raw, dict):
data = raw
method = "dict-direct"
else:
text = str(raw or "").strip()
data = None
method = "fallback-regex"
if text:
# 优先剥离 ```json ... ``` 再解析,兼容 selected_results 等存储格式
candidate = _extract_markdown_json_block(text)
try:
data = json.loads(candidate)
method = "markdown-json"
except Exception:
try:
data = json.loads(text)
method = "json-direct"
except Exception:
# 对已剥离的 candidate 再试一次平衡括号提取(防止尾部杂项)
obj = _extract_balanced_json_object(candidate) or _extract_balanced_json_object(text)
if obj:
try:
data = json.loads(obj)
method = "balanced-object-json"
except Exception:
data = None
else:
data = None
if data is None:
# Fallback: regex extract partial info
def _rx(pattern: str) -> str:
m = re.search(pattern, text, flags=re.IGNORECASE | re.DOTALL)
return (m.group(1).strip() if m else "")
data = {
"ChainOfThought": {
"CountryIdentification": {
"Conclusion": _rx(r"CountryIdentification[\s\S]*?Conclusion\"\s*:\s*\"([^\"]+)"),
"Reasoning": _rx(r"CountryIdentification[\s\S]*?Reasoning\"\s*:\s*\"([^\"]+)"),
"Uncertainty": _rx(r"CountryIdentification[\s\S]*?Uncertainty\"\s*:\s*\"([^\"]+)"),
"Clues": [],
},
"RegionalGuess": {
"Conclusion": _rx(r"RegionalGuess[\s\S]*?Conclusion\"\s*:\s*\"([^\"]+)"),
"Reasoning": _rx(r"RegionalGuess[\s\S]*?Reasoning\"\s*:\s*\"([^\"]+)"),
"Uncertainty": _rx(r"RegionalGuess[\s\S]*?Uncertainty\"\s*:\s*\"([^\"]+)"),
"Clues": [],
},
"PreciseLocalization": {
"Conclusion": _rx(r"PreciseLocalization[\s\S]*?Conclusion\"\s*:\s*\"([^\"]+)"),
"Reasoning": _rx(r"PreciseLocalization[\s\S]*?Reasoning\"\s*:\s*\"([^\"]+)"),
"Uncertainty": _rx(r"PreciseLocalization[\s\S]*?Uncertainty\"\s*:\s*\"([^\"]+)"),
"Clues": [],
},
},
"FinalAnswer": _rx(r"FinalAnswer\"\s*:\s*\"([^\"]+)"),
}
cot = data.get("ChainOfThought", {}) if isinstance(data, dict) else {}
country = _normalize_section(cot.get("CountryIdentification") if isinstance(cot, dict) else None)
regional = _normalize_section(cot.get("RegionalGuess") if isinstance(cot, dict) else None)
precise = _normalize_section(cot.get("PreciseLocalization") if isinstance(cot, dict) else None)
final_answer = ""
if isinstance(data, dict):
final_answer = str(data.get("FinalAnswer", "") or "")
if not final_answer:
parts = [country["Conclusion"], regional["Conclusion"], precise["Conclusion"]]
final_answer = "; ".join([p for p in parts if p])
return {
"country": country,
"regional": regional,
"precise": precise,
"final_answer": final_answer,
}, method
def _render_section_md(title: str, section: dict) -> str:
clues = ", ".join(section["Clues"]) if section["Clues"] else "None"
reasoning = section["Reasoning"] or "None"
conclusion = section["Conclusion"] or "None"
uncertainty = section["Uncertainty"] or "Unknown"
return (
f"#### {title}\n"
f"**Clues**: {clues}\n\n"
f"**Reasoning**: {reasoning}\n\n"
f"**Conclusion**: **{conclusion}**\n\n"
f"**Uncertainty**: {uncertainty}\n"
)
def build_ai_analysis_markdown(question: dict) -> str:
"""Render AI CoT/JSON explanation as markdown."""
raw = None
for key in (
"ai_analysis_json",
"ai_analysis",
"analysis_json",
"analysis",
"cot_json",
"cot",
"llm_output",
"ai_output",
):
if key in question and question.get(key):
raw = question.get(key)
break
if raw is None:
return "### AI Analysis\n\nNo analysis data."
parsed, method = parse_ai_analysis_payload(raw)
return (
"### AI Analysis\n"
+ _render_section_md("Country Judgement", parsed["country"])
+ "\n---\n"
+ _render_section_md("Region Judgement", parsed["regional"])
+ "\n---\n"
+ _render_section_md("Precise Location", parsed["precise"])
+ f"\n### Final Answer\n**{parsed['final_answer'] or 'None'}**\n"
)
def _clamp_lat(lat: float) -> float:
"""Clamp latitude to Folium displayable."""
return max(min(lat, 85.0), -85.0)
def _normalize_lng(lng: float) -> float:
"""Normalize longitude to [-180, 180]."""
return ((lng + 180.0) % 360.0) - 180.0
def _expand_answer_point_for_visual_distance(
mode: str,
true_lat: float,
true_lng: float,
answer_lat: float,
answer_lng: float,
) -> tuple[tuple[float, float], tuple[float, float], bool]:
"""
True loc fixed. If answer too close, push answer point away along true->answer.
Ensures both player & AI area's true coord are identical.
"""
mid_lat = (true_lat + answer_lat) / 2
km_per_lat = 111.0
km_per_lng = max(111.0 * math.cos(math.radians(mid_lat)), 1e-6)
dy_km = (answer_lat - true_lat) * km_per_lat
dx_km = (answer_lng - true_lng) * km_per_lng
dist_km = math.hypot(dx_km, dy_km)
base_min_km = 120.0 if mode == "china" else 260.0
target_dist_km = base_min_km * 1.2 if dist_km < base_min_km * 0.4 else base_min_km
target_dist_km = max(dist_km, target_dist_km)
if target_dist_km <= dist_km + 1e-6:
return (true_lat, true_lng), (answer_lat, answer_lng), False
if dist_km < 1e-6:
ux, uy = 1.0, 0.0
else:
ux, uy = dx_km / dist_km, dy_km / dist_km
move_km = target_dist_km - dist_km
dlat = (uy * move_km) / km_per_lat
dlng = (ux * move_km) / km_per_lng
disp_answer_lat = _clamp_lat(answer_lat + dlat)
disp_answer_lng = _normalize_lng(answer_lng + dlng)
return (true_lat, true_lng), (disp_answer_lat, disp_answer_lng), True
def _result_zoom_by_distance(distance_km: float, mode: str) -> int:
"""Choose map zoom by true distance (visual clarity, not coords)."""
if distance_km <= 1:
return 13
if distance_km <= 5:
return 11
if distance_km <= 20:
return 10
if distance_km <= 80:
return 9
if distance_km <= 300:
return 8
if distance_km <= 1200:
return 6 if mode == "world" else 7
return 4 if mode == "world" else 5
def _create_folium_map(
mode: str,
center: tuple[float, float] | None = None,
zoom: int | None = None,
markers: list[dict] | None = None,
lines: list[dict] | None = None,
add_click_for_marker: bool = False,
) -> folium.Map:
"""Create Folium map, add_click_for_marker=True allows click-to-mark."""
if mode == "china":
c = center or (35, 105)
z = zoom if zoom is not None else 4
else:
c = center or (20, 0)
z = zoom if zoom is not None else 2
# Do NOT use fit_bounds here.
# In gradio-folium, fit_bounds sometimes causes click issues (esp. for "china").
m = folium.Map(location=c, zoom_start=z, height=f"{MEDIA_HEIGHT}px")
if add_click_for_marker:
m.add_child(SingleClickMarker())
for marker in markers or []:
color = marker.get("color", "red")
label = marker.get("label", "")
folium.CircleMarker(
location=[marker["lat"], marker["lng"]],
radius=8,
color="white",
fill=True,
fill_color=color,
fill_opacity=1,
popup=label,
).add_to(m)
for line in lines or []:
folium.PolyLine(
locations=line.get("locations", []),
color=line.get("color", "#4f46e5"),
weight=line.get("weight", 2),
opacity=line.get("opacity", 0.9),
dash_array=line.get("dash_array", "8, 8"),
tooltip=line.get("tooltip", "Connection"),
).add_to(m)
return m
def build_empty_map(mode: str, clickable: bool = True, uid: str = "map") -> folium.Map:
"""Create a blank map; clickable=True makes it markable."""
return _create_folium_map(mode=mode, markers=[], add_click_for_marker=clickable)
def build_result_map(
mode: str,
true_lat: float,
true_lng: float,
user_lat: float,
user_lng: float,
uid: str = "result",
) -> folium.Map:
"""Show a map with the true location and user location."""
dist_km = geodesic_km(true_lat, true_lng, user_lat, user_lng)
map_zoom = _result_zoom_by_distance(dist_km, mode)
disp_true_lat, disp_true_lng = true_lat, true_lng
disp_user_lat, disp_user_lng = user_lat, user_lng
user_label = "Your Answer"
true_label = "True Location"
markers = [
{"lat": disp_true_lat, "lng": disp_true_lng, "color": "green", "label": true_label},
{"lat": disp_user_lat, "lng": disp_user_lng, "color": "red", "label": user_label},
]
lines = [
{
"locations": [[disp_true_lat, disp_true_lng], [disp_user_lat, disp_user_lng]],
"tooltip": "True vs. Player",
}
]
center = ((disp_true_lat + disp_user_lat) / 2, (disp_true_lng + disp_user_lng) / 2)
return _create_folium_map(mode=mode, center=center, zoom=map_zoom, markers=markers, lines=lines)
def build_ai_map(
mode: str,
true_lat: float,
true_lng: float,
ai_lat: float,
ai_lng: float,
uid: str = "ai",
) -> folium.Map:
"""Show map with true location vs. AI's location."""
dist_km = geodesic_km(true_lat, true_lng, ai_lat, ai_lng)
map_zoom = _result_zoom_by_distance(dist_km, mode)
disp_true_lat, disp_true_lng = true_lat, true_lng
disp_ai_lat, disp_ai_lng = ai_lat, ai_lng
ai_label = "AI Answer"
true_label = "True Location"
markers = [
{"lat": disp_true_lat, "lng": disp_true_lng, "color": "green", "label": true_label},
{"lat": disp_ai_lat, "lng": disp_ai_lng, "color": "red", "label": ai_label},
]
lines = [
{
"locations": [[disp_true_lat, disp_true_lng], [disp_ai_lat, disp_ai_lng]],
"tooltip": "True vs. AI",
}
]
center = ((disp_true_lat + disp_ai_lat) / 2, (disp_true_lng + disp_ai_lng) / 2)
return _create_folium_map(mode=mode, center=center, zoom=map_zoom, markers=markers, lines=lines)
def _coord_from_lat_lng(lat, lng) -> str:
"""Make coordinate string from lat/lng."""
if lat is not None and lng is not None:
return f"{lat},{lng}"
return ""
def on_lat_lng_change(lat, lng, state: dict) -> tuple:
"""Update map when coordinates change."""
coord = _coord_from_lat_lng(lat, lng)
if not coord:
return gr.update(), gr.update(value=""), state
try:
lat_f, lng_f = float(lat), float(lng)
except (TypeError, ValueError):
return gr.update(), gr.update(), state
mode = state.get("mode", "world")
markers = [{"lat": lat_f, "lng": lng_f, "color": "red", "label": "To Be Confirmed"}]
m = _create_folium_map(mode=mode, center=(lat_f, lng_f), zoom=8, markers=markers)
return gr.update(value=m), gr.update(value=coord), {**state, "pending_lat": lat_f, "pending_lng": lng_f}
def on_coord_change(coord_str: str, state: dict) -> tuple:
"""When player area map is clicked, update pending coords in state."""
if not state.get("questions"):
return state, gr.update(value="Please click 'Start Game' first")
if not coord_str or "," not in coord_str:
return state, gr.update()
try:
lat_str, lng_str = coord_str.strip().split(",", 1)
lat_f = float(lat_str)
lng_f = float(lng_str)
except (ValueError, TypeError):
return state, gr.update()
return {**state, "pending_lat": lat_f, "pending_lng": lng_f}, gr.update()
def start_game(
num_rounds,
custom_rounds: float | None,
mode: str,
time_limit: str,
custom_time_limit: float | None,
state: dict,
) -> tuple:
"""Start new game."""
mode_key = "china" if mode == "China Mode" else "world"
num_rounds_val = resolve_num_rounds(num_rounds, custom_rounds)
questions = sample_questions(mode_key, num_rounds_val)
limit_seconds = resolve_time_limit_seconds(time_limit, custom_time_limit)
if not questions:
return (
gr.update(value="Question bank is empty or no matching problems found. Please check data/questions.json"),
gr.update(value=build_empty_map(mode_key, clickable=False)),
gr.update(value=build_empty_map(mode_key, clickable=False)),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(value=ANALYSIS_PLACEHOLDER),
gr.update(visible=False),
gr.update(visible=False),
gr.update(active=False),
gr.update(value=[]),
gr.update(value="Average distance & GeoScore: No data yet."),
{
**state,
"questions": [],
"current": 0,
"mode": mode_key,
"time_limit": time_limit,
"time_limit_seconds": limit_seconds,
"total_player": 0,
"total_ai": 0,
"round_results": [],
},
)
q = questions[0]
img_display = get_image_or_placeholder(q["image_path"])
timer_active = limit_seconds > 0
time_left = limit_seconds
round_text = f"Round 1 / {len(questions)}"
if timer_active:
round_text += f" | Time left: {time_left} s"
return (
gr.update(value=round_text),
gr.update(value=build_empty_map(mode_key)),
gr.update(value=build_empty_map(mode_key, clickable=False)),
gr.update(value=img_display),
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(value=""),
gr.update(value=""),
gr.update(value=ANALYSIS_PLACEHOLDER),
gr.update(visible=True),
gr.update(visible=False),
gr.update(active=timer_active),
gr.update(value=[]),
gr.update(value="Average distance & GeoScore: No data yet."),
{
**state,
"questions": questions,
"current": 0,
"mode": mode_key,
"time_limit": time_limit,
"time_limit_seconds": limit_seconds,
"time_left": time_left,
"timer_active": timer_active,
"pending_lat": None,
"pending_lng": None,
"total_player": 0,
"total_ai": 0,
"round_results": [],
},
)
def timer_tick(state: dict) -> tuple:
"""Timer countdown, auto submit on timeout."""
if not state.get("timer_active") or not state.get("questions"):
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(active=False),
)
time_left = state.get("time_left", 0) - 1
new_state = {**state, "time_left": time_left}
if time_left <= 0:
new_state["timer_active"] = False
coord_str = (
f"{state.get('pending_lat', 0)},{state.get('pending_lng', 0)}"
if state.get("pending_lat") is not None
else "0,0"
)
res = submit_answer(coord_str, None, None, new_state)
# submit_answer's output order is different; re-order:
return (
res[15], res[0], res[1], res[2], res[3], res[4], res[5],
res[6], res[7], res[8], res[10], res[11],
gr.update(active=False),
)
# Only update timer display
questions = state["questions"]
num_total = len(questions)
current = state["current"]
round_info = f"Round {current + 1} / {num_total} | Time left: {time_left} s"
return (
new_state,
gr.update(value=round_info),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(active=True),
)
def submit_answer(
coord_str: str,
lat_val: float | None,
lng_val: float | None,
state: dict,
) -> tuple:
"""Submit an answer: compute distances, GeoScores, update display."""
if not state.get("questions"):
return (
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(),
gr.update(visible=False),
gr.update(visible=False),
gr.update(active=False),
gr.update(),
gr.update(),
state,
)
q = state["questions"][state["current"]]
ai_analysis_md = build_ai_analysis_markdown(q)
true_lat, true_lng = q["true_lat"], q["true_lng"]
ai_lat, ai_lng = q["ai_lat"], q["ai_lng"]
mode = state.get("mode", "world")
# Parse user coordinates (use what was written by clicking the map)
user_lat, user_lng = None, None
if coord_str and "," in coord_str:
try:
parts = coord_str.strip().split(",")
user_lat, user_lng = float(parts[0]), float(parts[1])
except (ValueError, IndexError):
pass
if user_lat is None:
user_lat, user_lng = state.get("pending_lat"), state.get("pending_lng")
if user_lat is None:
# No answer selected; do not submit; avoid accident using previous coords
return (
gr.update(value="Please select a point on the player map before submitting your answer."),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(value=ANALYSIS_PLACEHOLDER),
gr.update(visible=True),
gr.update(visible=False),
gr.update(active=state.get("timer_active", False)),
gr.update(),
gr.update(),
state,
)
# Player calculation
dist_km = geodesic_km(true_lat, true_lng, user_lat, user_lng)
score = geo_score(dist_km, mode)
player_result = f"Distance: {dist_km:.1f} km | GeoScore: {int(score)}"
player_map = build_result_map(mode, true_lat, true_lng, user_lat, user_lng, uid="p1")
# AI calculation
ai_dist_km = geodesic_km(true_lat, true_lng, ai_lat, ai_lng)
ai_score = geo_score(ai_dist_km, mode)
ai_result = f"Distance: {ai_dist_km:.1f} km | GeoScore: {int(ai_score)}"
ai_map = build_ai_map(mode, true_lat, true_lng, ai_lat, ai_lng, uid="ai")
# Scoring
total_player = state.get("total_player", 0) + score
total_ai = state.get("total_ai", 0) + ai_score
round_results = [
*state.get("round_results", []),
{
"round": state["current"] + 1,
"player_distance": dist_km,
"player_score": score,
"ai_distance": ai_dist_km,
"ai_score": ai_score,
},
]
round_table_rows, round_stats_md = build_round_table_and_stats(round_results)
next_current = state["current"] + 1
questions = state["questions"]
num_total = len(questions)
if next_current >= num_total:
# Game over
round_info = f"Game Over! Total Score - Player: {int(total_player)} | AI: {int(total_ai)}"
next_img = None
next_player_map = player_map
next_ai_map = ai_map
submit_visible = False
next_round_visible = False
new_state = {**state, "questions": [], "current": 0}
else:
# Show this round's result; waiting for user to go next round
round_info = f"Round {state['current'] + 1} / {num_total} Result | Click 'Next Round' to continue"
# Retain image for result page
next_img = get_image_or_placeholder(q["image_path"])
next_player_map = player_map
next_ai_map = ai_map
submit_visible = False
next_round_visible = True
new_state = {
**state,
"current": next_current,
"total_player": total_player,
"total_ai": total_ai,
"round_results": round_results,
}
if next_current >= num_total:
new_state = {**new_state, "total_player": total_player, "total_ai": total_ai, "round_results": round_results}
return (
gr.update(value=round_info),
gr.update(value=next_player_map),
gr.update(value=player_result),
gr.update(value=next_ai_map),
gr.update(value=ai_result),
gr.update(value=next_img),
gr.update(value=None),
gr.update(value=None),
gr.update(value=""),
gr.update(value=ai_analysis_md),
gr.update(visible=submit_visible),
gr.update(visible=next_round_visible),
gr.update(active=submit_visible and state.get("timer_active", False)),
gr.update(value=round_table_rows),
gr.update(value=round_stats_md),
new_state,
)
def go_next_round(state: dict) -> tuple:
"""Proceed to next round."""
if not state.get("questions"):
return (
gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=""), gr.update(),
gr.update(visible=False), gr.update(visible=False), gr.update(active=False),
state,
)
current = state["current"]
questions = state["questions"]
num_total = len(questions)
mode = state.get("mode", "world")
if current >= num_total:
return (
gr.update(value="Game Over"), gr.update(value=build_empty_map(mode, clickable=False)),
gr.update(value=""), gr.update(value=build_empty_map(mode, clickable=False)), gr.update(value=""), gr.update(value=ANALYSIS_PLACEHOLDER),
gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=""),
gr.update(visible=False), gr.update(visible=False), gr.update(active=False),
{**state, "questions": [], "current": 0},
)
q = questions[current]
next_img = get_image_or_placeholder(q["image_path"])
round_info = f"Round {current + 1} / {num_total}"
if state.get("timer_active"):
time_left = int(state.get("time_limit_seconds", 0))
round_info += f" | Time left: {time_left} s"
else:
time_left = 0
return (
gr.update(value=round_info),
gr.update(value=build_empty_map(mode, clickable=True)),
gr.update(value=""),
gr.update(value=build_empty_map(mode, clickable=False)),
gr.update(value=""),
gr.update(value=ANALYSIS_PLACEHOLDER),
gr.update(value=next_img),
gr.update(value=None), gr.update(value=None), gr.update(value=""),
gr.update(visible=True),
gr.update(visible=False),
gr.update(active=state.get("timer_active", False)),
{**state, "pending_lat": None, "pending_lng": None, "time_left": time_left},
)
def create_ui():
"""Build Gradio UI"""
warnings.filterwarnings(
"ignore",
message="The 'theme' parameter in the Blocks constructor will be removed in Gradio 6.0.*",
category=DeprecationWarning,
)
warnings.filterwarnings(
"ignore",
message="The 'css' parameter in the Blocks constructor will be removed in Gradio 6.0.*",
category=DeprecationWarning,
)
with gr.Blocks(
title="GeoAgent - Geolocation Battle",
theme=UI_THEME,
css=UI_CSS,
head=UI_HEAD,
) as demo:
gr.Markdown("# GeoAgent - Geolocation Battle")
gr.Markdown("Compete with AI in geolocation! Look at the image and guess the location by clicking on the map.")
state = gr.State({
"questions": [],
"current": 0,
"mode": "world",
"time_limit": "off",
"time_limit_seconds": 180,
"time_left": 180,
"timer_active": False,
"pending_lat": None,
"pending_lng": None,
"total_player": 0,
"total_ai": 0,
"round_results": [],
})
game_timer = gr.Timer(value=1, active=False)
with gr.Row(elem_id="control_row"):
num_rounds = gr.Dropdown(
choices=["3", "5", "10", "Custom"],
value="5",
label="Rounds",
)
custom_rounds = gr.Number(label="Custom Rounds", value=5, minimum=1, precision=0, visible=False)
mode = gr.Radio(
choices=["China Mode", "World Mode"],
value="World Mode",
label="Mode",
)
time_limit = gr.Dropdown(
choices=["Off", "30s", "60s", "3min", "Custom"],
value="3min",
label="Time Limit",
)
custom_time_limit = gr.Number(label="Custom Time Limit (s)", value=180, minimum=1, precision=0, visible=False)
start_btn = gr.Button("Start Game", variant="primary")
round_info = gr.Markdown("Click 'Start Game' to begin")
with gr.Row(equal_height=False, elem_id="battle_row"):
with gr.Column(scale=3, min_width=0, elem_id="player_col"):
gr.Markdown("### Player Area")
player_image = gr.Image(
label="Current Image",
type="filepath",
interactive=False,
elem_id="player_image",
width=MEDIA_WIDTH,
height=MEDIA_HEIGHT,
)
player_map = Folium(
value=build_empty_map("world", clickable=False),
height=MEDIA_HEIGHT,
elem_id="player_map",
)
player_result = gr.Markdown("")
with gr.Row():
submit_btn = gr.Button("Submit Answer", visible=False)
next_round_btn = gr.Button("Next Round", visible=False, variant="primary")
with gr.Column(scale=3, min_width=0, elem_id="ai_col"):
gr.Markdown("### AI Area")
ai_image = gr.Image(
label="Current Image",
type="filepath",
interactive=False,
elem_id="ai_image",
width=MEDIA_WIDTH,
height=MEDIA_HEIGHT,
)
ai_map = Folium(
value=build_empty_map("world", clickable=False),
height=MEDIA_HEIGHT,
elem_id="ai_map",
)
ai_result = gr.Markdown("")
with gr.Column(scale=4, min_width=0, elem_id="ai_analysis_col"):
gr.Markdown("### Analysis Area")
ai_analysis = gr.Markdown(
ANALYSIS_PLACEHOLDER,
elem_id="ai_analysis_box",
)
with gr.Column(elem_id="round_section"):
gr.Markdown("### Results per Round")
round_table = gr.Dataframe(
headers=ROUND_TABLE_HEADERS,
value=[],
row_count=(1, "dynamic"),
col_count=(len(ROUND_TABLE_HEADERS), "fixed"),
interactive=False,
)
round_stats = gr.Markdown("Average distance & GeoScore: No data yet.")
# Hide input widgets outside visible layout for UI neatness
lat_input = gr.Number(label="Latitude", value=None, visible=False)
lng_input = gr.Number(label="Longitude", value=None, visible=False)
# Must render this to DOM for map click JS to find it
coord_input = gr.Textbox(visible=True, elem_id="coord_input")
# Synchronize image to AI area
def sync_image(img):
return img
player_image.change(fn=sync_image, inputs=player_image, outputs=ai_image)
num_rounds.change(
fn=lambda v: gr.update(visible=str(v) == "Custom"),
inputs=[num_rounds],
outputs=[custom_rounds],
)
time_limit.change(
fn=lambda v: gr.update(visible=v == "Custom"),
inputs=[time_limit],
outputs=[custom_time_limit],
)
start_btn.click(
fn=start_game,
inputs=[num_rounds, custom_rounds, mode, time_limit, custom_time_limit, state],
outputs=[
round_info,
player_map,
ai_map,
player_image,
lat_input,
lng_input,
coord_input,
player_result,
ai_result,
ai_analysis,
submit_btn,
next_round_btn,
game_timer,
round_table,
round_stats,
state,
],
show_progress=False,
).then(
fn=sync_image,
inputs=player_image,
outputs=ai_image,
)
def _on_lat_lng(lat, lng, s):
return on_lat_lng_change(lat, lng, s)
lat_input.change(
fn=_on_lat_lng,
inputs=[lat_input, lng_input, state],
outputs=[player_map, coord_input, state],
)
lng_input.change(
fn=_on_lat_lng,
inputs=[lat_input, lng_input, state],
outputs=[player_map, coord_input, state],
)
coord_input.change(
fn=on_coord_change,
inputs=[coord_input, state],
outputs=[state, round_info],
)
game_timer.tick(
fn=timer_tick,
inputs=[state],
outputs=[
state,
round_info,
player_map,
player_result,
ai_map,
ai_result,
player_image,
lat_input,
lng_input,
coord_input,
submit_btn,
next_round_btn,
game_timer,
],
)
submit_btn.click(
fn=submit_answer,
inputs=[coord_input, lat_input, lng_input, state],
outputs=[
round_info,
player_map,
player_result,
ai_map,
ai_result,
player_image,
lat_input,
lng_input,
coord_input,
ai_analysis,
submit_btn,
next_round_btn,
game_timer,
round_table,
round_stats,
state,
],
show_progress=False,
).then(
fn=sync_image,
inputs=player_image,
outputs=ai_image,
)
next_round_btn.click(
fn=go_next_round,
inputs=[state],
outputs=[
round_info,
player_map,
player_result,
ai_map,
ai_result,
ai_analysis,
player_image,
lat_input,
lng_input,
coord_input,
submit_btn,
next_round_btn,
game_timer,
state,
],
show_progress=False,
).then(
fn=sync_image,
inputs=player_image,
outputs=ai_image,
)
return demo
demo = create_ui()
if __name__ == "__main__":
demo.launch()