Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import csv | |
| import datetime as dt | |
| import math | |
| import random | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| if not hasattr(np, "bool"): # pandas<2 compatibility on numpy>=2 | |
| np.bool = bool # type: ignore[attr-defined] | |
| import pandas as pd | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| BASE_DIR = Path(__file__).resolve().parent | |
| DATASET_PATH = BASE_DIR / "data" / "dataset.csv" | |
| AUDIO_BASE_DIR = BASE_DIR / "data" / "audios" | |
| ASSETS_DIR = BASE_DIR / "assets" | |
| MAP_IMAGE_PATH = ASSETS_DIR / "world_map.png" | |
| LOG_PATH = BASE_DIR / "player_runs.csv" | |
| RECENT_COLUMNS = ["timestamp", "player_id", "question_id", "distance_km"] | |
| class Sample: | |
| question_id: str | |
| audio_path: Path | |
| longitude: float | |
| latitude: float | |
| city: str | |
| country: str | |
| continent: str | |
| description: str | |
| title: str | |
| def _load_samples() -> List[Sample]: | |
| if not DATASET_PATH.exists(): | |
| raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}") | |
| df = pd.read_csv(DATASET_PATH) | |
| start_idx = int(len(df) * 0.9) | |
| test_df = df.iloc[start_idx:].reset_index(drop=True) | |
| samples: List[Sample] = [] | |
| missing_audio = 0 | |
| missing_coords = 0 | |
| for row in test_df.itertuples(): | |
| audio_path = AUDIO_BASE_DIR / getattr(row, "mp3name") | |
| longitude = getattr(row, "longitude") | |
| latitude = getattr(row, "latitude") | |
| if math.isnan(longitude) or math.isnan(latitude): | |
| missing_coords += 1 | |
| continue | |
| if not audio_path.exists(): | |
| missing_audio += 1 | |
| continue | |
| samples.append( | |
| Sample( | |
| question_id=str(getattr(row, "key")), | |
| audio_path=audio_path, | |
| longitude=float(longitude), | |
| latitude=float(latitude), | |
| city=str(getattr(row, "city", "") or ""), | |
| country=str(getattr(row, "country", "") or ""), | |
| continent=str(getattr(row, "continent", "") or ""), | |
| description=str(getattr(row, "description", "") or ""), | |
| title=str(getattr(row, "title", "") or ""), | |
| ) | |
| ) | |
| if not samples: | |
| raise RuntimeError("No playable samples were found in the test split.") | |
| if missing_audio: | |
| print(f"[game_app] Skipped {missing_audio} samples because audio files are missing.") | |
| if missing_coords: | |
| print(f"[game_app] Skipped {missing_coords} samples because coordinates are missing.") | |
| return samples | |
| SAMPLES: List[Sample] = _load_samples() | |
| BASE_MAP_IMAGE = Image.open(MAP_IMAGE_PATH).convert("RGB") | |
| MAP_WIDTH, MAP_HEIGHT = BASE_MAP_IMAGE.size | |
| def _random_queue() -> List[int]: | |
| queue = list(range(len(SAMPLES))) | |
| random.shuffle(queue) | |
| return queue | |
| def _pixel_to_latlon(x: int, y: int) -> Tuple[float, float]: | |
| lon = (x / (MAP_WIDTH - 1)) * 360.0 - 180.0 | |
| lat = 90.0 - (y / (MAP_HEIGHT - 1)) * 180.0 | |
| return round(lat, 6), round(lon, 6) | |
| def _latlon_to_text(lat: float, lon: float) -> str: | |
| return f"Selected latitude {lat:.3f}°, longitude {lon:.3f}°" | |
| def _haversine(lat1: float, lon1: float, lat2: float, lon2: float) -> float: | |
| r = 6371.0 | |
| phi1, phi2 = math.radians(lat1), math.radians(lat2) | |
| d_phi = math.radians(lat2 - lat1) | |
| d_lambda = math.radians(lon2 - lon1) | |
| a = math.sin(d_phi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2.0) ** 2 | |
| c = 2.0 * math.atan2(math.sqrt(a), math.sqrt(1.0 - a)) | |
| return r * c | |
| def _map_with_marker(x: int, y: int) -> np.ndarray: | |
| marker_radius = max(6, MAP_WIDTH // 150) | |
| img = BASE_MAP_IMAGE.copy() | |
| draw = ImageDraw.Draw(img) | |
| draw.ellipse( | |
| ( | |
| x - marker_radius, | |
| y - marker_radius, | |
| x + marker_radius, | |
| y + marker_radius, | |
| ), | |
| fill=(225, 64, 64), | |
| outline=(0, 0, 0), | |
| width=2, | |
| ) | |
| return np.array(img) | |
| def _base_map_array() -> np.ndarray: | |
| return np.array(BASE_MAP_IMAGE) | |
| def _prepare_clip_info(sample: Sample, round_idx: int) -> str: | |
| intro_lines = [ | |
| f"**Round:** {round_idx}" | |
| ] | |
| return "\n\n".join(intro_lines) | |
| def _append_log(entry: Dict[str, object]) -> None: | |
| write_header = not LOG_PATH.exists() | |
| with LOG_PATH.open("a", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=list(entry.keys())) | |
| if write_header: | |
| writer.writeheader() | |
| writer.writerow(entry) | |
| def _load_recent_runs(limit: int = 5) -> List[Dict[str, object]]: | |
| if not LOG_PATH.exists(): | |
| return [] | |
| rows = [] | |
| with LOG_PATH.open("r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| rows.append(row) | |
| return rows[-limit:] | |
| def _format_recent_rows(rows: List[Dict[str, object]]) -> List[List[object]]: | |
| formatted: List[List[object]] = [] | |
| for row in rows: | |
| timestamp_str = row.get("timestamp", "") | |
| if timestamp_str: | |
| try: | |
| # Parse ISO format timestamp and format as readable string | |
| dt_obj = dt.datetime.fromisoformat(str(timestamp_str).replace('Z', '+00:00')) | |
| formatted_timestamp = dt_obj.strftime("%Y-%m-%d %H:%M:%S") | |
| except (ValueError, AttributeError): | |
| formatted_timestamp = str(timestamp_str) | |
| else: | |
| formatted_timestamp = "" | |
| formatted_row = [formatted_timestamp] | |
| formatted_row.extend([row.get(col, "") for col in RECENT_COLUMNS[1:]]) | |
| formatted.append(formatted_row) | |
| return formatted | |
| def initialize_round() -> Tuple[Dict[str, object], Dict[str, Optional[float]], str, str, np.ndarray, str, List[List[object]], str, str]: | |
| queue = _random_queue() | |
| current_index = queue.pop() | |
| state = { | |
| "queue": queue, | |
| "current_index": current_index, | |
| "round": 1, | |
| } | |
| current_sample = SAMPLES[current_index] | |
| audio_value = str(current_sample.audio_path) | |
| clip_md = _prepare_clip_info(current_sample, state["round"]) | |
| prompt_text = "Click once on the map to pick your guess. The marker will update to your last selection." | |
| guess_state = {"lat": None, "lon": None, "pixel": None} | |
| recent_runs = _format_recent_rows(_load_recent_runs()) | |
| return state, guess_state, audio_value, clip_md, _base_map_array(), prompt_text, recent_runs, "", "" | |
| def _next_sample(state: Dict[str, object]) -> Sample: | |
| queue: List[int] = state["queue"] | |
| if not queue: | |
| queue.extend(_random_queue()) | |
| next_index = queue.pop() | |
| state["current_index"] = next_index | |
| state["round"] = state.get("round", 1) + 1 | |
| return SAMPLES[next_index] | |
| def _ensure_guess_state(state: Optional[Dict[str, Optional[float]]]) -> Dict[str, Optional[float]]: | |
| if not isinstance(state, dict): | |
| return {"lat": None, "lon": None, "pixel": None} | |
| return { | |
| "lat": state.get("lat"), | |
| "lon": state.get("lon"), | |
| "pixel": state.get("pixel"), | |
| } | |
| def handle_map_click( | |
| evt: gr.SelectData, | |
| current_guess_state: Optional[Dict[str, Optional[float]]], | |
| ) -> Tuple[np.ndarray, str, Dict[str, Optional[float]], str, str]: | |
| guess_state = _ensure_guess_state(current_guess_state) | |
| if evt is None: | |
| return _base_map_array(), "Unable to read selection. Please try again.", guess_state, "", "" | |
| index = getattr(evt, "index", None) | |
| value = getattr(evt, "value", None) | |
| x = y = None | |
| if isinstance(index, (tuple, list)) and len(index) >= 2: | |
| x, y = index[0], index[1] | |
| elif isinstance(index, dict): | |
| x = index.get("x") | |
| y = index.get("y") | |
| if (x is None or y is None) and isinstance(value, dict): | |
| x = value.get("x", x) | |
| y = value.get("y", y) | |
| elif (x is None or y is None) and isinstance(value, (tuple, list)) and len(value) >= 2: | |
| if x is None: | |
| x = value[0] | |
| if y is None: | |
| y = value[1] | |
| if x is None or y is None: | |
| return _base_map_array(), "Unable to read selection. Please try again.", guess_state, "", "" | |
| x = int(x) | |
| y = int(y) | |
| lat, lon = _pixel_to_latlon(x, y) | |
| guess_state = {"lat": lat, "lon": lon, "pixel": (x, y)} | |
| image_with_marker = _map_with_marker(x, y) | |
| return image_with_marker, _latlon_to_text(lat, lon), guess_state, f"{lon:.6f}", f"{lat:.6f}" | |
| def submit_guess( | |
| player_id: str, | |
| game_state: Dict[str, object], | |
| guess_state: Optional[Dict[str, Optional[float]]], | |
| longitude: str, | |
| latitude: str, | |
| ) -> Tuple[ | |
| Dict[str, object], | |
| Dict[str, Optional[float]], | |
| str, | |
| str, | |
| np.ndarray, | |
| str, | |
| List[List[object]], | |
| str, | |
| str, | |
| ]: | |
| guess_state = _ensure_guess_state(guess_state) | |
| player_id = (player_id or "").strip() | |
| if not player_id: | |
| message = "Please enter your player ID before submitting." | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1)) | |
| prompt_text = message | |
| return ( | |
| game_state, | |
| guess_state, | |
| str(current_sample.audio_path), | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| _format_recent_rows(_load_recent_runs()), | |
| longitude, | |
| latitude, | |
| ) | |
| # Parse longitude and latitude from text inputs | |
| try: | |
| longitude = longitude.strip() if longitude else "" | |
| latitude = latitude.strip() if latitude else "" | |
| if not longitude or not latitude: | |
| message = "Please enter both longitude and latitude, or click on the map to select a location." | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1)) | |
| prompt_text = message | |
| return ( | |
| game_state, | |
| guess_state, | |
| str(current_sample.audio_path), | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| _format_recent_rows(_load_recent_runs()), | |
| longitude, | |
| latitude, | |
| ) | |
| guess_lon = float(longitude) | |
| guess_lat = float(latitude) | |
| # Validate ranges | |
| if not (-180 <= guess_lon <= 180): | |
| message = "Longitude must be between -180 and 180." | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1)) | |
| prompt_text = message | |
| return ( | |
| game_state, | |
| guess_state, | |
| str(current_sample.audio_path), | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| _format_recent_rows(_load_recent_runs()), | |
| longitude, | |
| latitude, | |
| ) | |
| if not (-90 <= guess_lat <= 90): | |
| message = "Latitude must be between -90 and 90." | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1)) | |
| prompt_text = message | |
| return ( | |
| game_state, | |
| guess_state, | |
| str(current_sample.audio_path), | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| _format_recent_rows(_load_recent_runs()), | |
| longitude, | |
| latitude, | |
| ) | |
| except ValueError: | |
| message = "Invalid coordinates. Please enter valid numbers." | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| clip_md = _prepare_clip_info(current_sample, game_state.get("round", 1)) | |
| prompt_text = message | |
| return ( | |
| game_state, | |
| guess_state, | |
| str(current_sample.audio_path), | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| _format_recent_rows(_load_recent_runs()), | |
| longitude, | |
| latitude, | |
| ) | |
| current_sample = SAMPLES[game_state["current_index"]] | |
| true_lat = current_sample.latitude | |
| true_lon = current_sample.longitude | |
| distance_km = _haversine(true_lat, true_lon, guess_lat, guess_lon) | |
| reveal_lines = [ | |
| f"Real location: {current_sample.city or 'Unknown city'}, {current_sample.country or 'Unknown country'} ({true_lat:.3f}°, {true_lon:.3f}°)", | |
| f"Your guess: ({guess_lat:.3f}°, {guess_lon:.3f}°)", | |
| f"Error distance: {distance_km:.1f} km", | |
| ] | |
| log_entry = { | |
| "timestamp": dt.datetime.utcnow().isoformat(), | |
| "player_id": player_id, | |
| "question_id": current_sample.question_id, | |
| "audio_path": str(current_sample.audio_path), | |
| "guess_latitude": guess_lat, | |
| "guess_longitude": guess_lon, | |
| "true_latitude": true_lat, | |
| "true_longitude": true_lon, | |
| "distance_km": round(distance_km, 3), | |
| "city": current_sample.city, | |
| "country": current_sample.country, | |
| "continent": current_sample.continent, | |
| "title": current_sample.title, | |
| "description": current_sample.description, | |
| } | |
| _append_log(log_entry) | |
| next_sample = _next_sample(game_state) | |
| audio_value = str(next_sample.audio_path) | |
| clip_md = _prepare_clip_info(next_sample, game_state["round"]) | |
| new_guess_state = {"lat": None, "lon": None, "pixel": None} | |
| prompt_text = "Click once on the map to pick your guess. The marker will update to your last selection." | |
| recent_runs = _format_recent_rows(_load_recent_runs()) | |
| return ( | |
| game_state, | |
| new_guess_state, | |
| audio_value, | |
| clip_md, | |
| _base_map_array(), | |
| prompt_text, | |
| recent_runs, | |
| "", # Reset longitude input | |
| "", # Reset latitude input | |
| ) | |
| custom_css = """ | |
| h1 { | |
| font-size: 2.5rem !important; | |
| font-weight: 700 !important; | |
| margin-bottom: 1rem !important; | |
| text-align: center !important; | |
| } | |
| .intro-text { | |
| font-size: 1.1rem !important; | |
| line-height: 1.6 !important; | |
| margin-bottom: 1.5rem !important; | |
| color: #555 !important; | |
| text-align: center !important; | |
| } | |
| /* Apply size restrictions to regular (non-fullscreen) image */ | |
| .gradio-image:not([class*="fullscreen"]) { | |
| max-width: 100% !important; | |
| max-height: 600px !important; | |
| margin: 0 auto !important; | |
| display: flex !important; | |
| justify-content: center !important; | |
| } | |
| .gradio-image:not([class*="fullscreen"]) > div { | |
| max-width: 100% !important; | |
| max-height: 600px !important; | |
| margin: 0 auto !important; | |
| display: flex !important; | |
| justify-content: center !important; | |
| } | |
| .gradio-image:not([class*="fullscreen"]) img { | |
| max-width: 100% !important; | |
| max-height: 600px !important; | |
| width: auto !important; | |
| height: auto !important; | |
| object-fit: contain !important; | |
| margin: 0 auto !important; | |
| } | |
| /* Allow fullscreen modal to override size restrictions - higher specificity */ | |
| /* Target common Gradio fullscreen containers */ | |
| div[class*="modal"] .gradio-image, | |
| div[class*="modal"] .gradio-image > div, | |
| div[class*="modal"] .gradio-image img, | |
| div[class*="fullscreen"] .gradio-image, | |
| div[class*="fullscreen"] .gradio-image > div, | |
| div[class*="fullscreen"] .gradio-image img, | |
| div[id*="lightbox"] .gradio-image, | |
| div[id*="lightbox"] .gradio-image img, | |
| .gradio-image[class*="fullscreen"], | |
| .gradio-image[class*="fullscreen"] > div, | |
| .gradio-image[class*="fullscreen"] img { | |
| max-width: none !important; | |
| max-height: none !important; | |
| width: 100% !important; | |
| height: 100% !important; | |
| } | |
| .selection-text { | |
| font-size: 1.15rem !important; | |
| text-align: left !important; | |
| color: #666 !important; | |
| margin: 0.5rem 0 !important; | |
| } | |
| .selection-text * { | |
| font-size: 1.15rem !important; | |
| } | |
| .feedback-box { | |
| font-size: 1.15rem !important; | |
| line-height: 1.6 !important; | |
| } | |
| .feedback-box p, | |
| .feedback-box div, | |
| .feedback-box span { | |
| font-size: 1.15rem !important; | |
| line-height: 1.6 !important; | |
| } | |
| .clip-info { | |
| font-size: 1.1rem !important; | |
| } | |
| .clip-info p, | |
| .clip-info div, | |
| .clip-info span { | |
| font-size: 1.1rem !important; | |
| } | |
| .clip-info label { | |
| font-size: 1.1rem !important; | |
| font-weight: 600 !important; | |
| } | |
| label { | |
| font-size: 1rem !important; | |
| font-weight: 500 !important; | |
| } | |
| .form-text input { | |
| font-size: 0.95rem !important; | |
| } | |
| /* Table label styling - only target the label, not table content */ | |
| .gradio-dataframe label, | |
| [data-testid="dataframe"] label, | |
| label[for*="recent"], | |
| .form-group label { | |
| font-size: 1.15rem !important; | |
| font-weight: 600 !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Audio Geo-Localization Game", theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown("# Audio Geo-Localization Game") | |
| gr.HTML('<p class="intro-text">Welcome to the Audio Geo-Localization Game. Listen to an ambient audio clip, then guess where it was recorded by clicking on the world map. Submit to see the true location and how close you came.</p>') | |
| game_state = gr.State() | |
| guess_state = gr.State() | |
| clip_info = gr.Markdown(label="Clip details", elem_classes=["clip-info"]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| map_component = gr.Image( | |
| value=_base_map_array(), | |
| label="World map (click to set your guess)", | |
| interactive=True, | |
| type="numpy", | |
| image_mode="RGB", | |
| elem_classes=["gradio-image"], | |
| ) | |
| with gr.Column(scale=1): | |
| player_id = gr.Textbox(label="Player ID", placeholder="Enter an identifier so scores can be tracked", elem_classes=["form-text"]) | |
| audio_player = gr.Audio(label="Mystery audio clip", autoplay=False, interactive=False, streaming=False) | |
| selection_text = gr.Markdown("Click once on the map to pick your guess. The marker will update to your last selection.", elem_classes=["selection-text"]) | |
| with gr.Row(): | |
| longitude_input = gr.Textbox( | |
| label="Longitude", | |
| placeholder="Enter longitude (-180 to 180) or click map", | |
| elem_classes=["form-text"] | |
| ) | |
| latitude_input = gr.Textbox( | |
| label="Latitude", | |
| placeholder="Enter latitude (-90 to 90) or click map", | |
| elem_classes=["form-text"] | |
| ) | |
| submit_button = gr.Button("Submit Guess", variant="primary") | |
| recent_table = gr.Dataframe( | |
| headers=[ | |
| "timestamp", | |
| "player_id", | |
| "question_id", | |
| "distance_km", | |
| ], | |
| datatype=["str", "str", "str", "number"], | |
| value=[], | |
| interactive=False, | |
| label="Recent submissions (latest last)", | |
| wrap=True, | |
| ) | |
| demo.load( | |
| initialize_round, | |
| inputs=None, | |
| outputs=[game_state, guess_state, audio_player, clip_info, map_component, selection_text, recent_table, longitude_input, latitude_input], | |
| ) | |
| map_component.select( | |
| handle_map_click, | |
| inputs=[guess_state], | |
| outputs=[map_component, selection_text, guess_state, longitude_input, latitude_input], | |
| preprocess=False, | |
| ) | |
| submit_button.click( | |
| submit_guess, | |
| inputs=[player_id, game_state, guess_state, longitude_input, latitude_input], | |
| outputs=[game_state, guess_state, audio_player, clip_info, map_component, selection_text, recent_table, longitude_input, latitude_input], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=3828) | |