Spaces:
Running
Running
| """ | |
| server.py β WorldPolicy-Env V6.1 OpenEnv-compliant FastAPI backend. | |
| The base app is built by `openenv.core.env_server.http_server.create_app`, which | |
| gives us the standard OpenEnv contract for free: | |
| POST /reset POST /step GET /state GET /schema GET /health WS /ws | |
| GET /docs (FastAPI auto-docs) | |
| On top of that we keep every pre-existing route from the V6.1 demo: | |
| GET /groq-status (renamed from /health to avoid OpenEnv collision) | |
| GET /persona/{agent_id} | |
| GET /relationship-matrix | |
| GET /un-authority/{crisis_type} | |
| GET /vote-outcome/{round_id} | |
| GET /stream/debate (SSE) | |
| GET /stream/country-pnl (SSE) | |
| GET /stream/company-pnl (SSE) | |
| POST /live-debate -> { round_id } | |
| And we add new plan endpoints: | |
| POST /grader composite scoring across a finished episode | |
| GET /live-crisis/{type} GDELT live headline (cached + fallback) | |
| GET /tasks catalogue of 3 graduated tasks | |
| Plus the SPA routes (must stay AFTER all API routes due to /{fname:path} catch-all): | |
| GET / WorldPolicy V6.1.html | |
| GET /{fname:path} whitelisted static (.css .jsx .js .json .md ...) | |
| Run: | |
| python server.py # binds 0.0.0.0:7860 (HF Spaces convention) | |
| uvicorn server:app # equivalent | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import uuid | |
| from collections import deque | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, AsyncIterator | |
| from dotenv import load_dotenv | |
| from fastapi import HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse, StreamingResponse | |
| from openenv.core.env_server.http_server import create_app | |
| # Load local .env before importing modules that read env at import time. | |
| # override=True ensures stale exported shell vars don't shadow .env edits. | |
| load_dotenv(override=True) | |
| from debate_orchestrator import DebateOrchestrator, UNMediator | |
| from environment import WorldPolicyEnvironment | |
| from graders import grade_episode | |
| from models import WorldPolicyAction, WorldPolicyObservation | |
| from persona_loader import PersonaLoader | |
| from tasks import list_tasks, TASKS | |
| from crisis_types import ALLOWED_CRISIS_TYPES as _ALLOWED_CRISIS | |
| # Optional: live-data layer (added in P1). Guarded so server still boots if missing. | |
| try: | |
| from live_data import get_live_crisis # noqa: F401 | |
| _LIVE_DATA_OK = True | |
| except Exception: | |
| _LIVE_DATA_OK = False | |
| # Optional: yfinance market layer (P3). Soft import β server boots cleanly if | |
| # yfinance is missing; the company ticker strip + /market-data both fall through | |
| # to the static seed in that case. | |
| try: | |
| from market_data import get_market_snapshot, get_company_prices | |
| _MARKET_DATA_OK = True | |
| except Exception: | |
| _MARKET_DATA_OK = False | |
| ROOT = Path(__file__).parent.resolve() | |
| PERSONAS_DIR = ROOT / "personas" | |
| INDEX_HTML = ROOT / "WorldPolicy V6.1.html" | |
| AGENT_IDS = {"USA", "CHN", "RUS", "IND", "DPRK", "SAU", "UN"} | |
| # V6: Input validation allowlists | |
| ALLOWED_CRISIS_TYPES = _ALLOWED_CRISIS | |
| MAX_DESCRIPTION_LEN = 500 | |
| MAX_ACTION_LEN = 100 | |
| # V2: CORS origins from env, default restrictive | |
| _cors_origins = os.environ.get("WP_CORS_ORIGINS", "*").split(",") | |
| # ββ App ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # OpenEnv builds the FastAPI app for us. max_concurrent_envs=4 matches the | |
| # standard GRPO 4-rollout pattern so the validator can fan out cleanly. | |
| app = create_app( | |
| WorldPolicyEnvironment, | |
| WorldPolicyAction, | |
| WorldPolicyObservation, | |
| env_name="worldpolicy_env", | |
| max_concurrent_envs=int(os.environ.get("WP_MAX_CONCURRENT_ENVS", 4)), | |
| ) | |
| app.title = "WorldPolicy-Env V6.1 Backend" | |
| app.version = "1.0.0" | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=_cors_origins, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _loader = PersonaLoader() | |
| _orchestrator = DebateOrchestrator() | |
| _mediator = UNMediator() | |
| # round_id -> {"vote_tally": {...}, "crisis_type": ..., "utterances": [...]} | |
| _round_cache: dict[str, dict] = {} | |
| _recent_rounds: deque[str] = deque(maxlen=32) | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sse(payload: dict, event: str | None = None) -> str: | |
| """Format a single SSE frame.""" | |
| prefix = f"event: {event}\n" if event else "" | |
| return f"{prefix}data: {json.dumps(payload)}\n\n" | |
| def _store_round(round_id: str, crisis_type: str, utterances: list[dict], tally: dict): | |
| _round_cache[round_id] = { | |
| "round_id": round_id, | |
| "crisis_type": crisis_type, | |
| "vote_tally": tally, | |
| "utterances": utterances, | |
| "stored_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| _recent_rounds.append(round_id) | |
| if len(_round_cache) > _recent_rounds.maxlen: | |
| stale = set(_round_cache.keys()) - set(_recent_rounds) | |
| for sid in stale: | |
| _round_cache.pop(sid, None) | |
| # ββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Note: /health is owned by OpenEnv (created by create_app). Our pre-existing | |
| # liveness payload moved to /groq-status to avoid the collision while preserving | |
| # the SPA's amber/teal LED behaviour. | |
| def groq_status(): | |
| backend = getattr(_orchestrator, "_backend", "none") | |
| return { | |
| "status": "ok", | |
| "debate_backend": backend, | |
| "live_debate_enabled": bool(getattr(_orchestrator, "_use_live", False)), | |
| "live_groq": backend == "groq", | |
| "live_trained_model": backend == "mappo", | |
| "live_data_layer": _LIVE_DATA_OK, | |
| "market_data_layer": _MARKET_DATA_OK, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| } | |
| def market_data(): | |
| """P3: live yfinance market snapshot. Companies + country indices + cache info. | |
| Returns the same shape as `market_data.get_market_snapshot()`: | |
| {companies: [{symbol, name, countryId, currency, price, pct, live}, ...], | |
| indices: {AGENT_ID: {ticker, price, change_pct, live}, ...}, | |
| live: bool, yf_loaded: bool, fetched_at: float, cache_ttl: int} | |
| """ | |
| if not _MARKET_DATA_OK: | |
| return { | |
| "companies": [], | |
| "indices": {}, | |
| "live": False, | |
| "yf_loaded": False, | |
| "fetched_at": datetime.now(timezone.utc).timestamp(), | |
| "cache_ttl": 0, | |
| "error": "market_data module unavailable", | |
| } | |
| return get_market_snapshot() | |
| def get_tasks(): | |
| """Catalogue of the 3 graduated tasks (consumed by inference.py + UI).""" | |
| return {"tasks": list_tasks()} | |
| def grader(body: dict): | |
| """Composite scoring across a finished episode. | |
| Body shape: | |
| { "session_id": str | null, "task": "task_1" | "task_2" | "task_3", | |
| "rounds": [ {round_result dicts as emitted in step()'s metadata.round} ] } | |
| Returns: { task, raw_score, avg_per_round, normalized, step_count, target_range } | |
| """ | |
| task = str(body.get("task") or "task_1") | |
| rounds = body.get("rounds") or [] | |
| if not isinstance(rounds, list): | |
| raise HTTPException(400, "rounds must be a list of round_result dicts") | |
| result = grade_episode(rounds, task=task) | |
| # Attach the task's target reward range so callers can gate reward-hacking checks | |
| from tasks import get_task as _get_task | |
| cfg = _get_task(task) | |
| result["target_range"] = list(cfg.get("target_reward_range", (0.4, 0.8))) | |
| result["session_id"] = body.get("session_id") | |
| return result | |
| def live_crisis(crisis_type: str): | |
| """GDELT-backed live crisis headline. Falls back to static if live layer absent.""" | |
| if crisis_type not in ALLOWED_CRISIS_TYPES: | |
| raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}") | |
| if not _LIVE_DATA_OK: | |
| return {"type": crisis_type, "live": False, "headline": None, "fallback_reason": "live_data module missing"} | |
| from live_data import get_live_crisis as _gc | |
| return _gc(crisis_type) | |
| def country_sentiment(agent_id: str): | |
| """P4: GDELT tonechart-derived public sentiment for one agent's country.""" | |
| aid = agent_id.upper() | |
| if aid not in AGENT_IDS: | |
| raise HTTPException(404, f"unknown agent '{aid}'") | |
| if not _LIVE_DATA_OK: | |
| return {"agent_id": aid, "tone": 0.0, "label": "neutral", "live": False, | |
| "color": "#94a3b8", "sample_size": 0, "fallback_reason": "live_data module missing"} | |
| from live_data import get_country_sentiment as _gcs | |
| return _gcs(aid) | |
| def sentiment_snapshot(): | |
| """P4: snapshot of all 7 agents' sentiments. Frontend hits this on mount + every 60s.""" | |
| if not _LIVE_DATA_OK: | |
| return {"sentiments": {}, "live": False, "error": "live_data module unavailable"} | |
| from live_data import get_all_sentiments as _all | |
| snap = _all() | |
| any_live = any(v.get("live") for v in snap.values()) | |
| return {"sentiments": snap, "live": bool(any_live)} | |
| def get_persona(agent_id: str): | |
| agent_id = agent_id.upper() | |
| if agent_id not in AGENT_IDS: | |
| raise HTTPException(404, f"unknown agent '{agent_id}'") | |
| try: | |
| return _loader.load_persona(agent_id) | |
| except FileNotFoundError: | |
| raise HTTPException(404, f"persona file missing for {agent_id}") | |
| def get_matrix(): | |
| return { | |
| "matrix": _loader._relationships, | |
| "grudge_memory": _loader._grudge_memory, | |
| } | |
| def get_authority(crisis_type: str, limit: int = Query(3, ge=1, le=10)): | |
| if crisis_type not in ALLOWED_CRISIS_TYPES: | |
| raise HTTPException(400, f"unknown crisis type '{crisis_type}'; allowed: {sorted(ALLOWED_CRISIS_TYPES)}") | |
| articles = _mediator.get_articles_for_crisis(crisis_type, limit=limit) | |
| if not articles: | |
| raise HTTPException(404, f"no authority articles for crisis '{crisis_type}'") | |
| return { | |
| "crisis_type": crisis_type, | |
| "within_mandate": _mediator.is_within_mandate(crisis_type), | |
| "articles": articles, | |
| } | |
| def get_vote(round_id: str): | |
| record = _round_cache.get(round_id) | |
| if not record: | |
| raise HTTPException(404, f"round_id '{round_id}' not cached") | |
| return record | |
| # ββ Streaming debate ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _ALL_AGENTS = ["USA", "CHN", "RUS", "IND", "DPRK", "SAU", "UN"] | |
| _DEFAULT_INVOLVEMENT = { | |
| "involved": ["USA", "CHN", "RUS", "IND", "DPRK", "SAU"], | |
| "peripheral": ["UN"], | |
| "uninvolved": [], | |
| } | |
| def _derive_involvement(crisis_type: str) -> dict: | |
| """Derive involvement tiers from task config. All sovereign agents speak; | |
| primary_agents go first, rest follow, UN always last.""" | |
| for task_cfg in TASKS.values(): | |
| if task_cfg.get("crisis_type") == crisis_type: | |
| active = task_cfg["active_agents"] | |
| primary = task_cfg.get("primary_agents", []) | |
| sovereign = [a for a in active if a != "UN"] | |
| involved = [a for a in sovereign if a in primary] | |
| peripheral = [a for a in sovereign if a not in primary] + ["UN"] | |
| uninvolved = [a for a in _ALL_AGENTS if a not in active] | |
| return {"involved": involved, "peripheral": peripheral, "uninvolved": uninvolved} | |
| return dict(_DEFAULT_INVOLVEMENT) | |
| async def _debate_event_stream( | |
| crisis_type: str, | |
| crisis_description: str, | |
| mappo_action: str, | |
| force_canned: bool, | |
| max_rounds: int, | |
| ) -> AsyncIterator[str]: | |
| # Prefer real-time crisis headline when available so debate context is live. | |
| try: | |
| live = get_live_crisis(crisis_type) if _LIVE_DATA_OK else None | |
| live_headline = (live or {}).get("headline") if isinstance(live, dict) else None | |
| if live_headline: | |
| crisis_description = str(live_headline)[:MAX_DESCRIPTION_LEN] | |
| except Exception: | |
| pass | |
| involvement = _derive_involvement(crisis_type) | |
| task_cfg = next((t for t in TASKS.values() if t.get("crisis_type") == crisis_type), None) | |
| max_steps = task_cfg["max_steps"] if task_cfg else 10 | |
| world_state = { | |
| "step": max_steps // 2, | |
| "welfare_index": 0.50, | |
| "active_crises": [crisis_type], | |
| "crisis_description": crisis_description, | |
| } | |
| try: | |
| async for event in _orchestrator.run_multi_round_debate( | |
| crisis_type=crisis_type, | |
| crisis_description=crisis_description, | |
| mappo_action=mappo_action, | |
| world_state=world_state, | |
| involvement=involvement, | |
| force_canned=force_canned, | |
| max_rounds=max_rounds, | |
| ): | |
| etype = event.pop("_event", "utterance") | |
| yield _sse(event, event=etype) | |
| except Exception as exc: | |
| yield _sse({"error": str(exc)}, event="error_event") | |
| async def stream_debate( | |
| crisis_type: str = Query("natural_disaster"), | |
| crisis_description: str = Query("Severe cyclone hits Bay of Bengal; UNESCO heritage sites at risk."), | |
| mappo_action: str = Query("AID_DISPATCH_COORDINATED"), | |
| force_canned: bool = Query(True), | |
| max_rounds: int = Query(3, ge=1, le=3), | |
| ): | |
| if crisis_type not in ALLOWED_CRISIS_TYPES: | |
| raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}") | |
| crisis_description = crisis_description[:MAX_DESCRIPTION_LEN] | |
| mappo_action = mappo_action[:MAX_ACTION_LEN] | |
| return StreamingResponse( | |
| _debate_event_stream(crisis_type, crisis_description, mappo_action, force_canned, max_rounds), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| async def live_debate( | |
| crisis_type: str = Query("natural_disaster"), | |
| crisis_description: str = Query("Severe cyclone hits Bay of Bengal; UNESCO heritage sites at risk."), | |
| mappo_action: str = Query("AID_DISPATCH_COORDINATED"), | |
| max_rounds: int = Query(3, ge=1, le=3), | |
| ): | |
| """Kick off a live Groq debate (or canned if no key). Returns metadata to subscribe to /stream/debate.""" | |
| if crisis_type not in ALLOWED_CRISIS_TYPES: | |
| raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}") | |
| crisis_description = crisis_description[:MAX_DESCRIPTION_LEN] | |
| mappo_action = mappo_action[:MAX_ACTION_LEN] | |
| if not _orchestrator._use_live: | |
| return JSONResponse( | |
| {"live": False, "reason": "No live debate backend configured (set HF_TOKEN for trained model or switch backend). /stream/debate will serve canned.", | |
| "subscribe": f"/stream/debate?force_canned=true&max_rounds={max_rounds}&crisis_type={crisis_type}"}, | |
| status_code=200, | |
| ) | |
| return { | |
| "live": True, | |
| "subscribe": f"/stream/debate?force_canned=false&max_rounds={max_rounds}&crisis_type={crisis_type}", | |
| } | |
| # ββ Country / Company P&L streams βββββββββββββββββββββββββββββββββββββββ | |
| _SCRIPTED_COUNTRY_TICKS = [ | |
| {"at": 2, "countryId": "USA", "deltas": {"gdp": -0.02, "welfare": -0.01}}, | |
| {"at": 5, "countryId": "CHN", "deltas": {"gdp": -0.01, "influence": 0.015}}, | |
| {"at": 8, "countryId": "RUS", "deltas": {"gdp": 0.005, "military": 0.03}}, | |
| {"at": 11, "countryId": "IND", "deltas": {"gdp": 0.02, "welfare": 0.018}}, | |
| {"at": 17, "countryId": "SAU", "deltas": {"gdp": -0.015, "energy": 0.02}}, | |
| {"at": 20, "countryId": "UN", "deltas": {"heritage": 0.04}}, | |
| {"at": 25, "countryId": "USA", "deltas": {"gdp": -0.01, "influence": -0.02}}, | |
| {"at": 34, "countryId": "IND", "deltas": {"gdp": 0.03, "influence": 0.025}}, | |
| {"at": 45, "countryId": "IND", "deltas": {"welfare": 0.04, "heritage": 0.02}}, | |
| ] | |
| _SCRIPTED_COMPANY_TICKS = [ | |
| {"at": 5, "symbol": "AAPL", "price": 189.32, "pct": 0.8}, | |
| {"at": 5, "symbol": "BYDDY", "price": 215.40, "pct": -0.6}, | |
| {"at": 12, "symbol": "GAZP", "price": 139.20, "pct": -3.8}, | |
| {"at": 12, "symbol": "RELI", "price": 2860.00, "pct": 0.9}, | |
| {"at": 20, "symbol": "2222", "price": 33.10, "pct": 2.2}, | |
| {"at": 30, "symbol": "GAZP", "price": 136.00, "pct": -5.1}, | |
| {"at": 30, "symbol": "KOMID", "price": 86.50, "pct": -2.2}, | |
| {"at": 40, "symbol": "AAPL", "price": 190.50, "pct": 1.4}, | |
| {"at": 40, "symbol": "2222", "price": 33.80, "pct": 3.5}, | |
| ] | |
| async def _pnl_stream(entries: list[dict], tick_ms: int = 800) -> AsyncIterator[str]: | |
| step = 0 | |
| max_step = max(e["at"] for e in entries) + 2 | |
| while step <= max_step: | |
| step += 1 | |
| emitted = [e for e in entries if e["at"] == step] | |
| for ev in emitted: | |
| payload = {**ev, "step": step, "ts": datetime.now(timezone.utc).isoformat()} | |
| yield _sse(payload, event="pnl_tick") | |
| await asyncio.sleep(tick_ms / 1000) | |
| yield _sse({"step": step, "done": True}, event="pnl_end") | |
| async def stream_country_pnl(tick_ms: int = Query(800, ge=50, le=5000)): | |
| return StreamingResponse( | |
| _pnl_stream(_SCRIPTED_COUNTRY_TICKS, tick_ms), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| def _build_company_ticks_with_live() -> list[dict]: | |
| """P3: keep the scripted *cadence* of _SCRIPTED_COMPANY_TICKS but overwrite | |
| each tick's `price` + `pct` with the live yfinance snapshot when available. | |
| Result: the existing SSE stream the SPA already consumes ships LIVE prices | |
| on a deterministic schedule. Falls back to the scripted constants if the | |
| market layer is missing or returned no live data for that symbol. | |
| """ | |
| if not _MARKET_DATA_OK: | |
| return [{**t, "_demo": True, "live": False} for t in _SCRIPTED_COMPANY_TICKS] | |
| try: | |
| live = {c["symbol"]: c for c in get_company_prices() or []} | |
| except Exception: | |
| return [{**t, "_demo": True, "live": False} for t in _SCRIPTED_COMPANY_TICKS] | |
| out = [] | |
| for tick in _SCRIPTED_COMPANY_TICKS: | |
| sym = tick["symbol"] | |
| live_row = live.get(sym, {}) | |
| merged = {**tick} | |
| if live_row.get("live"): | |
| merged["price"] = live_row["price"] | |
| merged["pct"] = live_row["pct"] | |
| merged["live"] = True | |
| merged["_demo"] = False | |
| else: | |
| merged["live"] = False | |
| merged["_demo"] = True | |
| out.append(merged) | |
| return out | |
| async def stream_company_pnl(tick_ms: int = Query(800, ge=50, le=5000)): | |
| return StreamingResponse( | |
| _pnl_stream(_build_company_ticks_with_live(), tick_ms), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| # ββ Static frontend (same-origin serve for HF Spaces single-container) ββ | |
| def root_index(): | |
| if INDEX_HTML.exists(): | |
| return FileResponse(INDEX_HTML, media_type="text/html") | |
| raise HTTPException(404, "index HTML missing") | |
| _STATIC_WHITELIST = {".css", ".jsx", ".js", ".json", ".md", ".png", ".jpg", ".svg", ".ico"} | |
| _MEDIA_TYPES = { | |
| ".jsx": "text/babel", | |
| ".js": "application/javascript", | |
| ".css": "text/css", | |
| ".md": "text/markdown", | |
| ".json": "application/json", | |
| ".png": "image/png", | |
| ".jpg": "image/jpeg", | |
| ".svg": "image/svg+xml", | |
| ".ico": "image/x-icon", | |
| } | |
| def serve_static(fname: str): | |
| """Serve project root files (CSS, JSX, personas/*) behind an extension whitelist. | |
| V1 FIX: Uses Path.resolve() + is_relative_to() to prevent path traversal. | |
| Rejects symlinks pointing outside ROOT. | |
| """ | |
| if not fname or fname.startswith("/"): | |
| raise HTTPException(400, "invalid path") | |
| # V1: Resolve to real path and verify containment | |
| target = (ROOT / fname).resolve() | |
| if not target.is_relative_to(ROOT): | |
| raise HTTPException(403, "access denied") | |
| if not target.exists() or not target.is_file(): | |
| raise HTTPException(404, "not found") | |
| if target.suffix.lower() not in _STATIC_WHITELIST: | |
| raise HTTPException(403, f"type not served: {target.suffix}") | |
| media = _MEDIA_TYPES.get(target.suffix.lower(), "application/octet-stream") | |
| return FileResponse(target, media_type=media) | |
| # ββ CLI entry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False) | |