Eric Xu commited on
Commit
6dcf8e5
·
unverified ·
1 Parent(s): cb7e365

Fix production issues: session overwrite, error handling, memory leaks

Browse files

Critical fixes from audit:
- Fix entity_text overwrite: cohort endpoint no longer creates/overwrites
sessions — frontend creates session first, then uploads cohort
- Add try-catch around fut.result() in all SSE stream workers
- Add TTL cleanup for _cf_pending tickets (expire after 10min)
- Store goal/audience in session for consistency
- Remove unused asyncio.get_event_loop() call
- Add SESSION_MAX_AGE_HOURS constant for future cleanup

Files changed (1) hide show
  1. web/app.py +20 -19
web/app.py CHANGED
@@ -67,6 +67,7 @@ app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), na
67
 
68
  # In-memory store for active sessions
69
  sessions: dict = {}
 
70
 
71
  # Nemotron dataset — loaded once if available
72
  _nemotron_ds = None
@@ -213,9 +214,12 @@ async def create_session(entity: EntityInput):
213
  sessions[sid] = {
214
  "id": sid,
215
  "entity_text": entity.entity_text,
 
 
216
  "cohort": None,
217
  "eval_results": None,
218
  "gradient": None,
 
219
  "created": datetime.now().isoformat(),
220
  }
221
  return {"session_id": sid}
@@ -412,7 +416,6 @@ If nothing specific is stated, return {{}}."""
412
  @app.post("/api/cohort/generate")
413
  async def generate_cohort_endpoint(config: CohortConfig):
414
  """Generate a cohort — from Nemotron if available, else LLM-generated."""
415
- sid = uuid.uuid4().hex[:12]
416
  total = sum(s.get("count", 8) for s in config.segments)
417
 
418
  ds = get_nemotron()
@@ -465,17 +468,8 @@ async def generate_cohort_endpoint(config: CohortConfig):
465
  for i, p in enumerate(all_personas):
466
  p["user_id"] = i
467
 
468
- sessions[sid] = {
469
- "id": sid,
470
- "entity_text": config.description,
471
- "cohort": all_personas,
472
- "eval_results": None,
473
- "gradient": None,
474
- "created": datetime.now().isoformat(),
475
- }
476
-
477
  return {
478
- "session_id": sid, "cohort_size": len(all_personas),
479
  "cohort": all_personas, "source": source,
480
  "filters": filters if ds is not None else None,
481
  }
@@ -516,8 +510,6 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
516
  results = [None] * total
517
  done = 0
518
  t0 = time.time()
519
- loop = asyncio.get_event_loop()
520
-
521
  with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
522
  futs = {
523
  pool.submit(evaluate_one, client, model, ev, entity_text,
@@ -526,7 +518,10 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
526
  }
527
  for fut in concurrent.futures.as_completed(futs):
528
  idx = futs[fut]
529
- result = fut.result()
 
 
 
530
  results[idx] = result
531
  done += 1
532
 
@@ -573,8 +568,8 @@ class CounterfactualRequest(BaseModel):
573
  parallel: int = 5
574
 
575
 
576
- # Store pending counterfactual configs for SSE pickup
577
- _cf_pending: dict = {}
578
 
579
 
580
  @app.post("/api/counterfactual/prepare/{sid}")
@@ -583,7 +578,12 @@ async def prepare_counterfactual(sid: str, req: CounterfactualRequest):
583
  if sid not in sessions:
584
  raise HTTPException(404, "Session not found")
585
  ticket = uuid.uuid4().hex[:8]
586
- _cf_pending[ticket] = req
 
 
 
 
 
587
  return {"ticket": ticket}
588
 
589
 
@@ -595,9 +595,10 @@ async def counterfactual_stream(sid: str, ticket: str):
595
  session = sessions[sid]
596
  if not session["eval_results"]:
597
  raise HTTPException(400, "Run evaluation first")
598
- req = _cf_pending.pop(ticket, None)
599
- if not req:
600
  raise HTTPException(400, "Invalid or expired ticket")
 
601
 
602
  all_changes = req.changes
603
  goal = req.goal
 
67
 
68
  # In-memory store for active sessions
69
  sessions: dict = {}
70
+ SESSION_MAX_AGE_HOURS = 24
71
 
72
  # Nemotron dataset — loaded once if available
73
  _nemotron_ds = None
 
214
  sessions[sid] = {
215
  "id": sid,
216
  "entity_text": entity.entity_text,
217
+ "goal": "",
218
+ "audience": "",
219
  "cohort": None,
220
  "eval_results": None,
221
  "gradient": None,
222
+ "bias_audit": None,
223
  "created": datetime.now().isoformat(),
224
  }
225
  return {"session_id": sid}
 
416
  @app.post("/api/cohort/generate")
417
  async def generate_cohort_endpoint(config: CohortConfig):
418
  """Generate a cohort — from Nemotron if available, else LLM-generated."""
 
419
  total = sum(s.get("count", 8) for s in config.segments)
420
 
421
  ds = get_nemotron()
 
468
  for i, p in enumerate(all_personas):
469
  p["user_id"] = i
470
 
 
 
 
 
 
 
 
 
 
471
  return {
472
+ "cohort_size": len(all_personas),
473
  "cohort": all_personas, "source": source,
474
  "filters": filters if ds is not None else None,
475
  }
 
510
  results = [None] * total
511
  done = 0
512
  t0 = time.time()
 
 
513
  with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
514
  futs = {
515
  pool.submit(evaluate_one, client, model, ev, entity_text,
 
518
  }
519
  for fut in concurrent.futures.as_completed(futs):
520
  idx = futs[fut]
521
+ try:
522
+ result = fut.result()
523
+ except Exception as e:
524
+ result = {"error": str(e), "_evaluator": {"name": "?"}}
525
  results[idx] = result
526
  done += 1
527
 
 
568
  parallel: int = 5
569
 
570
 
571
+ # Store pending counterfactual configs for SSE pickup (with timestamps)
572
+ _cf_pending: dict = {} # ticket -> {"req": CounterfactualRequest, "ts": time.time()}
573
 
574
 
575
  @app.post("/api/counterfactual/prepare/{sid}")
 
578
  if sid not in sessions:
579
  raise HTTPException(404, "Session not found")
580
  ticket = uuid.uuid4().hex[:8]
581
+ # Clean expired tickets (>10 min)
582
+ now = time.time()
583
+ expired = [k for k, v in _cf_pending.items() if now - v.get("ts", 0) > 600]
584
+ for k in expired:
585
+ del _cf_pending[k]
586
+ _cf_pending[ticket] = {"req": req, "ts": now}
587
  return {"ticket": ticket}
588
 
589
 
 
595
  session = sessions[sid]
596
  if not session["eval_results"]:
597
  raise HTTPException(400, "Run evaluation first")
598
+ entry = _cf_pending.pop(ticket, None)
599
+ if not entry:
600
  raise HTTPException(400, "Invalid or expired ticket")
601
+ req = entry["req"]
602
 
603
  all_changes = req.changes
604
  goal = req.goal