File size: 4,549 Bytes
c2ea5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Example Traces Router
Serves list/detail/import endpoints for Who_and_When dataset subsets.
"""
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import List, Optional

from fastapi import APIRouter, HTTPException, status, Depends
from pydantic import BaseModel

from backend.database.utils import save_trace, get_db
from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/example-traces", tags=["example-traces"])

DATA_DIR = Path(__file__).resolve().parent.parent.parent / "datasets" / "example_traces"
SUBSET_FILES = {
    "Algorithm-Generated": DATA_DIR / "algorithm-generated.jsonl",
    "Hand-Crafted": DATA_DIR / "hand-crafted.jsonl",
}

# Full record model used internally and for detail endpoint
class ExampleTrace(BaseModel):
    id: int
    subset: str
    mistake_step: int
    question: str | None = None
    agent: str | None = None
    agents: list[str] | None = None
    trace: str
    # NEW: Failure analysis fields
    is_correct: bool | None = None
    question_id: str | None = None
    ground_truth: str | None = None
    mistake_agent: str | None = None
    mistake_reason: str | None = None


# Module-level cache {subset: List[ExampleTrace]}
_cache: dict[str, List[ExampleTrace]] = {}


def _load_subset(subset: str) -> List[ExampleTrace]:
    if subset not in SUBSET_FILES:
        raise HTTPException(status_code=404, detail=f"Invalid subset {subset}")
    if subset in _cache:
        return _cache[subset]
    path = SUBSET_FILES[subset]
    if not path.exists():
        raise HTTPException(
            status_code=500,
            detail=f"Subset file {path} missing on server. Run fetch_example_dataset.py first.",
        )
    examples: List[ExampleTrace] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            obj = json.loads(line)
            examples.append(ExampleTrace(**obj))
    _cache[subset] = examples
    return examples


class ExampleTraceLite(BaseModel):
    id: int
    subset: str
    mistake_step: int
    question: str | None = None
    agent: str | None = None
    agents: list[str] | None = None
    # NEW: Include key failure analysis fields in lite version
    is_correct: bool | None = None
    mistake_agent: str | None = None
    mistake_reason: str | None = None


@router.get("/", response_model=List[ExampleTraceLite])
async def list_examples(subset: Optional[str] = None):
    """List all example traces (optionally filter by subset)."""
    subsets = [subset] if subset else SUBSET_FILES.keys()
    items: List[ExampleTraceLite] = []
    for sub in subsets:
        examples = _load_subset(sub)
        for ex in examples:
            items.append(
                ExampleTraceLite(
                    id=ex.id, 
                    subset=ex.subset, 
                    mistake_step=ex.mistake_step,
                    question=ex.question, 
                    agent=ex.agent, 
                    agents=ex.agents,
                    # Include failure analysis fields
                    is_correct=ex.is_correct,
                    mistake_agent=ex.mistake_agent,
                    mistake_reason=ex.mistake_reason
                )
            )
    return items


@router.get("/{subset}/{example_id}", response_model=ExampleTrace)
async def get_example(subset: str, example_id: int):
    examples = _load_subset(subset)
    try:
        return next(ex for ex in examples if ex.id == example_id)
    except StopIteration:
        raise HTTPException(status_code=404, detail="Example not found")


class ImportRequest(BaseModel):
    subset: str
    id: int


@router.post("/import")
async def import_example(req: ImportRequest, db: Session = Depends(get_db)):
    examples = _load_subset(req.subset)
    try:
        ex = next(e for e in examples if e.id == req.id)
    except StopIteration:
        raise HTTPException(status_code=404, detail="Example not found")
    # Save as trace using existing util
    filename = f"example_{req.subset.lower().replace(' ', '_')}_{req.id}.txt"
    trace = save_trace(
        session=db,
        content=ex.trace,
        filename=filename,
        title=f"Example {req.subset} #{req.id}",
        description=(ex.question or "")[:100],
        trace_type="example",
        trace_source="example_dataset",
        tags=["example", req.subset.replace(" ", "_").lower()],
    )
    logger.info(f"Imported example trace {req.subset}#{req.id} as trace_id={trace.trace_id}")
    return trace.to_dict()