import asyncio import json import os import random import time import json from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional from fastapi import FastAPI, Header, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse # ---------------------------- # Config # ---------------------------- with open("default_scenario.json", "r", encoding="utf-8") as f: DEFAULT_SCENARIO = json.load(f) TICK_RATE = float(os.getenv("TICK_RATE", "2.0")) # seconds per tick MARKET_LENGTH = int(os.getenv("MARKET_LENGTH", "1900")) # default timeline length MIN_MARKET_LENGTH = int(os.getenv("MIN_MARKET_LENGTH", "600")) # hard minimum START_PRICE = float(os.getenv("START_PRICE", "100.0")) DEFAULT_VOLATILITY = float(os.getenv("DEFAULT_VOLATILITY", "0.8")) # If true, day wraps around at end of market (old behavior). If false, clamps at last day. LOOP_MARKET = os.getenv("LOOP_MARKET", "0").strip().lower() in ("1", "true", "yes", "y") ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") # set as HF Space Secret ADMIN_HEADER = "X-ADMIN-TOKEN" INDEX_FILE = os.getenv("INDEX_FILE", "index.html") ADMIN_FILE = os.getenv("ADMIN_FILE", "admin.html") # optional if you keep admin.html # ---------------------------- # Data models # ---------------------------- @dataclass class ScenarioEvent: day: int shockPct: float = 0.0 # price shock applied to future path from day onward volatility: Optional[float] = None news: Optional[str] = None # ---------------------------- # Market simulator # ---------------------------- class MarketSimulator: def __init__(self, seed: int = 42): self.seed = seed def generate_base_market(self, length: int, start_price: float, vol: float) -> List[Dict[str, float]]: rng = random.Random(self.seed) price = float(start_price) drift = 0.02 series: List[Dict[str, float]] = [] for i in range(length): shock = rng.gauss(0.0, vol) price = max(1.0, price + drift + shock) series.append({"i": i, "close": round(price, 2)}) return series # ---------------------------- # Connection manager # ---------------------------- class ConnectionManager: def __init__(self) -> None: self.active: Dict[str, WebSocket] = {} # client_id -> websocket self.leaderboard: Dict[str, Dict[str, float]] = {} # name -> {equity, roi, ts} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, client_id: str) -> None: await websocket.accept() async with self._lock: self.active[client_id] = websocket async def disconnect(self, client_id: str) -> None: async with self._lock: self.active.pop(client_id, None) async def update_equity(self, name: str, equity: float, roi: float) -> None: now = time.time() async with self._lock: self.leaderboard[name] = {"equity": float(equity), "roi": float(roi), "ts": now} async def _snapshot_leaderboard(self) -> List[Dict[str, Any]]: async with self._lock: entries = [ {"name": n, "equity": v["equity"], "roi": v["roi"], "ts": v.get("ts", 0.0)} for n, v in self.leaderboard.items() ] entries.sort(key=lambda x: x["equity"], reverse=True) for e in entries: e.pop("ts", None) return entries[:50] async def broadcast(self, obj: Dict[str, Any]) -> None: msg = json.dumps(obj) async with self._lock: sockets = list(self.active.items()) stale: List[str] = [] for client_id, ws in sockets: try: await ws.send_text(msg) except Exception: stale.append(client_id) if stale: async with self._lock: for cid in stale: self.active.pop(cid, None) async def broadcast_tick(self, day: int) -> None: await self.broadcast({ "type": "TICK", "payload": { "day": day, "leaderboard": await self._snapshot_leaderboard(), }, }) async def broadcast_news(self, day: int, text: str) -> None: await self.broadcast({"type": "NEWS", "payload": {"day": day, "text": text}}) app = FastAPI(title="MPTrading (FastAPI + WebSocket)") manager = ConnectionManager() sim = MarketSimulator(seed=42) # ---------------------------- # Global game state # ---------------------------- DAY_LOCK = asyncio.Lock() STATE_LOCK = asyncio.Lock() CURRENT_DAY = 0 CURRENT_VOL = DEFAULT_VOLATILITY MARKET: List[Dict[str, float]] = sim.generate_base_market( max(MIN_MARKET_LENGTH, MARKET_LENGTH), START_PRICE, DEFAULT_VOLATILITY ) EVENTS: Dict[int, List[ScenarioEvent]] = {} # ---------------------------- # Helpers # ---------------------------- def require_admin(token: Optional[str]) -> None: if not ADMIN_TOKEN: raise HTTPException(status_code=403, detail="Admin token not configured on server.") if token != ADMIN_TOKEN: raise HTTPException(status_code=401, detail="Invalid admin token.") def parse_event(obj: Dict[str, Any]) -> ScenarioEvent: day = int(obj["day"]) shock = float(obj.get("shockPct", 0.0)) vol = obj.get("volatility", None) vol_f = float(vol) if vol is not None else None news = obj.get("news", None) if news is not None: news = str(news) return ScenarioEvent(day=day, shockPct=shock, volatility=vol_f, news=news) def snapshot_events() -> List[Dict[str, Any]]: out: List[Dict[str, Any]] = [] for d in sorted(EVENTS.keys()): for ev in EVENTS[d]: out.append(asdict(ev)) return out def apply_price_shock_from_day(day: int, shock_pct: float) -> None: if day < 0 or day >= len(MARKET): return factor = 1.0 + (shock_pct / 100.0) for i in range(day, len(MARKET)): MARKET[i]["close"] = round(max(1.0, MARKET[i]["close"] * factor), 2) def regen_market(length: int, start_price: float, vol: float) -> List[Dict[str, float]]: return sim.generate_base_market(length, start_price, vol) # ---------------------------- # Static pages # ---------------------------- @app.get("/") async def root(): return FileResponse(INDEX_FILE) @app.get("/admin") async def admin_page(): if not os.path.exists(ADMIN_FILE): raise HTTPException(status_code=404, detail="admin.html not found in repo root.") return FileResponse(ADMIN_FILE) # ---------------------------- # Admin REST API # ---------------------------- @app.get("/admin/state") async def admin_state(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)): require_admin(x_admin_token) async with DAY_LOCK: day = CURRENT_DAY async with STATE_LOCK: return { "day": day, "tickRate": TICK_RATE, "marketLength": len(MARKET), "currentVolatility": CURRENT_VOL, "events": snapshot_events(), } @app.post("/admin/clear_events") async def admin_clear_events(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)): require_admin(x_admin_token) async with STATE_LOCK: EVENTS.clear() return {"ok": True} @app.post("/admin/add_event") async def admin_add_event(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)): require_admin(x_admin_token) async with DAY_LOCK: cur = CURRENT_DAY if "day" in body: day = int(body["day"]) elif "offset" in body: day = cur + int(body["offset"]) else: raise HTTPException(status_code=400, detail="Provide 'day' or 'offset'.") if day < cur: raise HTTPException(status_code=400, detail=f"Event day {day} is in the past (current day {cur}).") ev = parse_event({**body, "day": day}) async with STATE_LOCK: EVENTS.setdefault(day, []).append(ev) return {"ok": True, "event": asdict(ev)} @app.post("/admin/load_scenario") async def admin_load_scenario(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)): require_admin(x_admin_token) start_day = int(body.get("startDay", 0)) base_price = float(body.get("basePrice", START_PRICE)) default_vol = float(body.get("defaultVolatility", DEFAULT_VOLATILITY)) evs_raw = body.get("events", []) if not isinstance(evs_raw, list): raise HTTPException(status_code=400, detail="'events' must be a list.") evs = [parse_event(e) for e in evs_raw] max_day_in_scenario = max([ev.day for ev in evs], default=0) requested_len = body.get("marketLength", None) if requested_len is None or str(requested_len).strip() == "": desired_len = max(MIN_MARKET_LENGTH, MARKET_LENGTH, max_day_in_scenario + 1) mode = "auto" else: desired_len = max(MIN_MARKET_LENGTH, int(requested_len)) mode = "manual" async with STATE_LOCK: global MARKET, CURRENT_VOL CURRENT_VOL = default_vol MARKET = regen_market(length=desired_len, start_price=base_price, vol=default_vol) EVENTS.clear() for ev in evs: EVENTS.setdefault(ev.day, []).append(ev) async with DAY_LOCK: global CURRENT_DAY CURRENT_DAY = max(0, min(start_day, len(MARKET) - 1)) return { "ok": True, "startDay": CURRENT_DAY, "eventsLoaded": len(evs), "marketLength": len(MARKET), "marketLengthMode": mode, } # ---------------------------- # WebSocket endpoint # ---------------------------- @app.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: str): await manager.connect(websocket, client_id) async with DAY_LOCK: day0 = CURRENT_DAY async with STATE_LOCK: init_payload = {"market": MARKET, "startDay": day0} await websocket.send_text(json.dumps({"type": "INIT", "payload": init_payload})) try: while True: raw = await websocket.receive_text() try: data = json.loads(raw) except Exception: continue msg_type = data.get("type") payload = data.get("payload") or {} if msg_type == "UPDATE_EQUITY": name = str(payload.get("name", client_id)) try: equity_f = float(payload.get("equity", 0.0)) except Exception: equity_f = 0.0 try: roi_f = float(payload.get("roi", 0.0)) except Exception: roi_f = 0.0 await manager.update_equity(name=name, equity=equity_f, roi=roi_f) except WebSocketDisconnect: await manager.disconnect(client_id) except Exception: await manager.disconnect(client_id) # ---------------------------- # Background tick loop # ---------------------------- async def game_loop(): global CURRENT_DAY, CURRENT_VOL while True: await asyncio.sleep(TICK_RATE) async with DAY_LOCK: if LOOP_MARKET: CURRENT_DAY = (CURRENT_DAY + 1) % len(MARKET) else: CURRENT_DAY = min(CURRENT_DAY + 1, len(MARKET) - 1) day = CURRENT_DAY news_to_broadcast: List[str] = [] async with STATE_LOCK: if day in EVENTS and EVENTS[day]: for ev in EVENTS[day]: if ev.shockPct: apply_price_shock_from_day(day, ev.shockPct) if ev.volatility is not None: CURRENT_VOL = float(ev.volatility) if ev.news: news_to_broadcast.append(ev.news) for text in news_to_broadcast: await manager.broadcast_news(day, text) await manager.broadcast_tick(day) @app.on_event("startup") async def on_startup(): global MARKET, EVENTS, CURRENT_DAY, CURRENT_VOL # ---- Load default scenario file (Option B) ---- scenario_path = os.getenv("DEFAULT_SCENARIO_FILE", "default_scenario.json") if os.path.exists(scenario_path): with open(scenario_path, "r", encoding="utf-8") as f: scn = json.load(f) start_day = int(scn.get("startDay", 0)) base_price = float(scn.get("basePrice", START_PRICE)) default_vol = float(scn.get("defaultVolatility", DEFAULT_VOLATILITY)) evs_raw = scn.get("events", []) if not isinstance(evs_raw, list): evs_raw = [] evs = [parse_event(e) for e in evs_raw] max_day_in_scenario = max((ev.day for ev in evs), default=0) # If scenario provides marketLength, use it as a hint; always ensure it's big enough. requested_len = scn.get("marketLength", None) if requested_len is None or str(requested_len).strip() == "": desired_len = max(MIN_MARKET_LENGTH, MARKET_LENGTH, max_day_in_scenario + 1) else: desired_len = max(MIN_MARKET_LENGTH, int(requested_len), max_day_in_scenario + 1) async with STATE_LOCK: CURRENT_VOL = default_vol MARKET = regen_market(length=desired_len, start_price=base_price, vol=default_vol) EVENTS.clear() for ev in evs: EVENTS.setdefault(ev.day, []).append(ev) async with DAY_LOCK: CURRENT_DAY = max(0, min(start_day, len(MARKET) - 1)) # ---- Start the tick loop ---- asyncio.create_task(game_loop())