MPTrading / app.py
Almaatla's picture
Update app.py
3d72991 verified
import asyncio
import json
import os
import random
import time
import json
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, Header, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
# ----------------------------
# Config
# ----------------------------
with open("default_scenario.json", "r", encoding="utf-8") as f:
DEFAULT_SCENARIO = json.load(f)
TICK_RATE = float(os.getenv("TICK_RATE", "2.0")) # seconds per tick
MARKET_LENGTH = int(os.getenv("MARKET_LENGTH", "1900")) # default timeline length
MIN_MARKET_LENGTH = int(os.getenv("MIN_MARKET_LENGTH", "600")) # hard minimum
START_PRICE = float(os.getenv("START_PRICE", "100.0"))
DEFAULT_VOLATILITY = float(os.getenv("DEFAULT_VOLATILITY", "0.8"))
# If true, day wraps around at end of market (old behavior). If false, clamps at last day.
LOOP_MARKET = os.getenv("LOOP_MARKET", "0").strip().lower() in ("1", "true", "yes", "y")
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") # set as HF Space Secret
ADMIN_HEADER = "X-ADMIN-TOKEN"
INDEX_FILE = os.getenv("INDEX_FILE", "index.html")
ADMIN_FILE = os.getenv("ADMIN_FILE", "admin.html") # optional if you keep admin.html
# ----------------------------
# Data models
# ----------------------------
@dataclass
class ScenarioEvent:
day: int
shockPct: float = 0.0 # price shock applied to future path from day onward
volatility: Optional[float] = None
news: Optional[str] = None
# ----------------------------
# Market simulator
# ----------------------------
class MarketSimulator:
def __init__(self, seed: int = 42):
self.seed = seed
def generate_base_market(self, length: int, start_price: float, vol: float) -> List[Dict[str, float]]:
rng = random.Random(self.seed)
price = float(start_price)
drift = 0.02
series: List[Dict[str, float]] = []
for i in range(length):
shock = rng.gauss(0.0, vol)
price = max(1.0, price + drift + shock)
series.append({"i": i, "close": round(price, 2)})
return series
# ----------------------------
# Connection manager
# ----------------------------
class ConnectionManager:
def __init__(self) -> None:
self.active: Dict[str, WebSocket] = {} # client_id -> websocket
self.leaderboard: Dict[str, Dict[str, float]] = {} # name -> {equity, roi, ts}
self._lock = asyncio.Lock()
async def connect(self, websocket: WebSocket, client_id: str) -> None:
await websocket.accept()
async with self._lock:
self.active[client_id] = websocket
async def disconnect(self, client_id: str) -> None:
async with self._lock:
self.active.pop(client_id, None)
async def update_equity(self, name: str, equity: float, roi: float) -> None:
now = time.time()
async with self._lock:
self.leaderboard[name] = {"equity": float(equity), "roi": float(roi), "ts": now}
async def _snapshot_leaderboard(self) -> List[Dict[str, Any]]:
async with self._lock:
entries = [
{"name": n, "equity": v["equity"], "roi": v["roi"], "ts": v.get("ts", 0.0)}
for n, v in self.leaderboard.items()
]
entries.sort(key=lambda x: x["equity"], reverse=True)
for e in entries:
e.pop("ts", None)
return entries[:50]
async def broadcast(self, obj: Dict[str, Any]) -> None:
msg = json.dumps(obj)
async with self._lock:
sockets = list(self.active.items())
stale: List[str] = []
for client_id, ws in sockets:
try:
await ws.send_text(msg)
except Exception:
stale.append(client_id)
if stale:
async with self._lock:
for cid in stale:
self.active.pop(cid, None)
async def broadcast_tick(self, day: int) -> None:
await self.broadcast({
"type": "TICK",
"payload": {
"day": day,
"leaderboard": await self._snapshot_leaderboard(),
},
})
async def broadcast_news(self, day: int, text: str) -> None:
await self.broadcast({"type": "NEWS", "payload": {"day": day, "text": text}})
app = FastAPI(title="MPTrading (FastAPI + WebSocket)")
manager = ConnectionManager()
sim = MarketSimulator(seed=42)
# ----------------------------
# Global game state
# ----------------------------
DAY_LOCK = asyncio.Lock()
STATE_LOCK = asyncio.Lock()
CURRENT_DAY = 0
CURRENT_VOL = DEFAULT_VOLATILITY
MARKET: List[Dict[str, float]] = sim.generate_base_market(
max(MIN_MARKET_LENGTH, MARKET_LENGTH),
START_PRICE,
DEFAULT_VOLATILITY
)
EVENTS: Dict[int, List[ScenarioEvent]] = {}
# ----------------------------
# Helpers
# ----------------------------
def require_admin(token: Optional[str]) -> None:
if not ADMIN_TOKEN:
raise HTTPException(status_code=403, detail="Admin token not configured on server.")
if token != ADMIN_TOKEN:
raise HTTPException(status_code=401, detail="Invalid admin token.")
def parse_event(obj: Dict[str, Any]) -> ScenarioEvent:
day = int(obj["day"])
shock = float(obj.get("shockPct", 0.0))
vol = obj.get("volatility", None)
vol_f = float(vol) if vol is not None else None
news = obj.get("news", None)
if news is not None:
news = str(news)
return ScenarioEvent(day=day, shockPct=shock, volatility=vol_f, news=news)
def snapshot_events() -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for d in sorted(EVENTS.keys()):
for ev in EVENTS[d]:
out.append(asdict(ev))
return out
def apply_price_shock_from_day(day: int, shock_pct: float) -> None:
if day < 0 or day >= len(MARKET):
return
factor = 1.0 + (shock_pct / 100.0)
for i in range(day, len(MARKET)):
MARKET[i]["close"] = round(max(1.0, MARKET[i]["close"] * factor), 2)
def regen_market(length: int, start_price: float, vol: float) -> List[Dict[str, float]]:
return sim.generate_base_market(length, start_price, vol)
# ----------------------------
# Static pages
# ----------------------------
@app.get("/")
async def root():
return FileResponse(INDEX_FILE)
@app.get("/admin")
async def admin_page():
if not os.path.exists(ADMIN_FILE):
raise HTTPException(status_code=404, detail="admin.html not found in repo root.")
return FileResponse(ADMIN_FILE)
# ----------------------------
# Admin REST API
# ----------------------------
@app.get("/admin/state")
async def admin_state(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
require_admin(x_admin_token)
async with DAY_LOCK:
day = CURRENT_DAY
async with STATE_LOCK:
return {
"day": day,
"tickRate": TICK_RATE,
"marketLength": len(MARKET),
"currentVolatility": CURRENT_VOL,
"events": snapshot_events(),
}
@app.post("/admin/clear_events")
async def admin_clear_events(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
require_admin(x_admin_token)
async with STATE_LOCK:
EVENTS.clear()
return {"ok": True}
@app.post("/admin/add_event")
async def admin_add_event(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
require_admin(x_admin_token)
async with DAY_LOCK:
cur = CURRENT_DAY
if "day" in body:
day = int(body["day"])
elif "offset" in body:
day = cur + int(body["offset"])
else:
raise HTTPException(status_code=400, detail="Provide 'day' or 'offset'.")
if day < cur:
raise HTTPException(status_code=400, detail=f"Event day {day} is in the past (current day {cur}).")
ev = parse_event({**body, "day": day})
async with STATE_LOCK:
EVENTS.setdefault(day, []).append(ev)
return {"ok": True, "event": asdict(ev)}
@app.post("/admin/load_scenario")
async def admin_load_scenario(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
require_admin(x_admin_token)
start_day = int(body.get("startDay", 0))
base_price = float(body.get("basePrice", START_PRICE))
default_vol = float(body.get("defaultVolatility", DEFAULT_VOLATILITY))
evs_raw = body.get("events", [])
if not isinstance(evs_raw, list):
raise HTTPException(status_code=400, detail="'events' must be a list.")
evs = [parse_event(e) for e in evs_raw]
max_day_in_scenario = max([ev.day for ev in evs], default=0)
requested_len = body.get("marketLength", None)
if requested_len is None or str(requested_len).strip() == "":
desired_len = max(MIN_MARKET_LENGTH, MARKET_LENGTH, max_day_in_scenario + 1)
mode = "auto"
else:
desired_len = max(MIN_MARKET_LENGTH, int(requested_len))
mode = "manual"
async with STATE_LOCK:
global MARKET, CURRENT_VOL
CURRENT_VOL = default_vol
MARKET = regen_market(length=desired_len, start_price=base_price, vol=default_vol)
EVENTS.clear()
for ev in evs:
EVENTS.setdefault(ev.day, []).append(ev)
async with DAY_LOCK:
global CURRENT_DAY
CURRENT_DAY = max(0, min(start_day, len(MARKET) - 1))
return {
"ok": True,
"startDay": CURRENT_DAY,
"eventsLoaded": len(evs),
"marketLength": len(MARKET),
"marketLengthMode": mode,
}
# ----------------------------
# WebSocket endpoint
# ----------------------------
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
async with DAY_LOCK:
day0 = CURRENT_DAY
async with STATE_LOCK:
init_payload = {"market": MARKET, "startDay": day0}
await websocket.send_text(json.dumps({"type": "INIT", "payload": init_payload}))
try:
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except Exception:
continue
msg_type = data.get("type")
payload = data.get("payload") or {}
if msg_type == "UPDATE_EQUITY":
name = str(payload.get("name", client_id))
try:
equity_f = float(payload.get("equity", 0.0))
except Exception:
equity_f = 0.0
try:
roi_f = float(payload.get("roi", 0.0))
except Exception:
roi_f = 0.0
await manager.update_equity(name=name, equity=equity_f, roi=roi_f)
except WebSocketDisconnect:
await manager.disconnect(client_id)
except Exception:
await manager.disconnect(client_id)
# ----------------------------
# Background tick loop
# ----------------------------
async def game_loop():
global CURRENT_DAY, CURRENT_VOL
while True:
await asyncio.sleep(TICK_RATE)
async with DAY_LOCK:
if LOOP_MARKET:
CURRENT_DAY = (CURRENT_DAY + 1) % len(MARKET)
else:
CURRENT_DAY = min(CURRENT_DAY + 1, len(MARKET) - 1)
day = CURRENT_DAY
news_to_broadcast: List[str] = []
async with STATE_LOCK:
if day in EVENTS and EVENTS[day]:
for ev in EVENTS[day]:
if ev.shockPct:
apply_price_shock_from_day(day, ev.shockPct)
if ev.volatility is not None:
CURRENT_VOL = float(ev.volatility)
if ev.news:
news_to_broadcast.append(ev.news)
for text in news_to_broadcast:
await manager.broadcast_news(day, text)
await manager.broadcast_tick(day)
@app.on_event("startup")
async def on_startup():
global MARKET, EVENTS, CURRENT_DAY, CURRENT_VOL
# ---- Load default scenario file (Option B) ----
scenario_path = os.getenv("DEFAULT_SCENARIO_FILE", "default_scenario.json")
if os.path.exists(scenario_path):
with open(scenario_path, "r", encoding="utf-8") as f:
scn = json.load(f)
start_day = int(scn.get("startDay", 0))
base_price = float(scn.get("basePrice", START_PRICE))
default_vol = float(scn.get("defaultVolatility", DEFAULT_VOLATILITY))
evs_raw = scn.get("events", [])
if not isinstance(evs_raw, list):
evs_raw = []
evs = [parse_event(e) for e in evs_raw]
max_day_in_scenario = max((ev.day for ev in evs), default=0)
# If scenario provides marketLength, use it as a hint; always ensure it's big enough.
requested_len = scn.get("marketLength", None)
if requested_len is None or str(requested_len).strip() == "":
desired_len = max(MIN_MARKET_LENGTH, MARKET_LENGTH, max_day_in_scenario + 1)
else:
desired_len = max(MIN_MARKET_LENGTH, int(requested_len), max_day_in_scenario + 1)
async with STATE_LOCK:
CURRENT_VOL = default_vol
MARKET = regen_market(length=desired_len, start_price=base_price, vol=default_vol)
EVENTS.clear()
for ev in evs:
EVENTS.setdefault(ev.day, []).append(ev)
async with DAY_LOCK:
CURRENT_DAY = max(0, min(start_day, len(MARKET) - 1))
# ---- Start the tick loop ----
asyncio.create_task(game_loop())