Eric Xu commited on
Move counterfactual params from GET query to POST+ticket to avoid log leakage
Browse files- web/app.py +33 -7
- web/static/index.html +13 -7
web/app.py
CHANGED
|
@@ -565,19 +565,45 @@ async def evaluate_stream(sid: str, parallel: int = 5, bias_calibration: bool =
|
|
| 565 |
return EventSourceResponse(event_generator())
|
| 566 |
|
| 567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
@app.get("/api/counterfactual/stream/{sid}")
|
| 569 |
-
async def counterfactual_stream(
|
| 570 |
-
|
| 571 |
-
min_score: int = 4, max_score: int = 7, parallel: int = 5
|
| 572 |
-
):
|
| 573 |
-
"""Run counterfactual probes with SSE progress. Goal enables VJP weighting."""
|
| 574 |
if sid not in sessions:
|
| 575 |
raise HTTPException(404, "Session not found")
|
| 576 |
session = sessions[sid]
|
| 577 |
if not session["eval_results"]:
|
| 578 |
raise HTTPException(400, "Run evaluation first")
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
async def event_generator():
|
| 583 |
client = get_client()
|
|
|
|
| 565 |
return EventSourceResponse(event_generator())
|
| 566 |
|
| 567 |
|
| 568 |
+
class CounterfactualRequest(BaseModel):
|
| 569 |
+
changes: list[dict]
|
| 570 |
+
goal: str = ""
|
| 571 |
+
min_score: int = 4
|
| 572 |
+
max_score: int = 7
|
| 573 |
+
parallel: int = 5
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
# Store pending counterfactual configs for SSE pickup
|
| 577 |
+
_cf_pending: dict = {}
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@app.post("/api/counterfactual/prepare/{sid}")
|
| 581 |
+
async def prepare_counterfactual(sid: str, req: CounterfactualRequest):
|
| 582 |
+
"""Stage counterfactual config, return a ticket for the SSE stream."""
|
| 583 |
+
if sid not in sessions:
|
| 584 |
+
raise HTTPException(404, "Session not found")
|
| 585 |
+
ticket = uuid.uuid4().hex[:8]
|
| 586 |
+
_cf_pending[ticket] = req
|
| 587 |
+
return {"ticket": ticket}
|
| 588 |
+
|
| 589 |
+
|
| 590 |
@app.get("/api/counterfactual/stream/{sid}")
|
| 591 |
+
async def counterfactual_stream(sid: str, ticket: str, **_):
|
| 592 |
+
"""Run counterfactual probes with SSE progress."""
|
|
|
|
|
|
|
|
|
|
| 593 |
if sid not in sessions:
|
| 594 |
raise HTTPException(404, "Session not found")
|
| 595 |
session = sessions[sid]
|
| 596 |
if not session["eval_results"]:
|
| 597 |
raise HTTPException(400, "Run evaluation first")
|
| 598 |
+
req = _cf_pending.pop(ticket, None)
|
| 599 |
+
if not req:
|
| 600 |
+
raise HTTPException(400, "Invalid or expired ticket")
|
| 601 |
+
|
| 602 |
+
all_changes = req.changes
|
| 603 |
+
goal = req.goal
|
| 604 |
+
min_score = req.min_score
|
| 605 |
+
max_score = req.max_score
|
| 606 |
+
parallel = req.parallel
|
| 607 |
|
| 608 |
async def event_generator():
|
| 609 |
client = get_client()
|
web/static/index.html
CHANGED
|
@@ -905,16 +905,22 @@ async function runDirections() {
|
|
| 905 |
// Phase 3: Run counterfactual probes
|
| 906 |
document.getElementById('cfProgressText').textContent = 'Testing changes against persuadable middle...';
|
| 907 |
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 914 |
});
|
|
|
|
| 915 |
|
| 916 |
await new Promise((resolve, reject) => {
|
| 917 |
-
const es = new EventSource(`/api/counterfactual/stream/${sessionId}?${
|
| 918 |
|
| 919 |
es.addEventListener('start', (e) => {
|
| 920 |
const d = JSON.parse(e.data);
|
|
|
|
| 905 |
// Phase 3: Run counterfactual probes
|
| 906 |
document.getElementById('cfProgressText').textContent = 'Testing changes against persuadable middle...';
|
| 907 |
|
| 908 |
+
// POST config first, get a ticket, then SSE with just the ticket
|
| 909 |
+
const prepResp = await fetch(`/api/counterfactual/prepare/${sessionId}`, {
|
| 910 |
+
method: 'POST',
|
| 911 |
+
headers: {'Content-Type': 'application/json'},
|
| 912 |
+
body: JSON.stringify({
|
| 913 |
+
changes: suggestedChanges,
|
| 914 |
+
goal: goal,
|
| 915 |
+
min_score: 4,
|
| 916 |
+
max_score: 7,
|
| 917 |
+
parallel: 5,
|
| 918 |
+
}),
|
| 919 |
});
|
| 920 |
+
const {ticket} = await prepResp.json();
|
| 921 |
|
| 922 |
await new Promise((resolve, reject) => {
|
| 923 |
+
const es = new EventSource(`/api/counterfactual/stream/${sessionId}?ticket=${ticket}`);
|
| 924 |
|
| 925 |
es.addEventListener('start', (e) => {
|
| 926 |
const d = JSON.parse(e.data);
|