RisingZhang's picture
add new ver
a73d0ab
raw
history blame
20.3 kB
from __future__ import annotations
import csv
import datetime as dt
import math
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
if not hasattr(np, "bool"): # pandas<2 compatibility on numpy>=2
np.bool = bool # type: ignore[attr-defined]
import pandas as pd
from PIL import Image, ImageDraw
import gradio as gr
BASE_DIR = Path(__file__).resolve().parent
DATASET_PATH = BASE_DIR / "data" / "dataset.csv"
AUDIO_BASE_DIR = BASE_DIR / "data" / "audios"
ASSETS_DIR = BASE_DIR / "assets"
MAP_IMAGE_PATH = ASSETS_DIR / "world_map.png"
LOG_PATH = BASE_DIR / "player_runs.csv"
RECENT_COLUMNS = ["timestamp", "player_id", "question_id", "distance_km"]
@dataclass
class Sample:
question_id: str
audio_path: Path
longitude: float
latitude: float
city: str
country: str
continent: str
description: str
title: str
def _load_samples() -> List[Sample]:
if not DATASET_PATH.exists():
raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}")
df = pd.read_csv(DATASET_PATH)
start_idx = int(len(df) * 0.9)
test_df = df.iloc[start_idx:].reset_index(drop=True)
samples: List[Sample] = []
missing_audio = 0
missing_coords = 0
for row in test_df.itertuples():
audio_path = AUDIO_BASE_DIR / getattr(row, "mp3name")
longitude = getattr(row, "longitude")
latitude = getattr(row, "latitude")
if math.isnan(longitude) or math.isnan(latitude):
missing_coords += 1
continue
if not audio_path.exists():
missing_audio += 1
continue
samples.append(
Sample(
question_id=str(getattr(row, "key")),
audio_path=audio_path,
longitude=float(longitude),
latitude=float(latitude),
city=str(getattr(row, "city", "") or ""),
country=str(getattr(row, "country", "") or ""),
continent=str(getattr(row, "continent", "") or ""),
description=str(getattr(row, "description", "") or ""),
title=str(getattr(row, "title", "") or ""),
)
)
if not samples:
raise RuntimeError("No playable samples were found in the test split.")
if missing_audio:
print(f"[game_app] Skipped {missing_audio} samples because audio files are missing.")
if missing_coords:
print(f"[game_app] Skipped {missing_coords} samples because coordinates are missing.")
return samples
SAMPLES: List[Sample] = _load_samples()
BASE_MAP_IMAGE = Image.open(MAP_IMAGE_PATH).convert("RGB")
MAP_WIDTH, MAP_HEIGHT = BASE_MAP_IMAGE.size
def _random_queue() -> List[int]:
queue = list(range(len(SAMPLES)))
random.shuffle(queue)
return queue
def _pixel_to_latlon(x: int, y: int) -> Tuple[float, float]:
lon = (x / (MAP_WIDTH - 1)) * 360.0 - 180.0
lat = 90.0 - (y / (MAP_HEIGHT - 1)) * 180.0
return round(lat, 6), round(lon, 6)
def _latlon_to_text(lat: float, lon: float) -> str:
return f"Selected latitude {lat:.3f}°, longitude {lon:.3f}°"
def _haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
r = 6371.0
phi1, phi2 = math.radians(lat1), math.radians(lat2)
d_phi = math.radians(lat2 - lat1)
d_lambda = math.radians(lon2 - lon1)
a = math.sin(d_phi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2.0) ** 2
c = 2.0 * math.atan2(math.sqrt(a), math.sqrt(1.0 - a))
return r * c
def _map_with_marker(x: int, y: int) -> np.ndarray:
marker_radius = max(6, MAP_WIDTH // 150)
img = BASE_MAP_IMAGE.copy()
draw = ImageDraw.Draw(img)
draw.ellipse(
(
x - marker_radius,
y - marker_radius,
x + marker_radius,
y + marker_radius,
),
fill=(225, 64, 64),
outline=(0, 0, 0),
width=2,
)
return np.array(img)
def _base_map_array() -> np.ndarray:
return np.array(BASE_MAP_IMAGE)
def _prepare_clip_info(sample: Sample, round_idx: int) -> str:
intro_lines = [
f"**Round:** {round_idx}"
]
return "\n\n".join(intro_lines)
def _append_log(entry: Dict[str, object]) -> None:
write_header = not LOG_PATH.exists()
with LOG_PATH.open("a", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=list(entry.keys()))
if write_header:
writer.writeheader()
writer.writerow(entry)
def _load_recent_runs(limit: int = 5) -> List[Dict[str, object]]:
if not LOG_PATH.exists():
return []
rows = []
with LOG_PATH.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
return rows[-limit:]
def _format_recent_rows(rows: List[Dict[str, object]]) -> List[List[object]]:
formatted: List[List[object]] = []
for row in rows:
timestamp_str = row.get("timestamp", "")
if timestamp_str:
try:
# Parse ISO format timestamp and format as readable string
dt_obj = dt.datetime.fromisoformat(str(timestamp_str).replace('Z', '+00:00'))
formatted_timestamp = dt_obj.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, AttributeError):
formatted_timestamp = str(timestamp_str)
else:
formatted_timestamp = ""
formatted_row = [formatted_timestamp]
formatted_row.extend([row.get(col, "") for col in RECENT_COLUMNS[1:]])
formatted.append(formatted_row)
return formatted
def initialize_round() -> Tuple[Dict[str, object], Dict[str, Optional[float]], str, str, np.ndarray, str, List[List[object]], str, str]:
queue = _random_queue()
current_index = queue.pop()
state = {
"queue": queue,
"current_index": current_index,
"round": 1,
}
current_sample = SAMPLES[current_index]
audio_value = str(current_sample.audio_path)
clip_md = _prepare_clip_info(current_sample, state["round"])
prompt_text = "Click once on the map to pick your guess. The marker will update to your last selection."
guess_state = {"lat": None, "lon": None, "pixel": None}
recent_runs = _format_recent_rows(_load_recent_runs())
return state, guess_state, audio_value, clip_md, _base_map_array(), prompt_text, recent_runs, "", ""
def _next_sample(state: Dict[str, object]) -> Sample:
queue: List[int] = state["queue"]
if not queue:
queue.extend(_random_queue())
next_index = queue.pop()
state["current_index"] = next_index
state["round"] = state.get("round", 1) + 1
return SAMPLES[next_index]
def _ensure_guess_state(state: Optional[Dict[str, Optional[float]]]) -> Dict[str, Optional[float]]:
if not isinstance(state, dict):
return {"lat": None, "lon": None, "pixel": None}
return {
"lat": state.get("lat"),
"lon": state.get("lon"),
"pixel": state.get("pixel"),
}
def handle_map_click(
evt: gr.SelectData,
current_guess_state: Optional[Dict[str, Optional[float]]],
) -> Tuple[np.ndarray, str, Dict[str, Optional[float]], str, str]:
guess_state = _ensure_guess_state(current_guess_state)
if evt is None:
return _base_map_array(), "Unable to read selection. Please try again.", guess_state, "", ""
index = getattr(evt, "index", None)
value = getattr(evt, "value", None)
x = y = None
if isinstance(index, (tuple, list)) and len(index) >= 2:
x, y = index[0], index[1]
elif isinstance(index, dict):
x = index.get("x")
y = index.get("y")
if (x is None or y is None) and isinstance(value, dict):
x = value.get("x", x)
y = value.get("y", y)
elif (x is None or y is None) and isinstance(value, (tuple, list)) and len(value) >= 2:
if x is None:
x = value[0]
if y is None:
y = value[1]
if x is None or y is None:
return _base_map_array(), "Unable to read selection. Please try again.", guess_state, "", ""
x = int(x)
y = int(y)
lat, lon = _pixel_to_latlon(x, y)
guess_state = {"lat": lat, "lon": lon, "pixel": (x, y)}
image_with_marker = _map_with_marker(x, y)
return image_with_marker, _latlon_to_text(lat, lon), guess_state, f"{lon:.6f}", f"{lat:.6f}"
def submit_guess(
player_id: str,
game_state: Dict[str, object],
guess_state: Optional[Dict[str, Optional[float]]],
longitude: str,
latitude: str,
) -> Tuple[
Dict[str, object],
Dict[str, Optional[float]],
str,
str,
np.ndarray,
str,
List[List[object]],
str,
str,
]:
guess_state = _ensure_guess_state(guess_state)
player_id = (player_id or "").strip()
if not player_id:
message = "Please enter your player ID before submitting."
current_sample = SAMPLES[game_state["current_index"]]
clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1))
prompt_text = message
return (
game_state,
guess_state,
str(current_sample.audio_path),
clip_md,
_base_map_array(),
prompt_text,
_format_recent_rows(_load_recent_runs()),
longitude,
latitude,
)
# Parse longitude and latitude from text inputs
try:
longitude = longitude.strip() if longitude else ""
latitude = latitude.strip() if latitude else ""
if not longitude or not latitude:
message = "Please enter both longitude and latitude, or click on the map to select a location."
current_sample = SAMPLES[game_state["current_index"]]
clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1))
prompt_text = message
return (
game_state,
guess_state,
str(current_sample.audio_path),
clip_md,
_base_map_array(),
prompt_text,
_format_recent_rows(_load_recent_runs()),
longitude,
latitude,
)
guess_lon = float(longitude)
guess_lat = float(latitude)
# Validate ranges
if not (-180 <= guess_lon <= 180):
message = "Longitude must be between -180 and 180."
current_sample = SAMPLES[game_state["current_index"]]
clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1))
prompt_text = message
return (
game_state,
guess_state,
str(current_sample.audio_path),
clip_md,
_base_map_array(),
prompt_text,
_format_recent_rows(_load_recent_runs()),
longitude,
latitude,
)
if not (-90 <= guess_lat <= 90):
message = "Latitude must be between -90 and 90."
current_sample = SAMPLES[game_state["current_index"]]
clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1))
prompt_text = message
return (
game_state,
guess_state,
str(current_sample.audio_path),
clip_md,
_base_map_array(),
prompt_text,
_format_recent_rows(_load_recent_runs()),
longitude,
latitude,
)
except ValueError:
message = "Invalid coordinates. Please enter valid numbers."
current_sample = SAMPLES[game_state["current_index"]]
clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1))
prompt_text = message
return (
game_state,
guess_state,
str(current_sample.audio_path),
clip_md,
_base_map_array(),
prompt_text,
_format_recent_rows(_load_recent_runs()),
longitude,
latitude,
)
current_sample = SAMPLES[game_state["current_index"]]
true_lat = current_sample.latitude
true_lon = current_sample.longitude
distance_km = _haversine(true_lat, true_lon, guess_lat, guess_lon)
reveal_lines = [
f"Real location: {current_sample.city or 'Unknown city'}, {current_sample.country or 'Unknown country'} ({true_lat:.3f}°, {true_lon:.3f}°)",
f"Your guess: ({guess_lat:.3f}°, {guess_lon:.3f}°)",
f"Error distance: {distance_km:.1f} km",
]
log_entry = {
"timestamp": dt.datetime.utcnow().isoformat(),
"player_id": player_id,
"question_id": current_sample.question_id,
"audio_path": str(current_sample.audio_path),
"guess_latitude": guess_lat,
"guess_longitude": guess_lon,
"true_latitude": true_lat,
"true_longitude": true_lon,
"distance_km": round(distance_km, 3),
"city": current_sample.city,
"country": current_sample.country,
"continent": current_sample.continent,
"title": current_sample.title,
"description": current_sample.description,
}
_append_log(log_entry)
next_sample = _next_sample(game_state)
audio_value = str(next_sample.audio_path)
clip_md = _prepare_clip_info(next_sample, game_state["round"])
new_guess_state = {"lat": None, "lon": None, "pixel": None}
prompt_text = "Click once on the map to pick your guess. The marker will update to your last selection."
recent_runs = _format_recent_rows(_load_recent_runs())
return (
game_state,
new_guess_state,
audio_value,
clip_md,
_base_map_array(),
prompt_text,
recent_runs,
"", # Reset longitude input
"", # Reset latitude input
)
custom_css = """
h1 {
font-size: 2.5rem !important;
font-weight: 700 !important;
margin-bottom: 1rem !important;
text-align: center !important;
}
.intro-text {
font-size: 1.1rem !important;
line-height: 1.6 !important;
margin-bottom: 1.5rem !important;
color: #555 !important;
text-align: center !important;
}
/* Apply size restrictions to regular (non-fullscreen) image */
.gradio-image:not([class*="fullscreen"]) {
max-width: 100% !important;
max-height: 600px !important;
margin: 0 auto !important;
display: flex !important;
justify-content: center !important;
}
.gradio-image:not([class*="fullscreen"]) > div {
max-width: 100% !important;
max-height: 600px !important;
margin: 0 auto !important;
display: flex !important;
justify-content: center !important;
}
.gradio-image:not([class*="fullscreen"]) img {
max-width: 100% !important;
max-height: 600px !important;
width: auto !important;
height: auto !important;
object-fit: contain !important;
margin: 0 auto !important;
}
/* Allow fullscreen modal to override size restrictions - higher specificity */
/* Target common Gradio fullscreen containers */
div[class*="modal"] .gradio-image,
div[class*="modal"] .gradio-image > div,
div[class*="modal"] .gradio-image img,
div[class*="fullscreen"] .gradio-image,
div[class*="fullscreen"] .gradio-image > div,
div[class*="fullscreen"] .gradio-image img,
div[id*="lightbox"] .gradio-image,
div[id*="lightbox"] .gradio-image img,
.gradio-image[class*="fullscreen"],
.gradio-image[class*="fullscreen"] > div,
.gradio-image[class*="fullscreen"] img {
max-width: none !important;
max-height: none !important;
width: 100% !important;
height: 100% !important;
}
.selection-text {
font-size: 1.15rem !important;
text-align: left !important;
color: #666 !important;
margin: 0.5rem 0 !important;
}
.selection-text * {
font-size: 1.15rem !important;
}
.feedback-box {
font-size: 1.15rem !important;
line-height: 1.6 !important;
}
.feedback-box p,
.feedback-box div,
.feedback-box span {
font-size: 1.15rem !important;
line-height: 1.6 !important;
}
.clip-info {
font-size: 1.1rem !important;
}
.clip-info p,
.clip-info div,
.clip-info span {
font-size: 1.1rem !important;
}
.clip-info label {
font-size: 1.1rem !important;
font-weight: 600 !important;
}
label {
font-size: 1rem !important;
font-weight: 500 !important;
}
.form-text input {
font-size: 0.95rem !important;
}
/* Table label styling - only target the label, not table content */
.gradio-dataframe label,
[data-testid="dataframe"] label,
label[for*="recent"],
.form-group label {
font-size: 1.15rem !important;
font-weight: 600 !important;
}
"""
with gr.Blocks(title="Audio Geo-Localization Game", theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown("# Audio Geo-Localization Game")
gr.HTML('<p class="intro-text">Welcome to the Audio Geo-Localization Game. Listen to an ambient audio clip, then guess where it was recorded by clicking on the world map. Submit to see the true location and how close you came.</p>')
game_state = gr.State()
guess_state = gr.State()
clip_info = gr.Markdown(label="Clip details", elem_classes=["clip-info"])
with gr.Row():
with gr.Column(scale=2):
map_component = gr.Image(
value=_base_map_array(),
label="World map (click to set your guess)",
interactive=True,
type="numpy",
image_mode="RGB",
elem_classes=["gradio-image"],
)
with gr.Column(scale=1):
player_id = gr.Textbox(label="Player ID", placeholder="Enter an identifier so scores can be tracked", elem_classes=["form-text"])
audio_player = gr.Audio(label="Mystery audio clip", autoplay=False, interactive=False, streaming=False)
selection_text = gr.Markdown("Click once on the map to pick your guess. The marker will update to your last selection.", elem_classes=["selection-text"])
with gr.Row():
longitude_input = gr.Textbox(
label="Longitude",
placeholder="Enter longitude (-180 to 180) or click map",
elem_classes=["form-text"]
)
latitude_input = gr.Textbox(
label="Latitude",
placeholder="Enter latitude (-90 to 90) or click map",
elem_classes=["form-text"]
)
submit_button = gr.Button("Submit Guess", variant="primary")
recent_table = gr.Dataframe(
headers=[
"timestamp",
"player_id",
"question_id",
"distance_km",
],
datatype=["str", "str", "str", "number"],
value=[],
interactive=False,
label="Recent submissions (latest last)",
wrap=True,
)
demo.load(
initialize_round,
inputs=None,
outputs=[game_state, guess_state, audio_player, clip_info, map_component, selection_text, recent_table, longitude_input, latitude_input],
)
map_component.select(
handle_map_click,
inputs=[guess_state],
outputs=[map_component, selection_text, guess_state, longitude_input, latitude_input],
preprocess=False,
)
submit_button.click(
submit_guess,
inputs=[player_id, game_state, guess_state, longitude_input, latitude_input],
outputs=[game_state, guess_state, audio_player, clip_info, map_component, selection_text, recent_table, longitude_input, latitude_input],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=3828)