Eric Xu commited on
Fix production issues: session overwrite, error handling, memory leaks
Browse filesCritical fixes from audit:
- Fix entity_text overwrite: cohort endpoint no longer creates/overwrites
sessions — frontend creates session first, then uploads cohort
- Add try-catch around fut.result() in all SSE stream workers
- Add TTL cleanup for _cf_pending tickets (expire after 10min)
- Store goal/audience in session for consistency
- Remove unused asyncio.get_event_loop() call
- Add SESSION_MAX_AGE_HOURS constant for future cleanup
- web/app.py +20 -19
web/app.py
CHANGED
|
@@ -67,6 +67,7 @@ app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), na
|
|
| 67 |
|
| 68 |
# In-memory store for active sessions
|
| 69 |
sessions: dict = {}
|
|
|
|
| 70 |
|
| 71 |
# Nemotron dataset — loaded once if available
|
| 72 |
_nemotron_ds = None
|
|
@@ -213,9 +214,12 @@ async def create_session(entity: EntityInput):
|
|
| 213 |
sessions[sid] = {
|
| 214 |
"id": sid,
|
| 215 |
"entity_text": entity.entity_text,
|
|
|
|
|
|
|
| 216 |
"cohort": None,
|
| 217 |
"eval_results": None,
|
| 218 |
"gradient": None,
|
|
|
|
| 219 |
"created": datetime.now().isoformat(),
|
| 220 |
}
|
| 221 |
return {"session_id": sid}
|
|
@@ -412,7 +416,6 @@ If nothing specific is stated, return {{}}."""
|
|
| 412 |
@app.post("/api/cohort/generate")
|
| 413 |
async def generate_cohort_endpoint(config: CohortConfig):
|
| 414 |
"""Generate a cohort — from Nemotron if available, else LLM-generated."""
|
| 415 |
-
sid = uuid.uuid4().hex[:12]
|
| 416 |
total = sum(s.get("count", 8) for s in config.segments)
|
| 417 |
|
| 418 |
ds = get_nemotron()
|
|
@@ -465,17 +468,8 @@ async def generate_cohort_endpoint(config: CohortConfig):
|
|
| 465 |
for i, p in enumerate(all_personas):
|
| 466 |
p["user_id"] = i
|
| 467 |
|
| 468 |
-
sessions[sid] = {
|
| 469 |
-
"id": sid,
|
| 470 |
-
"entity_text": config.description,
|
| 471 |
-
"cohort": all_personas,
|
| 472 |
-
"eval_results": None,
|
| 473 |
-
"gradient": None,
|
| 474 |
-
"created": datetime.now().isoformat(),
|
| 475 |
-
}
|
| 476 |
-
|
| 477 |
return {
|
| 478 |
-
"
|
| 479 |
"cohort": all_personas, "source": source,
|
| 480 |
"filters": filters if ds is not None else None,
|
| 481 |
}
|
|
@@ -516,8 +510,6 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
|
|
| 516 |
results = [None] * total
|
| 517 |
done = 0
|
| 518 |
t0 = time.time()
|
| 519 |
-
loop = asyncio.get_event_loop()
|
| 520 |
-
|
| 521 |
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
|
| 522 |
futs = {
|
| 523 |
pool.submit(evaluate_one, client, model, ev, entity_text,
|
|
@@ -526,7 +518,10 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
|
|
| 526 |
}
|
| 527 |
for fut in concurrent.futures.as_completed(futs):
|
| 528 |
idx = futs[fut]
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
| 530 |
results[idx] = result
|
| 531 |
done += 1
|
| 532 |
|
|
@@ -573,8 +568,8 @@ class CounterfactualRequest(BaseModel):
|
|
| 573 |
parallel: int = 5
|
| 574 |
|
| 575 |
|
| 576 |
-
# Store pending counterfactual configs for SSE pickup
|
| 577 |
-
_cf_pending: dict = {}
|
| 578 |
|
| 579 |
|
| 580 |
@app.post("/api/counterfactual/prepare/{sid}")
|
|
@@ -583,7 +578,12 @@ async def prepare_counterfactual(sid: str, req: CounterfactualRequest):
|
|
| 583 |
if sid not in sessions:
|
| 584 |
raise HTTPException(404, "Session not found")
|
| 585 |
ticket = uuid.uuid4().hex[:8]
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
return {"ticket": ticket}
|
| 588 |
|
| 589 |
|
|
@@ -595,9 +595,10 @@ async def counterfactual_stream(sid: str, ticket: str):
|
|
| 595 |
session = sessions[sid]
|
| 596 |
if not session["eval_results"]:
|
| 597 |
raise HTTPException(400, "Run evaluation first")
|
| 598 |
-
|
| 599 |
-
if not
|
| 600 |
raise HTTPException(400, "Invalid or expired ticket")
|
|
|
|
| 601 |
|
| 602 |
all_changes = req.changes
|
| 603 |
goal = req.goal
|
|
|
|
| 67 |
|
| 68 |
# In-memory store for active sessions
|
| 69 |
sessions: dict = {}
|
| 70 |
+
SESSION_MAX_AGE_HOURS = 24
|
| 71 |
|
| 72 |
# Nemotron dataset — loaded once if available
|
| 73 |
_nemotron_ds = None
|
|
|
|
| 214 |
sessions[sid] = {
|
| 215 |
"id": sid,
|
| 216 |
"entity_text": entity.entity_text,
|
| 217 |
+
"goal": "",
|
| 218 |
+
"audience": "",
|
| 219 |
"cohort": None,
|
| 220 |
"eval_results": None,
|
| 221 |
"gradient": None,
|
| 222 |
+
"bias_audit": None,
|
| 223 |
"created": datetime.now().isoformat(),
|
| 224 |
}
|
| 225 |
return {"session_id": sid}
|
|
|
|
| 416 |
@app.post("/api/cohort/generate")
|
| 417 |
async def generate_cohort_endpoint(config: CohortConfig):
|
| 418 |
"""Generate a cohort — from Nemotron if available, else LLM-generated."""
|
|
|
|
| 419 |
total = sum(s.get("count", 8) for s in config.segments)
|
| 420 |
|
| 421 |
ds = get_nemotron()
|
|
|
|
| 468 |
for i, p in enumerate(all_personas):
|
| 469 |
p["user_id"] = i
|
| 470 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
return {
|
| 472 |
+
"cohort_size": len(all_personas),
|
| 473 |
"cohort": all_personas, "source": source,
|
| 474 |
"filters": filters if ds is not None else None,
|
| 475 |
}
|
|
|
|
| 510 |
results = [None] * total
|
| 511 |
done = 0
|
| 512 |
t0 = time.time()
|
|
|
|
|
|
|
| 513 |
with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
|
| 514 |
futs = {
|
| 515 |
pool.submit(evaluate_one, client, model, ev, entity_text,
|
|
|
|
| 518 |
}
|
| 519 |
for fut in concurrent.futures.as_completed(futs):
|
| 520 |
idx = futs[fut]
|
| 521 |
+
try:
|
| 522 |
+
result = fut.result()
|
| 523 |
+
except Exception as e:
|
| 524 |
+
result = {"error": str(e), "_evaluator": {"name": "?"}}
|
| 525 |
results[idx] = result
|
| 526 |
done += 1
|
| 527 |
|
|
|
|
| 568 |
parallel: int = 5
|
| 569 |
|
| 570 |
|
| 571 |
+
# Store pending counterfactual configs for SSE pickup (with timestamps)
|
| 572 |
+
_cf_pending: dict = {} # ticket -> {"req": CounterfactualRequest, "ts": time.time()}
|
| 573 |
|
| 574 |
|
| 575 |
@app.post("/api/counterfactual/prepare/{sid}")
|
|
|
|
| 578 |
if sid not in sessions:
|
| 579 |
raise HTTPException(404, "Session not found")
|
| 580 |
ticket = uuid.uuid4().hex[:8]
|
| 581 |
+
# Clean expired tickets (>10 min)
|
| 582 |
+
now = time.time()
|
| 583 |
+
expired = [k for k, v in _cf_pending.items() if now - v.get("ts", 0) > 600]
|
| 584 |
+
for k in expired:
|
| 585 |
+
del _cf_pending[k]
|
| 586 |
+
_cf_pending[ticket] = {"req": req, "ts": now}
|
| 587 |
return {"ticket": ticket}
|
| 588 |
|
| 589 |
|
|
|
|
| 595 |
session = sessions[sid]
|
| 596 |
if not session["eval_results"]:
|
| 597 |
raise HTTPException(400, "Run evaluation first")
|
| 598 |
+
entry = _cf_pending.pop(ticket, None)
|
| 599 |
+
if not entry:
|
| 600 |
raise HTTPException(400, "Invalid or expired ticket")
|
| 601 |
+
req = entry["req"]
|
| 602 |
|
| 603 |
all_changes = req.changes
|
| 604 |
goal = req.goal
|