Spaces:
Running
Running
File size: 4,549 Bytes
c2ea5ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""Example Traces Router
Serves list/detail/import endpoints for Who_and_When dataset subsets.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import List, Optional
from fastapi import APIRouter, HTTPException, status, Depends
from pydantic import BaseModel
from backend.database.utils import save_trace, get_db
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/example-traces", tags=["example-traces"])
DATA_DIR = Path(__file__).resolve().parent.parent.parent / "datasets" / "example_traces"
SUBSET_FILES = {
"Algorithm-Generated": DATA_DIR / "algorithm-generated.jsonl",
"Hand-Crafted": DATA_DIR / "hand-crafted.jsonl",
}
# Full record model used internally and for detail endpoint
class ExampleTrace(BaseModel):
id: int
subset: str
mistake_step: int
question: str | None = None
agent: str | None = None
agents: list[str] | None = None
trace: str
# NEW: Failure analysis fields
is_correct: bool | None = None
question_id: str | None = None
ground_truth: str | None = None
mistake_agent: str | None = None
mistake_reason: str | None = None
# Module-level cache {subset: List[ExampleTrace]}
_cache: dict[str, List[ExampleTrace]] = {}
def _load_subset(subset: str) -> List[ExampleTrace]:
if subset not in SUBSET_FILES:
raise HTTPException(status_code=404, detail=f"Invalid subset {subset}")
if subset in _cache:
return _cache[subset]
path = SUBSET_FILES[subset]
if not path.exists():
raise HTTPException(
status_code=500,
detail=f"Subset file {path} missing on server. Run fetch_example_dataset.py first.",
)
examples: List[ExampleTrace] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
examples.append(ExampleTrace(**obj))
_cache[subset] = examples
return examples
class ExampleTraceLite(BaseModel):
id: int
subset: str
mistake_step: int
question: str | None = None
agent: str | None = None
agents: list[str] | None = None
# NEW: Include key failure analysis fields in lite version
is_correct: bool | None = None
mistake_agent: str | None = None
mistake_reason: str | None = None
@router.get("/", response_model=List[ExampleTraceLite])
async def list_examples(subset: Optional[str] = None):
"""List all example traces (optionally filter by subset)."""
subsets = [subset] if subset else SUBSET_FILES.keys()
items: List[ExampleTraceLite] = []
for sub in subsets:
examples = _load_subset(sub)
for ex in examples:
items.append(
ExampleTraceLite(
id=ex.id,
subset=ex.subset,
mistake_step=ex.mistake_step,
question=ex.question,
agent=ex.agent,
agents=ex.agents,
# Include failure analysis fields
is_correct=ex.is_correct,
mistake_agent=ex.mistake_agent,
mistake_reason=ex.mistake_reason
)
)
return items
@router.get("/{subset}/{example_id}", response_model=ExampleTrace)
async def get_example(subset: str, example_id: int):
examples = _load_subset(subset)
try:
return next(ex for ex in examples if ex.id == example_id)
except StopIteration:
raise HTTPException(status_code=404, detail="Example not found")
class ImportRequest(BaseModel):
subset: str
id: int
@router.post("/import")
async def import_example(req: ImportRequest, db: Session = Depends(get_db)):
examples = _load_subset(req.subset)
try:
ex = next(e for e in examples if e.id == req.id)
except StopIteration:
raise HTTPException(status_code=404, detail="Example not found")
# Save as trace using existing util
filename = f"example_{req.subset.lower().replace(' ', '_')}_{req.id}.txt"
trace = save_trace(
session=db,
content=ex.trace,
filename=filename,
title=f"Example {req.subset} #{req.id}",
description=(ex.question or "")[:100],
trace_type="example",
trace_source="example_dataset",
tags=["example", req.subset.replace(" ", "_").lower()],
)
logger.info(f"Imported example trace {req.subset}#{req.id} as trace_id={trace.trace_id}")
return trace.to_dict() |