Almaatla commited on
Commit
2379cd3
·
verified ·
1 Parent(s): e1bf013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -62
app.py CHANGED
@@ -3,41 +3,66 @@ import json
3
  import os
4
  import random
5
  import time
6
- from typing import Dict, List, Optional, Any
 
7
 
8
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
  from fastapi.responses import FileResponse
10
- from fastapi.staticfiles import StaticFiles
11
 
12
  # ----------------------------
13
  # Config
14
  # ----------------------------
15
- TICK_RATE = float(os.getenv("TICK_RATE", "1.0")) # seconds
16
- MARKET_LENGTH = int(os.getenv("MARKET_LENGTH", "300")) # number of "days"/ticks
17
  START_PRICE = float(os.getenv("START_PRICE", "100.0"))
18
 
19
- # If your frontend file is named differently, adjust here.
 
 
20
  INDEX_FILE = os.getenv("INDEX_FILE", "index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- app = FastAPI(title="Trading Game (WebSocket)")
23
 
24
  # ----------------------------
25
  # Market simulator
26
  # ----------------------------
27
  class MarketSimulator:
28
- def __init__(self, seed: int = 42, start_price: float = 100.0):
29
  self.seed = seed
30
  self.start_price = start_price
 
31
 
32
- def _generate_market(self, length: int) -> List[Dict[str, float]]:
33
  rng = random.Random(self.seed)
34
  price = self.start_price
35
  series: List[Dict[str, float]] = []
 
36
 
37
- # Simple random walk with gaussian noise, clipped to avoid negative prices.
38
  for i in range(length):
39
- drift = 0.02 # small upward drift per step
40
- shock = rng.gauss(0.0, 0.8)
41
  price = max(1.0, price + drift + shock)
42
  series.append({"i": i, "close": round(price, 2)})
43
 
@@ -68,30 +93,18 @@ class ConnectionManager:
68
  self.leaderboard[name] = {"equity": float(equity), "roi": float(roi), "ts": now}
69
 
70
  async def _snapshot_leaderboard(self) -> List[Dict[str, Any]]:
71
- # Sort by equity desc, show top N.
72
  async with self._lock:
73
  entries = [
74
  {"name": n, "equity": v["equity"], "roi": v["roi"], "ts": v.get("ts", 0.0)}
75
  for n, v in self.leaderboard.items()
76
  ]
77
-
78
  entries.sort(key=lambda x: x["equity"], reverse=True)
79
- # Remove ts before sending to client
80
  for e in entries:
81
  e.pop("ts", None)
82
  return entries[:50]
83
 
84
- async def broadcast_tick(self, day: int) -> None:
85
- payload = {
86
- "type": "TICK",
87
- "payload": {
88
- "day": day,
89
- "leaderboard": await self._snapshot_leaderboard(),
90
- },
91
- }
92
-
93
- msg = json.dumps(payload)
94
- # Copy sockets to avoid holding the lock while sending
95
  async with self._lock:
96
  sockets = list(self.active.items())
97
 
@@ -102,40 +115,188 @@ class ConnectionManager:
102
  except Exception:
103
  stale.append(client_id)
104
 
105
- # Clean up dead connections
106
  if stale:
107
  async with self._lock:
108
  for cid in stale:
109
  self.active.pop(cid, None)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
112
  manager = ConnectionManager()
113
- market_sim = MarketSimulator(seed=42, start_price=START_PRICE)
114
- MARKET = market_sim._generate_market(MARKET_LENGTH)
115
 
116
- # Global game clock
117
- CURRENT_DAY = 0
 
118
  DAY_LOCK = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # ----------------------------
121
- # Static files (frontend)
122
  # ----------------------------
123
- # Put your index.html + assets in the Space repo root or a ./static folder.
124
- # This configuration supports both patterns:
125
- # - If you have ./static, it will be mounted.
126
- # - Otherwise, root file serving will still work via "/" -> index.html
127
- if os.path.isdir("static"):
128
- app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
 
 
 
131
  @app.get("/")
132
  async def root():
133
- # If you keep index.html in ./static, serve it from there; else from repo root.
134
- if os.path.exists(os.path.join("static", INDEX_FILE)):
135
- return FileResponse(os.path.join("static", INDEX_FILE))
136
  return FileResponse(INDEX_FILE)
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # ----------------------------
140
  # WebSocket endpoint
141
  # ----------------------------
@@ -144,14 +305,9 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
144
  await manager.connect(websocket, client_id)
145
 
146
  # Send INIT immediately
147
- init_msg = {
148
- "type": "INIT",
149
- "payload": {
150
- "market": MARKET,
151
- "startDay": 0,
152
- },
153
- }
154
- await websocket.send_text(json.dumps(init_msg))
155
 
156
  try:
157
  while True:
@@ -159,7 +315,6 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
159
  try:
160
  data = json.loads(raw)
161
  except Exception:
162
- # Ignore invalid JSON
163
  continue
164
 
165
  msg_type = data.get("type")
@@ -167,28 +322,20 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
167
 
168
  if msg_type == "UPDATE_EQUITY":
169
  name = str(payload.get("name", client_id))
170
- equity = payload.get("equity", 0.0)
171
- roi = payload.get("roi", 0.0)
172
-
173
- # Defensive parsing
174
  try:
175
- equity_f = float(equity)
176
  except Exception:
177
  equity_f = 0.0
178
  try:
179
- roi_f = float(roi)
180
  except Exception:
181
  roi_f = 0.0
182
 
183
  await manager.update_equity(name=name, equity=equity_f, roi=roi_f)
184
- else:
185
- # Unknown message types are ignored to keep protocol forward-compatible
186
- continue
187
 
188
  except WebSocketDisconnect:
189
  await manager.disconnect(client_id)
190
  except Exception:
191
- # Any other error: drop connection cleanly
192
  await manager.disconnect(client_id)
193
 
194
 
@@ -196,16 +343,37 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):
196
  # Background tick loop
197
  # ----------------------------
198
  async def game_loop():
199
- global CURRENT_DAY
 
200
  while True:
201
  await asyncio.sleep(TICK_RATE)
 
 
202
  async with DAY_LOCK:
203
  CURRENT_DAY = (CURRENT_DAY + 1) % len(MARKET)
204
  day = CURRENT_DAY
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  await manager.broadcast_tick(day)
206
 
207
 
208
  @app.on_event("startup")
209
  async def on_startup():
210
- # Start ticker loop once.
211
  asyncio.create_task(game_loop())
 
3
  import os
4
  import random
5
  import time
6
+ from dataclasses import dataclass, asdict
7
+ from typing import Any, Dict, List, Optional, Tuple
8
 
9
+ from fastapi import FastAPI, Header, HTTPException, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import FileResponse
11
+
12
 
13
  # ----------------------------
14
  # Config
15
  # ----------------------------
16
+ TICK_RATE = float(os.getenv("TICK_RATE", "1.0")) # seconds
17
+ MARKET_LENGTH = int(os.getenv("MARKET_LENGTH", "300")) # number of ticks/days
18
  START_PRICE = float(os.getenv("START_PRICE", "100.0"))
19
 
20
+ ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") # set as HF Space Secret
21
+ ADMIN_HEADER = "X-ADMIN-TOKEN"
22
+
23
  INDEX_FILE = os.getenv("INDEX_FILE", "index.html")
24
+ ADMIN_FILE = os.getenv("ADMIN_FILE", "admin.html")
25
+
26
+ INITIAL_VOLATILITY = float(os.getenv("DEFAULT_VOLATILITY", "0.8"))
27
+
28
+
29
+ # ----------------------------
30
+ # Data models
31
+ # ----------------------------
32
+ @dataclass
33
+ class ScenarioEvent:
34
+ day: int
35
+ shockPct: float = 0.0 # e.g. +5.0 means +5%
36
+ volatility: Optional[float] = None
37
+ news: Optional[str] = None
38
+
39
+
40
+ @dataclass
41
+ class Scenario:
42
+ name: str = "default"
43
+ startDay: int = 0
44
+ basePrice: float = START_PRICE
45
+ defaultVolatility: float = INITIAL_VOLATILITY
46
+ events: List[ScenarioEvent] = None
47
 
 
48
 
49
  # ----------------------------
50
  # Market simulator
51
  # ----------------------------
52
  class MarketSimulator:
53
+ def __init__(self, seed: int = 42, start_price: float = 100.0, base_vol: float = 0.8):
54
  self.seed = seed
55
  self.start_price = start_price
56
+ self.base_vol = base_vol
57
 
58
+ def generate_base_market(self, length: int) -> List[Dict[str, float]]:
59
  rng = random.Random(self.seed)
60
  price = self.start_price
61
  series: List[Dict[str, float]] = []
62
+ drift = 0.02
63
 
 
64
  for i in range(length):
65
+ shock = rng.gauss(0.0, self.base_vol)
 
66
  price = max(1.0, price + drift + shock)
67
  series.append({"i": i, "close": round(price, 2)})
68
 
 
93
  self.leaderboard[name] = {"equity": float(equity), "roi": float(roi), "ts": now}
94
 
95
  async def _snapshot_leaderboard(self) -> List[Dict[str, Any]]:
 
96
  async with self._lock:
97
  entries = [
98
  {"name": n, "equity": v["equity"], "roi": v["roi"], "ts": v.get("ts", 0.0)}
99
  for n, v in self.leaderboard.items()
100
  ]
 
101
  entries.sort(key=lambda x: x["equity"], reverse=True)
 
102
  for e in entries:
103
  e.pop("ts", None)
104
  return entries[:50]
105
 
106
+ async def broadcast(self, obj: Dict[str, Any]) -> None:
107
+ msg = json.dumps(obj)
 
 
 
 
 
 
 
 
 
108
  async with self._lock:
109
  sockets = list(self.active.items())
110
 
 
115
  except Exception:
116
  stale.append(client_id)
117
 
 
118
  if stale:
119
  async with self._lock:
120
  for cid in stale:
121
  self.active.pop(cid, None)
122
 
123
+ async def broadcast_tick(self, day: int) -> None:
124
+ payload = {
125
+ "type": "TICK",
126
+ "payload": {
127
+ "day": day,
128
+ "leaderboard": await self._snapshot_leaderboard(),
129
+ },
130
+ }
131
+ await self.broadcast(payload)
132
+
133
+ async def broadcast_news(self, day: int, text: str) -> None:
134
+ payload = {"type": "NEWS", "payload": {"day": day, "text": text}}
135
+ await self.broadcast(payload)
136
 
137
+
138
+ app = FastAPI(title="Trading Game (WebSocket + Admin)")
139
  manager = ConnectionManager()
 
 
140
 
141
+ # ----------------------------
142
+ # Global game state
143
+ # ----------------------------
144
  DAY_LOCK = asyncio.Lock()
145
+ STATE_LOCK = asyncio.Lock() # protects scenario/events/volatility/market
146
+
147
+ CURRENT_DAY = 0
148
+ CURRENT_VOL = INITIAL_VOLATILITY
149
+
150
+ # Market is the "timeline" clients receive on INIT; we will adjust it when events happen.
151
+ market_sim = MarketSimulator(seed=42, start_price=START_PRICE, base_vol=INITIAL_VOLATILITY)
152
+ MARKET: List[Dict[str, float]] = market_sim.generate_base_market(MARKET_LENGTH)
153
+
154
+ # Scheduled events by day
155
+ EVENTS: Dict[int, List[ScenarioEvent]] = {} # day -> [ScenarioEvent,...]
156
+
157
 
158
  # ----------------------------
159
+ # Helpers
160
  # ----------------------------
161
+ def require_admin(token: Optional[str]) -> None:
162
+ if not ADMIN_TOKEN:
163
+ # If not set, keep endpoints locked down by default.
164
+ raise HTTPException(status_code=403, detail="Admin token not configured on server.")
165
+ if token != ADMIN_TOKEN:
166
+ raise HTTPException(status_code=401, detail="Invalid admin token.")
167
+
168
+
169
+ def parse_event(obj: Dict[str, Any]) -> ScenarioEvent:
170
+ day = int(obj["day"])
171
+ shock = float(obj.get("shockPct", 0.0))
172
+ vol = obj.get("volatility", None)
173
+ vol_f = float(vol) if vol is not None else None
174
+ news = obj.get("news", None)
175
+ if news is not None:
176
+ news = str(news)
177
+ return ScenarioEvent(day=day, shockPct=shock, volatility=vol_f, news=news)
178
+
179
+
180
+ def snapshot_events() -> List[Dict[str, Any]]:
181
+ out: List[Dict[str, Any]] = []
182
+ for d in sorted(EVENTS.keys()):
183
+ for ev in EVENTS[d]:
184
+ out.append(asdict(ev))
185
+ return out
186
+
187
+
188
+ def apply_price_shock_from_day(day: int, shock_pct: float) -> None:
189
+ """
190
+ Applies a multiplicative factor to MARKET[day:] so the event shifts the future path.
191
+ This matches the "future path shift" interpretation.
192
+ """
193
+ if day < 0 or day >= len(MARKET):
194
+ return
195
+ factor = 1.0 + (shock_pct / 100.0)
196
+ for i in range(day, len(MARKET)):
197
+ MARKET[i]["close"] = round(max(1.0, MARKET[i]["close"] * factor), 2)
198
+
199
+
200
+ def regen_market_with_volatility(seed: int, start_price: float, base_vol: float) -> List[Dict[str, float]]:
201
+ sim = MarketSimulator(seed=seed, start_price=start_price, base_vol=base_vol)
202
+ return sim.generate_base_market(MARKET_LENGTH)
203
 
204
 
205
+ # ----------------------------
206
+ # Static pages
207
+ # ----------------------------
208
  @app.get("/")
209
  async def root():
 
 
 
210
  return FileResponse(INDEX_FILE)
211
 
212
 
213
+ @app.get("/admin")
214
+ async def admin_page():
215
+ return FileResponse(ADMIN_FILE)
216
+
217
+
218
+ # ----------------------------
219
+ # Admin REST API
220
+ # ----------------------------
221
+ @app.get("/admin/state")
222
+ async def admin_state(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
223
+ require_admin(x_admin_token)
224
+ async with DAY_LOCK:
225
+ day = CURRENT_DAY
226
+ async with STATE_LOCK:
227
+ return {
228
+ "day": day,
229
+ "tickRate": TICK_RATE,
230
+ "marketLength": len(MARKET),
231
+ "currentVolatility": CURRENT_VOL,
232
+ "events": snapshot_events(),
233
+ }
234
+
235
+
236
+ @app.post("/admin/clear_events")
237
+ async def admin_clear_events(x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
238
+ require_admin(x_admin_token)
239
+ async with STATE_LOCK:
240
+ EVENTS.clear()
241
+ return {"ok": True}
242
+
243
+
244
+ @app.post("/admin/add_event")
245
+ async def admin_add_event(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
246
+ require_admin(x_admin_token)
247
+
248
+ # Supports either "day" (absolute) or "offset" (relative to current day)
249
+ async with DAY_LOCK:
250
+ cur = CURRENT_DAY
251
+
252
+ if "day" in body:
253
+ day = int(body["day"])
254
+ elif "offset" in body:
255
+ day = cur + int(body["offset"])
256
+ else:
257
+ raise HTTPException(status_code=400, detail="Provide 'day' or 'offset'.")
258
+
259
+ if day < cur:
260
+ raise HTTPException(status_code=400, detail=f"Event day {day} is in the past (current day {cur}).")
261
+
262
+ ev = parse_event({**body, "day": day})
263
+
264
+ async with STATE_LOCK:
265
+ EVENTS.setdefault(day, []).append(ev)
266
+
267
+ return {"ok": True, "event": asdict(ev)}
268
+
269
+
270
+ @app.post("/admin/load_scenario")
271
+ async def admin_load_scenario(body: Dict[str, Any], x_admin_token: Optional[str] = Header(default=None, alias=ADMIN_HEADER)):
272
+ require_admin(x_admin_token)
273
+
274
+ name = str(body.get("name", "scenario"))
275
+ start_day = int(body.get("startDay", 0))
276
+ base_price = float(body.get("basePrice", START_PRICE))
277
+ default_vol = float(body.get("defaultVolatility", INITIAL_VOLATILITY))
278
+ evs_raw = body.get("events", [])
279
+ if not isinstance(evs_raw, list):
280
+ raise HTTPException(status_code=400, detail="'events' must be a list.")
281
+
282
+ evs = [parse_event(e) for e in evs_raw]
283
+
284
+ # Reset market + events, reset day, reset volatility
285
+ async with STATE_LOCK:
286
+ global MARKET, CURRENT_VOL
287
+ CURRENT_VOL = default_vol
288
+ MARKET = regen_market_with_volatility(seed=42, start_price=base_price, base_vol=default_vol)
289
+ EVENTS.clear()
290
+ for ev in evs:
291
+ EVENTS.setdefault(ev.day, []).append(ev)
292
+
293
+ async with DAY_LOCK:
294
+ global CURRENT_DAY
295
+ CURRENT_DAY = max(0, min(start_day, len(MARKET) - 1))
296
+
297
+ return {"ok": True, "name": name, "startDay": CURRENT_DAY, "eventsLoaded": len(evs)}
298
+
299
+
300
  # ----------------------------
301
  # WebSocket endpoint
302
  # ----------------------------
 
305
  await manager.connect(websocket, client_id)
306
 
307
  # Send INIT immediately
308
+ async with STATE_LOCK:
309
+ init_payload = {"market": MARKET, "startDay": 0}
310
+ await websocket.send_text(json.dumps({"type": "INIT", "payload": init_payload}))
 
 
 
 
 
311
 
312
  try:
313
  while True:
 
315
  try:
316
  data = json.loads(raw)
317
  except Exception:
 
318
  continue
319
 
320
  msg_type = data.get("type")
 
322
 
323
  if msg_type == "UPDATE_EQUITY":
324
  name = str(payload.get("name", client_id))
 
 
 
 
325
  try:
326
+ equity_f = float(payload.get("equity", 0.0))
327
  except Exception:
328
  equity_f = 0.0
329
  try:
330
+ roi_f = float(payload.get("roi", 0.0))
331
  except Exception:
332
  roi_f = 0.0
333
 
334
  await manager.update_equity(name=name, equity=equity_f, roi=roi_f)
 
 
 
335
 
336
  except WebSocketDisconnect:
337
  await manager.disconnect(client_id)
338
  except Exception:
 
339
  await manager.disconnect(client_id)
340
 
341
 
 
343
  # Background tick loop
344
  # ----------------------------
345
  async def game_loop():
346
+ global CURRENT_DAY, CURRENT_VOL
347
+
348
  while True:
349
  await asyncio.sleep(TICK_RATE)
350
+
351
+ # advance day
352
  async with DAY_LOCK:
353
  CURRENT_DAY = (CURRENT_DAY + 1) % len(MARKET)
354
  day = CURRENT_DAY
355
+
356
+ # apply scheduled events
357
+ news_to_broadcast: List[str] = []
358
+ async with STATE_LOCK:
359
+ if day in EVENTS and EVENTS[day]:
360
+ for ev in EVENTS[day]:
361
+ if ev.shockPct:
362
+ # shift entire future path from this day onwards
363
+ apply_price_shock_from_day(day, ev.shockPct)
364
+ if ev.volatility is not None:
365
+ CURRENT_VOL = float(ev.volatility)
366
+ if ev.news:
367
+ news_to_broadcast.append(ev.news)
368
+
369
+ # broadcast optional news first
370
+ for text in news_to_broadcast:
371
+ await manager.broadcast_news(day, text)
372
+
373
+ # broadcast tick
374
  await manager.broadcast_tick(day)
375
 
376
 
377
  @app.on_event("startup")
378
  async def on_startup():
 
379
  asyncio.create_task(game_loop())