agg-trace-visualizer / backend /api /model_datasets.py
Zayne Rea Sprague
small tweak
cdf803d
import json
import os
import hashlib
from flask import Blueprint, request, jsonify
from datasets import load_dataset, Dataset
bp = Blueprint("model_datasets", __name__, url_prefix="/api/model/datasets")
# In-memory cache: id -> {dataset, repo, column, split, n_rows, n_samples}
_cache: dict[str, dict] = {}
def _make_id(repo: str, column: str, split: str) -> str:
key = f"{repo}:{column}:{split}"
return hashlib.md5(key.encode()).hexdigest()[:12]
def _load_hf_dataset(repo: str, split: str) -> Dataset:
if os.path.exists(repo):
return Dataset.from_parquet(repo)
return load_dataset(repo, split=split)
def _detect_response_column(columns: list[str], preferred: str) -> str:
if preferred in columns:
return preferred
for fallback in ["model_responses", "response", "responses", "output", "outputs"]:
if fallback in columns:
return fallback
return preferred
def _detect_prompt_column(columns: list[str], preferred: str) -> str | None:
if preferred in columns:
return preferred
for fallback in ["formatted_prompt", "prompt", "question", "input"]:
if fallback in columns:
return fallback
return None
def _compute_question_fingerprint(ds: Dataset, n: int = 5) -> str:
"""Hash first N question texts to fingerprint the question set."""
questions = []
for i in range(min(n, len(ds))):
row = ds[i]
for qcol in ["question", "prompt", "input", "formatted_prompt"]:
if qcol in row:
questions.append(str(row[qcol] or "")[:200])
break
return hashlib.md5("||".join(questions).encode()).hexdigest()[:8]
def _count_samples(ds: Dataset, column: str) -> int:
if len(ds) == 0:
return 0
first = ds[0][column]
if isinstance(first, list):
return len(first)
return 1
def _flatten_evals(evals) -> list[bool]:
if not isinstance(evals, list):
return [bool(evals)]
return [
bool(e[-1]) if isinstance(e, list) and len(e) > 0
else (bool(e) if not isinstance(e, list) else False)
for e in evals
]
def _extract_reasoning(meta: dict | None) -> str | None:
"""Extract reasoning/thinking content from response metadata's raw_response."""
if not meta or not isinstance(meta, dict):
return None
raw = meta.get("raw_response")
if not raw or not isinstance(raw, dict):
return None
try:
msg = raw["choices"][0]["message"]
return (
msg.get("reasoning_content")
or msg.get("thinking")
or msg.get("reasoning")
)
except (KeyError, IndexError, TypeError):
return None
def _merge_reasoning_into_response(response: str, reasoning: str | None) -> str:
"""Prepend <think>{reasoning}</think> to response if reasoning exists
and isn't already present in the response."""
if not reasoning:
return response or ""
response = response or ""
# Don't double-add if response already contains the thinking
if "<think>" in response:
return response
return f"<think>{reasoning}</think>\n{response}"
def _analyze_trace(text: str) -> dict:
if not text:
return dict(total_len=0, think_len=0, answer_len=0,
backtracks=0, restarts=0, think_text="", answer_text="")
think_end = text.find("</think>")
if think_end > 0:
# Keep raw tags so display is 1:1 with HuggingFace data
think_text = text[:think_end + 8] # include </think>
answer_text = text[think_end + 8:].strip()
else:
think_text = text
answer_text = ""
t = text.lower()
backtracks = sum(t.count(w) for w in
["wait,", "wait ", "hmm", "let me try", "try again",
"another approach", "let me reconsider"])
restarts = sum(t.count(w) for w in
["start over", "fresh approach", "different approach", "from scratch"])
return dict(total_len=len(text), think_len=len(think_text),
answer_len=len(answer_text), backtracks=backtracks,
restarts=restarts, think_text=think_text, answer_text=answer_text)
@bp.route("/load", methods=["POST"])
def load_dataset_endpoint():
data = request.get_json()
repo = data.get("repo", "").strip()
if not repo:
return jsonify({"error": "repo is required"}), 400
split = data.get("split", "train")
preferred_column = data.get("column", "model_responses")
preferred_prompt_column = data.get("prompt_column", "formatted_prompt")
try:
ds = _load_hf_dataset(repo, split)
except Exception as e:
return jsonify({"error": f"Failed to load dataset: {e}"}), 400
columns = ds.column_names
column = _detect_response_column(columns, preferred_column)
prompt_column = _detect_prompt_column(columns, preferred_prompt_column)
if column not in columns:
return jsonify({
"error": f"Column '{column}' not found. Available: {columns}"
}), 400
n_samples = _count_samples(ds, column)
ds_id = _make_id(repo, column, split)
fingerprint = _compute_question_fingerprint(ds)
_cache[ds_id] = {
"dataset": ds,
"repo": repo,
"column": column,
"prompt_column": prompt_column,
"split": split,
"n_rows": len(ds),
"n_samples": n_samples,
"question_fingerprint": fingerprint,
}
short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo
return jsonify({
"id": ds_id,
"repo": repo,
"name": short_name,
"column": column,
"prompt_column": prompt_column,
"columns": columns,
"split": split,
"n_rows": len(ds),
"n_samples": n_samples,
"question_fingerprint": fingerprint,
})
@bp.route("/", methods=["GET"])
def list_datasets():
result = []
for ds_id, info in _cache.items():
result.append({
"id": ds_id,
"repo": info["repo"],
"name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"],
"column": info["column"],
"split": info["split"],
"n_rows": info["n_rows"],
"n_samples": info["n_samples"],
"question_fingerprint": info.get("question_fingerprint", ""),
})
return jsonify(result)
@bp.route("/<ds_id>/question/<int:idx>", methods=["GET"])
def get_question(ds_id, idx):
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
ds = info["dataset"]
column = info["column"]
if idx < 0 or idx >= len(ds):
return jsonify({"error": f"Index {idx} out of range (0-{len(ds)-1})"}), 400
row = ds[idx]
responses_raw = row[column]
if not isinstance(responses_raw, list):
responses_raw = [responses_raw]
# Check for {column}__metadata to recover reasoning/thinking content
meta_column = f"{column}__metadata"
response_metas = None
if meta_column in row:
response_metas = row[meta_column]
if not isinstance(response_metas, list):
response_metas = [response_metas]
# Merge reasoning from metadata into responses
merged_responses = []
for i, resp in enumerate(responses_raw):
meta = response_metas[i] if response_metas and i < len(response_metas) else None
reasoning = _extract_reasoning(meta)
merged_responses.append(_merge_reasoning_into_response(resp, reasoning))
responses_raw = merged_responses
# Prompt text from configured prompt column
prompt_text = ""
prompt_col = info.get("prompt_column")
if prompt_col and prompt_col in row:
val = row[prompt_col]
if isinstance(val, str):
prompt_text = val
elif isinstance(val, list):
prompt_text = json.dumps(val)
elif val is not None:
prompt_text = str(val)
question = ""
for qcol in ["question", "prompt", "input", "problem", "formatted_prompt"]:
if qcol in row:
val = row[qcol] or ""
if isinstance(val, str):
question = val
elif isinstance(val, list):
question = json.dumps(val)
else:
question = str(val)
break
eval_correct = []
if "eval_correct" in row:
eval_correct = _flatten_evals(row["eval_correct"])
# Check extractions with column-aware name
extractions = []
extractions_col = f"{column}__extractions"
for ecol in [extractions_col, "response__extractions"]:
if ecol in row:
ext = row[ecol]
if isinstance(ext, list):
extractions = [str(e) for e in ext]
break
metadata = {}
if "metadata" in row:
metadata = row["metadata"] if isinstance(row["metadata"], dict) else {}
analyses = [_analyze_trace(r or "") for r in responses_raw]
return jsonify({
"question": question,
"prompt_text": prompt_text,
"responses": [r or "" for r in responses_raw],
"eval_correct": eval_correct,
"extractions": extractions,
"metadata": metadata,
"analyses": analyses,
"n_samples": len(responses_raw),
"index": idx,
})
@bp.route("/<ds_id>/summary", methods=["GET"])
def get_summary(ds_id):
if ds_id not in _cache:
return jsonify({"error": "Dataset not loaded"}), 404
info = _cache[ds_id]
ds = info["dataset"]
n_rows = info["n_rows"]
n_samples = info["n_samples"]
if "eval_correct" not in ds.column_names:
return jsonify({
"n_rows": n_rows,
"n_samples": n_samples,
"has_eval": False,
})
pass_at = {}
for k in [1, 2, 4, 8]:
if k > n_samples:
break
correct = sum(1 for i in range(n_rows)
if any(_flatten_evals(ds[i]["eval_correct"])[:k]))
pass_at[k] = {"correct": correct, "total": n_rows,
"rate": correct / n_rows if n_rows > 0 else 0}
total_samples = n_rows * n_samples
total_correct = sum(
sum(_flatten_evals(ds[i]["eval_correct"]))
for i in range(n_rows)
)
return jsonify({
"n_rows": n_rows,
"n_samples": n_samples,
"has_eval": True,
"sample_accuracy": {
"correct": total_correct,
"total": total_samples,
"rate": total_correct / total_samples if total_samples > 0 else 0,
},
"pass_at": pass_at,
})
@bp.route("/<ds_id>", methods=["DELETE"])
def unload_dataset(ds_id):
if ds_id in _cache:
del _cache[ds_id]
return jsonify({"status": "ok"})