worldpolicy-v6 / server.py
rajx77's picture
Refactored custom Llama 70B scenarios, mapped country arcs for globe, and fixed debate orchestration UI
aa05560
"""
server.py β€” WorldPolicy-Env V6.1 OpenEnv-compliant FastAPI backend.
The base app is built by `openenv.core.env_server.http_server.create_app`, which
gives us the standard OpenEnv contract for free:
POST /reset POST /step GET /state GET /schema GET /health WS /ws
GET /docs (FastAPI auto-docs)
On top of that we keep every pre-existing route from the V6.1 demo:
GET /groq-status (renamed from /health to avoid OpenEnv collision)
GET /persona/{agent_id}
GET /relationship-matrix
GET /un-authority/{crisis_type}
GET /vote-outcome/{round_id}
GET /stream/debate (SSE)
GET /stream/country-pnl (SSE)
GET /stream/company-pnl (SSE)
POST /live-debate -> { round_id }
And we add new plan endpoints:
POST /grader composite scoring across a finished episode
GET /live-crisis/{type} GDELT live headline (cached + fallback)
GET /tasks catalogue of 3 graduated tasks
Plus the SPA routes (must stay AFTER all API routes due to /{fname:path} catch-all):
GET / WorldPolicy V6.1.html
GET /{fname:path} whitelisted static (.css .jsx .js .json .md ...)
Run:
python server.py # binds 0.0.0.0:7860 (HF Spaces convention)
uvicorn server:app # equivalent
"""
from __future__ import annotations
import asyncio
import json
import os
import uuid
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, AsyncIterator
from dotenv import load_dotenv
from fastapi import HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse, StreamingResponse
from openenv.core.env_server.http_server import create_app
# Load local .env before importing modules that read env at import time.
# override=True ensures stale exported shell vars don't shadow .env edits.
load_dotenv(override=True)
from debate_orchestrator import DebateOrchestrator, UNMediator
from environment import WorldPolicyEnvironment
from graders import grade_episode
from models import WorldPolicyAction, WorldPolicyObservation
from persona_loader import PersonaLoader
from tasks import list_tasks, TASKS
from crisis_types import ALLOWED_CRISIS_TYPES as _ALLOWED_CRISIS
# Optional: live-data layer (added in P1). Guarded so server still boots if missing.
try:
from live_data import get_live_crisis # noqa: F401
_LIVE_DATA_OK = True
except Exception:
_LIVE_DATA_OK = False
# Optional: yfinance market layer (P3). Soft import β€” server boots cleanly if
# yfinance is missing; the company ticker strip + /market-data both fall through
# to the static seed in that case.
try:
from market_data import get_market_snapshot, get_company_prices
_MARKET_DATA_OK = True
except Exception:
_MARKET_DATA_OK = False
ROOT = Path(__file__).parent.resolve()
PERSONAS_DIR = ROOT / "personas"
INDEX_HTML = ROOT / "WorldPolicy V6.1.html"
AGENT_IDS = {"USA", "CHN", "RUS", "IND", "DPRK", "SAU", "UN"}
# V6: Input validation allowlists
ALLOWED_CRISIS_TYPES = _ALLOWED_CRISIS
MAX_DESCRIPTION_LEN = 500
MAX_ACTION_LEN = 100
# V2: CORS origins from env, default restrictive
_cors_origins = os.environ.get("WP_CORS_ORIGINS", "*").split(",")
# ── App ──────────────────────────────────────────────────────────────────
# OpenEnv builds the FastAPI app for us. max_concurrent_envs=4 matches the
# standard GRPO 4-rollout pattern so the validator can fan out cleanly.
app = create_app(
WorldPolicyEnvironment,
WorldPolicyAction,
WorldPolicyObservation,
env_name="worldpolicy_env",
max_concurrent_envs=int(os.environ.get("WP_MAX_CONCURRENT_ENVS", 4)),
)
app.title = "WorldPolicy-Env V6.1 Backend"
app.version = "1.0.0"
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
# ── Singletons ───────────────────────────────────────────────────────────
_loader = PersonaLoader()
_orchestrator = DebateOrchestrator()
_mediator = UNMediator()
# round_id -> {"vote_tally": {...}, "crisis_type": ..., "utterances": [...]}
_round_cache: dict[str, dict] = {}
_recent_rounds: deque[str] = deque(maxlen=32)
# ── Helpers ──────────────────────────────────────────────────────────────
def _sse(payload: dict, event: str | None = None) -> str:
"""Format a single SSE frame."""
prefix = f"event: {event}\n" if event else ""
return f"{prefix}data: {json.dumps(payload)}\n\n"
def _store_round(round_id: str, crisis_type: str, utterances: list[dict], tally: dict):
_round_cache[round_id] = {
"round_id": round_id,
"crisis_type": crisis_type,
"vote_tally": tally,
"utterances": utterances,
"stored_at": datetime.now(timezone.utc).isoformat(),
}
_recent_rounds.append(round_id)
if len(_round_cache) > _recent_rounds.maxlen:
stale = set(_round_cache.keys()) - set(_recent_rounds)
for sid in stale:
_round_cache.pop(sid, None)
# ── Routes ──────────────────────────────────────────────────────────────
# Note: /health is owned by OpenEnv (created by create_app). Our pre-existing
# liveness payload moved to /groq-status to avoid the collision while preserving
# the SPA's amber/teal LED behaviour.
@app.get("/groq-status")
def groq_status():
backend = getattr(_orchestrator, "_backend", "none")
return {
"status": "ok",
"debate_backend": backend,
"live_debate_enabled": bool(getattr(_orchestrator, "_use_live", False)),
"live_groq": backend == "groq",
"live_trained_model": backend == "mappo",
"live_data_layer": _LIVE_DATA_OK,
"market_data_layer": _MARKET_DATA_OK,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/market-data")
def market_data():
"""P3: live yfinance market snapshot. Companies + country indices + cache info.
Returns the same shape as `market_data.get_market_snapshot()`:
{companies: [{symbol, name, countryId, currency, price, pct, live}, ...],
indices: {AGENT_ID: {ticker, price, change_pct, live}, ...},
live: bool, yf_loaded: bool, fetched_at: float, cache_ttl: int}
"""
if not _MARKET_DATA_OK:
return {
"companies": [],
"indices": {},
"live": False,
"yf_loaded": False,
"fetched_at": datetime.now(timezone.utc).timestamp(),
"cache_ttl": 0,
"error": "market_data module unavailable",
}
return get_market_snapshot()
@app.get("/tasks")
def get_tasks():
"""Catalogue of the 3 graduated tasks (consumed by inference.py + UI)."""
return {"tasks": list_tasks()}
@app.post("/grader")
def grader(body: dict):
"""Composite scoring across a finished episode.
Body shape:
{ "session_id": str | null, "task": "task_1" | "task_2" | "task_3",
"rounds": [ {round_result dicts as emitted in step()'s metadata.round} ] }
Returns: { task, raw_score, avg_per_round, normalized, step_count, target_range }
"""
task = str(body.get("task") or "task_1")
rounds = body.get("rounds") or []
if not isinstance(rounds, list):
raise HTTPException(400, "rounds must be a list of round_result dicts")
result = grade_episode(rounds, task=task)
# Attach the task's target reward range so callers can gate reward-hacking checks
from tasks import get_task as _get_task
cfg = _get_task(task)
result["target_range"] = list(cfg.get("target_reward_range", (0.4, 0.8)))
result["session_id"] = body.get("session_id")
return result
@app.get("/live-crisis/{crisis_type}")
def live_crisis(crisis_type: str):
"""GDELT-backed live crisis headline. Falls back to static if live layer absent."""
if crisis_type not in ALLOWED_CRISIS_TYPES:
raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}")
if not _LIVE_DATA_OK:
return {"type": crisis_type, "live": False, "headline": None, "fallback_reason": "live_data module missing"}
from live_data import get_live_crisis as _gc
return _gc(crisis_type)
@app.get("/country-sentiment/{agent_id}")
def country_sentiment(agent_id: str):
"""P4: GDELT tonechart-derived public sentiment for one agent's country."""
aid = agent_id.upper()
if aid not in AGENT_IDS:
raise HTTPException(404, f"unknown agent '{aid}'")
if not _LIVE_DATA_OK:
return {"agent_id": aid, "tone": 0.0, "label": "neutral", "live": False,
"color": "#94a3b8", "sample_size": 0, "fallback_reason": "live_data module missing"}
from live_data import get_country_sentiment as _gcs
return _gcs(aid)
@app.get("/sentiment")
def sentiment_snapshot():
"""P4: snapshot of all 7 agents' sentiments. Frontend hits this on mount + every 60s."""
if not _LIVE_DATA_OK:
return {"sentiments": {}, "live": False, "error": "live_data module unavailable"}
from live_data import get_all_sentiments as _all
snap = _all()
any_live = any(v.get("live") for v in snap.values())
return {"sentiments": snap, "live": bool(any_live)}
@app.get("/persona/{agent_id}", response_class=PlainTextResponse)
def get_persona(agent_id: str):
agent_id = agent_id.upper()
if agent_id not in AGENT_IDS:
raise HTTPException(404, f"unknown agent '{agent_id}'")
try:
return _loader.load_persona(agent_id)
except FileNotFoundError:
raise HTTPException(404, f"persona file missing for {agent_id}")
@app.get("/relationship-matrix")
def get_matrix():
return {
"matrix": _loader._relationships,
"grudge_memory": _loader._grudge_memory,
}
@app.get("/un-authority/{crisis_type}")
def get_authority(crisis_type: str, limit: int = Query(3, ge=1, le=10)):
if crisis_type not in ALLOWED_CRISIS_TYPES:
raise HTTPException(400, f"unknown crisis type '{crisis_type}'; allowed: {sorted(ALLOWED_CRISIS_TYPES)}")
articles = _mediator.get_articles_for_crisis(crisis_type, limit=limit)
if not articles:
raise HTTPException(404, f"no authority articles for crisis '{crisis_type}'")
return {
"crisis_type": crisis_type,
"within_mandate": _mediator.is_within_mandate(crisis_type),
"articles": articles,
}
@app.get("/vote-outcome/{round_id}")
def get_vote(round_id: str):
record = _round_cache.get(round_id)
if not record:
raise HTTPException(404, f"round_id '{round_id}' not cached")
return record
# ── Streaming debate ────────────────────────────────────────────────────
_ALL_AGENTS = ["USA", "CHN", "RUS", "IND", "DPRK", "SAU", "UN"]
_DEFAULT_INVOLVEMENT = {
"involved": ["USA", "CHN", "RUS", "IND", "DPRK", "SAU"],
"peripheral": ["UN"],
"uninvolved": [],
}
def _derive_involvement(crisis_type: str) -> dict:
"""Derive involvement tiers from task config. All sovereign agents speak;
primary_agents go first, rest follow, UN always last."""
for task_cfg in TASKS.values():
if task_cfg.get("crisis_type") == crisis_type:
active = task_cfg["active_agents"]
primary = task_cfg.get("primary_agents", [])
sovereign = [a for a in active if a != "UN"]
involved = [a for a in sovereign if a in primary]
peripheral = [a for a in sovereign if a not in primary] + ["UN"]
uninvolved = [a for a in _ALL_AGENTS if a not in active]
return {"involved": involved, "peripheral": peripheral, "uninvolved": uninvolved}
return dict(_DEFAULT_INVOLVEMENT)
async def _debate_event_stream(
crisis_type: str,
crisis_description: str,
mappo_action: str,
force_canned: bool,
max_rounds: int,
) -> AsyncIterator[str]:
# Prefer real-time crisis headline when available so debate context is live.
try:
live = get_live_crisis(crisis_type) if _LIVE_DATA_OK else None
live_headline = (live or {}).get("headline") if isinstance(live, dict) else None
if live_headline:
crisis_description = str(live_headline)[:MAX_DESCRIPTION_LEN]
except Exception:
pass
involvement = _derive_involvement(crisis_type)
task_cfg = next((t for t in TASKS.values() if t.get("crisis_type") == crisis_type), None)
max_steps = task_cfg["max_steps"] if task_cfg else 10
world_state = {
"step": max_steps // 2,
"welfare_index": 0.50,
"active_crises": [crisis_type],
"crisis_description": crisis_description,
}
try:
async for event in _orchestrator.run_multi_round_debate(
crisis_type=crisis_type,
crisis_description=crisis_description,
mappo_action=mappo_action,
world_state=world_state,
involvement=involvement,
force_canned=force_canned,
max_rounds=max_rounds,
):
etype = event.pop("_event", "utterance")
yield _sse(event, event=etype)
except Exception as exc:
yield _sse({"error": str(exc)}, event="error_event")
@app.get("/stream/debate")
async def stream_debate(
crisis_type: str = Query("natural_disaster"),
crisis_description: str = Query("Severe cyclone hits Bay of Bengal; UNESCO heritage sites at risk."),
mappo_action: str = Query("AID_DISPATCH_COORDINATED"),
force_canned: bool = Query(True),
max_rounds: int = Query(3, ge=1, le=3),
):
if crisis_type not in ALLOWED_CRISIS_TYPES:
raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}")
crisis_description = crisis_description[:MAX_DESCRIPTION_LEN]
mappo_action = mappo_action[:MAX_ACTION_LEN]
return StreamingResponse(
_debate_event_stream(crisis_type, crisis_description, mappo_action, force_canned, max_rounds),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/live-debate")
async def live_debate(
crisis_type: str = Query("natural_disaster"),
crisis_description: str = Query("Severe cyclone hits Bay of Bengal; UNESCO heritage sites at risk."),
mappo_action: str = Query("AID_DISPATCH_COORDINATED"),
max_rounds: int = Query(3, ge=1, le=3),
):
"""Kick off a live Groq debate (or canned if no key). Returns metadata to subscribe to /stream/debate."""
if crisis_type not in ALLOWED_CRISIS_TYPES:
raise HTTPException(400, f"unknown crisis_type; allowed: {sorted(ALLOWED_CRISIS_TYPES)}")
crisis_description = crisis_description[:MAX_DESCRIPTION_LEN]
mappo_action = mappo_action[:MAX_ACTION_LEN]
if not _orchestrator._use_live:
return JSONResponse(
{"live": False, "reason": "No live debate backend configured (set HF_TOKEN for trained model or switch backend). /stream/debate will serve canned.",
"subscribe": f"/stream/debate?force_canned=true&max_rounds={max_rounds}&crisis_type={crisis_type}"},
status_code=200,
)
return {
"live": True,
"subscribe": f"/stream/debate?force_canned=false&max_rounds={max_rounds}&crisis_type={crisis_type}",
}
# ── Country / Company P&L streams ───────────────────────────────────────
_SCRIPTED_COUNTRY_TICKS = [
{"at": 2, "countryId": "USA", "deltas": {"gdp": -0.02, "welfare": -0.01}},
{"at": 5, "countryId": "CHN", "deltas": {"gdp": -0.01, "influence": 0.015}},
{"at": 8, "countryId": "RUS", "deltas": {"gdp": 0.005, "military": 0.03}},
{"at": 11, "countryId": "IND", "deltas": {"gdp": 0.02, "welfare": 0.018}},
{"at": 17, "countryId": "SAU", "deltas": {"gdp": -0.015, "energy": 0.02}},
{"at": 20, "countryId": "UN", "deltas": {"heritage": 0.04}},
{"at": 25, "countryId": "USA", "deltas": {"gdp": -0.01, "influence": -0.02}},
{"at": 34, "countryId": "IND", "deltas": {"gdp": 0.03, "influence": 0.025}},
{"at": 45, "countryId": "IND", "deltas": {"welfare": 0.04, "heritage": 0.02}},
]
_SCRIPTED_COMPANY_TICKS = [
{"at": 5, "symbol": "AAPL", "price": 189.32, "pct": 0.8},
{"at": 5, "symbol": "BYDDY", "price": 215.40, "pct": -0.6},
{"at": 12, "symbol": "GAZP", "price": 139.20, "pct": -3.8},
{"at": 12, "symbol": "RELI", "price": 2860.00, "pct": 0.9},
{"at": 20, "symbol": "2222", "price": 33.10, "pct": 2.2},
{"at": 30, "symbol": "GAZP", "price": 136.00, "pct": -5.1},
{"at": 30, "symbol": "KOMID", "price": 86.50, "pct": -2.2},
{"at": 40, "symbol": "AAPL", "price": 190.50, "pct": 1.4},
{"at": 40, "symbol": "2222", "price": 33.80, "pct": 3.5},
]
async def _pnl_stream(entries: list[dict], tick_ms: int = 800) -> AsyncIterator[str]:
step = 0
max_step = max(e["at"] for e in entries) + 2
while step <= max_step:
step += 1
emitted = [e for e in entries if e["at"] == step]
for ev in emitted:
payload = {**ev, "step": step, "ts": datetime.now(timezone.utc).isoformat()}
yield _sse(payload, event="pnl_tick")
await asyncio.sleep(tick_ms / 1000)
yield _sse({"step": step, "done": True}, event="pnl_end")
@app.get("/stream/country-pnl")
async def stream_country_pnl(tick_ms: int = Query(800, ge=50, le=5000)):
return StreamingResponse(
_pnl_stream(_SCRIPTED_COUNTRY_TICKS, tick_ms),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
def _build_company_ticks_with_live() -> list[dict]:
"""P3: keep the scripted *cadence* of _SCRIPTED_COMPANY_TICKS but overwrite
each tick's `price` + `pct` with the live yfinance snapshot when available.
Result: the existing SSE stream the SPA already consumes ships LIVE prices
on a deterministic schedule. Falls back to the scripted constants if the
market layer is missing or returned no live data for that symbol.
"""
if not _MARKET_DATA_OK:
return [{**t, "_demo": True, "live": False} for t in _SCRIPTED_COMPANY_TICKS]
try:
live = {c["symbol"]: c for c in get_company_prices() or []}
except Exception:
return [{**t, "_demo": True, "live": False} for t in _SCRIPTED_COMPANY_TICKS]
out = []
for tick in _SCRIPTED_COMPANY_TICKS:
sym = tick["symbol"]
live_row = live.get(sym, {})
merged = {**tick}
if live_row.get("live"):
merged["price"] = live_row["price"]
merged["pct"] = live_row["pct"]
merged["live"] = True
merged["_demo"] = False
else:
merged["live"] = False
merged["_demo"] = True
out.append(merged)
return out
@app.get("/stream/company-pnl")
async def stream_company_pnl(tick_ms: int = Query(800, ge=50, le=5000)):
return StreamingResponse(
_pnl_stream(_build_company_ticks_with_live(), tick_ms),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# ── Static frontend (same-origin serve for HF Spaces single-container) ──
@app.api_route("/", methods=["GET", "HEAD"], include_in_schema=False)
def root_index():
if INDEX_HTML.exists():
return FileResponse(INDEX_HTML, media_type="text/html")
raise HTTPException(404, "index HTML missing")
_STATIC_WHITELIST = {".css", ".jsx", ".js", ".json", ".md", ".png", ".jpg", ".svg", ".ico"}
_MEDIA_TYPES = {
".jsx": "text/babel",
".js": "application/javascript",
".css": "text/css",
".md": "text/markdown",
".json": "application/json",
".png": "image/png",
".jpg": "image/jpeg",
".svg": "image/svg+xml",
".ico": "image/x-icon",
}
@app.api_route("/{fname:path}", methods=["GET", "HEAD"], include_in_schema=False)
def serve_static(fname: str):
"""Serve project root files (CSS, JSX, personas/*) behind an extension whitelist.
V1 FIX: Uses Path.resolve() + is_relative_to() to prevent path traversal.
Rejects symlinks pointing outside ROOT.
"""
if not fname or fname.startswith("/"):
raise HTTPException(400, "invalid path")
# V1: Resolve to real path and verify containment
target = (ROOT / fname).resolve()
if not target.is_relative_to(ROOT):
raise HTTPException(403, "access denied")
if not target.exists() or not target.is_file():
raise HTTPException(404, "not found")
if target.suffix.lower() not in _STATIC_WHITELIST:
raise HTTPException(403, f"type not served: {target.suffix}")
media = _MEDIA_TYPES.get(target.suffix.lower(), "application/octet-stream")
return FileResponse(target, media_type=media)
# ── CLI entry ────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False)