| | import json |
| | import os |
| | import hashlib |
| | from flask import Blueprint, request, jsonify |
| | from datasets import load_dataset, Dataset |
| |
|
| | bp = Blueprint("model_datasets", __name__, url_prefix="/api/model/datasets") |
| |
|
| | |
| | _cache: dict[str, dict] = {} |
| |
|
| |
|
| | def _make_id(repo: str, column: str, split: str) -> str: |
| | key = f"{repo}:{column}:{split}" |
| | return hashlib.md5(key.encode()).hexdigest()[:12] |
| |
|
| |
|
| | def _load_hf_dataset(repo: str, split: str) -> Dataset: |
| | if os.path.exists(repo): |
| | return Dataset.from_parquet(repo) |
| | return load_dataset(repo, split=split) |
| |
|
| |
|
| | def _detect_response_column(columns: list[str], preferred: str) -> str: |
| | if preferred in columns: |
| | return preferred |
| | for fallback in ["model_responses", "response", "responses", "output", "outputs"]: |
| | if fallback in columns: |
| | return fallback |
| | return preferred |
| |
|
| |
|
| | def _detect_prompt_column(columns: list[str], preferred: str) -> str | None: |
| | if preferred in columns: |
| | return preferred |
| | for fallback in ["formatted_prompt", "prompt", "question", "input"]: |
| | if fallback in columns: |
| | return fallback |
| | return None |
| |
|
| |
|
| | def _compute_question_fingerprint(ds: Dataset, n: int = 5) -> str: |
| | """Hash first N question texts to fingerprint the question set.""" |
| | questions = [] |
| | for i in range(min(n, len(ds))): |
| | row = ds[i] |
| | for qcol in ["question", "prompt", "input", "formatted_prompt"]: |
| | if qcol in row: |
| | questions.append(str(row[qcol] or "")[:200]) |
| | break |
| | return hashlib.md5("||".join(questions).encode()).hexdigest()[:8] |
| |
|
| |
|
| | def _count_samples(ds: Dataset, column: str) -> int: |
| | if len(ds) == 0: |
| | return 0 |
| | first = ds[0][column] |
| | if isinstance(first, list): |
| | return len(first) |
| | return 1 |
| |
|
| |
|
| | def _flatten_evals(evals) -> list[bool]: |
| | if not isinstance(evals, list): |
| | return [bool(evals)] |
| | return [ |
| | bool(e[-1]) if isinstance(e, list) and len(e) > 0 |
| | else (bool(e) if not isinstance(e, list) else False) |
| | for e in evals |
| | ] |
| |
|
| |
|
| | def _extract_reasoning(meta: dict | None) -> str | None: |
| | """Extract reasoning/thinking content from response metadata's raw_response.""" |
| | if not meta or not isinstance(meta, dict): |
| | return None |
| | raw = meta.get("raw_response") |
| | if not raw or not isinstance(raw, dict): |
| | return None |
| | try: |
| | msg = raw["choices"][0]["message"] |
| | return ( |
| | msg.get("reasoning_content") |
| | or msg.get("thinking") |
| | or msg.get("reasoning") |
| | ) |
| | except (KeyError, IndexError, TypeError): |
| | return None |
| |
|
| |
|
| | def _merge_reasoning_into_response(response: str, reasoning: str | None) -> str: |
| | """Prepend <think>{reasoning}</think> to response if reasoning exists |
| | and isn't already present in the response.""" |
| | if not reasoning: |
| | return response or "" |
| | response = response or "" |
| | |
| | if "<think>" in response: |
| | return response |
| | return f"<think>{reasoning}</think>\n{response}" |
| |
|
| |
|
| | def _analyze_trace(text: str) -> dict: |
| | if not text: |
| | return dict(total_len=0, think_len=0, answer_len=0, |
| | backtracks=0, restarts=0, think_text="", answer_text="") |
| | think_end = text.find("</think>") |
| | if think_end > 0: |
| | |
| | think_text = text[:think_end + 8] |
| | answer_text = text[think_end + 8:].strip() |
| | else: |
| | think_text = text |
| | answer_text = "" |
| | t = text.lower() |
| | backtracks = sum(t.count(w) for w in |
| | ["wait,", "wait ", "hmm", "let me try", "try again", |
| | "another approach", "let me reconsider"]) |
| | restarts = sum(t.count(w) for w in |
| | ["start over", "fresh approach", "different approach", "from scratch"]) |
| | return dict(total_len=len(text), think_len=len(think_text), |
| | answer_len=len(answer_text), backtracks=backtracks, |
| | restarts=restarts, think_text=think_text, answer_text=answer_text) |
| |
|
| |
|
| | @bp.route("/load", methods=["POST"]) |
| | def load_dataset_endpoint(): |
| | data = request.get_json() |
| | repo = data.get("repo", "").strip() |
| | if not repo: |
| | return jsonify({"error": "repo is required"}), 400 |
| |
|
| | split = data.get("split", "train") |
| | preferred_column = data.get("column", "model_responses") |
| | preferred_prompt_column = data.get("prompt_column", "formatted_prompt") |
| |
|
| | try: |
| | ds = _load_hf_dataset(repo, split) |
| | except Exception as e: |
| | return jsonify({"error": f"Failed to load dataset: {e}"}), 400 |
| |
|
| | columns = ds.column_names |
| | column = _detect_response_column(columns, preferred_column) |
| | prompt_column = _detect_prompt_column(columns, preferred_prompt_column) |
| |
|
| | if column not in columns: |
| | return jsonify({ |
| | "error": f"Column '{column}' not found. Available: {columns}" |
| | }), 400 |
| |
|
| | n_samples = _count_samples(ds, column) |
| | ds_id = _make_id(repo, column, split) |
| | fingerprint = _compute_question_fingerprint(ds) |
| |
|
| | _cache[ds_id] = { |
| | "dataset": ds, |
| | "repo": repo, |
| | "column": column, |
| | "prompt_column": prompt_column, |
| | "split": split, |
| | "n_rows": len(ds), |
| | "n_samples": n_samples, |
| | "question_fingerprint": fingerprint, |
| | } |
| |
|
| | short_name = repo.rsplit("/", 1)[-1] if "/" in repo else repo |
| |
|
| | return jsonify({ |
| | "id": ds_id, |
| | "repo": repo, |
| | "name": short_name, |
| | "column": column, |
| | "prompt_column": prompt_column, |
| | "columns": columns, |
| | "split": split, |
| | "n_rows": len(ds), |
| | "n_samples": n_samples, |
| | "question_fingerprint": fingerprint, |
| | }) |
| |
|
| |
|
| | @bp.route("/", methods=["GET"]) |
| | def list_datasets(): |
| | result = [] |
| | for ds_id, info in _cache.items(): |
| | result.append({ |
| | "id": ds_id, |
| | "repo": info["repo"], |
| | "name": info["repo"].rsplit("/", 1)[-1] if "/" in info["repo"] else info["repo"], |
| | "column": info["column"], |
| | "split": info["split"], |
| | "n_rows": info["n_rows"], |
| | "n_samples": info["n_samples"], |
| | "question_fingerprint": info.get("question_fingerprint", ""), |
| | }) |
| | return jsonify(result) |
| |
|
| |
|
| | @bp.route("/<ds_id>/question/<int:idx>", methods=["GET"]) |
| | def get_question(ds_id, idx): |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | ds = info["dataset"] |
| | column = info["column"] |
| |
|
| | if idx < 0 or idx >= len(ds): |
| | return jsonify({"error": f"Index {idx} out of range (0-{len(ds)-1})"}), 400 |
| |
|
| | row = ds[idx] |
| | responses_raw = row[column] |
| | if not isinstance(responses_raw, list): |
| | responses_raw = [responses_raw] |
| |
|
| | |
| | meta_column = f"{column}__metadata" |
| | response_metas = None |
| | if meta_column in row: |
| | response_metas = row[meta_column] |
| | if not isinstance(response_metas, list): |
| | response_metas = [response_metas] |
| |
|
| | |
| | merged_responses = [] |
| | for i, resp in enumerate(responses_raw): |
| | meta = response_metas[i] if response_metas and i < len(response_metas) else None |
| | reasoning = _extract_reasoning(meta) |
| | merged_responses.append(_merge_reasoning_into_response(resp, reasoning)) |
| | responses_raw = merged_responses |
| |
|
| | |
| | prompt_text = "" |
| | prompt_col = info.get("prompt_column") |
| | if prompt_col and prompt_col in row: |
| | val = row[prompt_col] |
| | if isinstance(val, str): |
| | prompt_text = val |
| | elif isinstance(val, list): |
| | prompt_text = json.dumps(val) |
| | elif val is not None: |
| | prompt_text = str(val) |
| |
|
| | question = "" |
| | for qcol in ["question", "prompt", "input", "problem", "formatted_prompt"]: |
| | if qcol in row: |
| | val = row[qcol] or "" |
| | if isinstance(val, str): |
| | question = val |
| | elif isinstance(val, list): |
| | question = json.dumps(val) |
| | else: |
| | question = str(val) |
| | break |
| |
|
| | eval_correct = [] |
| | if "eval_correct" in row: |
| | eval_correct = _flatten_evals(row["eval_correct"]) |
| |
|
| | |
| | extractions = [] |
| | extractions_col = f"{column}__extractions" |
| | for ecol in [extractions_col, "response__extractions"]: |
| | if ecol in row: |
| | ext = row[ecol] |
| | if isinstance(ext, list): |
| | extractions = [str(e) for e in ext] |
| | break |
| |
|
| | metadata = {} |
| | if "metadata" in row: |
| | metadata = row["metadata"] if isinstance(row["metadata"], dict) else {} |
| |
|
| | analyses = [_analyze_trace(r or "") for r in responses_raw] |
| |
|
| | return jsonify({ |
| | "question": question, |
| | "prompt_text": prompt_text, |
| | "responses": [r or "" for r in responses_raw], |
| | "eval_correct": eval_correct, |
| | "extractions": extractions, |
| | "metadata": metadata, |
| | "analyses": analyses, |
| | "n_samples": len(responses_raw), |
| | "index": idx, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>/summary", methods=["GET"]) |
| | def get_summary(ds_id): |
| | if ds_id not in _cache: |
| | return jsonify({"error": "Dataset not loaded"}), 404 |
| |
|
| | info = _cache[ds_id] |
| | ds = info["dataset"] |
| | n_rows = info["n_rows"] |
| | n_samples = info["n_samples"] |
| |
|
| | if "eval_correct" not in ds.column_names: |
| | return jsonify({ |
| | "n_rows": n_rows, |
| | "n_samples": n_samples, |
| | "has_eval": False, |
| | }) |
| |
|
| | pass_at = {} |
| | for k in [1, 2, 4, 8]: |
| | if k > n_samples: |
| | break |
| | correct = sum(1 for i in range(n_rows) |
| | if any(_flatten_evals(ds[i]["eval_correct"])[:k])) |
| | pass_at[k] = {"correct": correct, "total": n_rows, |
| | "rate": correct / n_rows if n_rows > 0 else 0} |
| |
|
| | total_samples = n_rows * n_samples |
| | total_correct = sum( |
| | sum(_flatten_evals(ds[i]["eval_correct"])) |
| | for i in range(n_rows) |
| | ) |
| |
|
| | return jsonify({ |
| | "n_rows": n_rows, |
| | "n_samples": n_samples, |
| | "has_eval": True, |
| | "sample_accuracy": { |
| | "correct": total_correct, |
| | "total": total_samples, |
| | "rate": total_correct / total_samples if total_samples > 0 else 0, |
| | }, |
| | "pass_at": pass_at, |
| | }) |
| |
|
| |
|
| | @bp.route("/<ds_id>", methods=["DELETE"]) |
| | def unload_dataset(ds_id): |
| | if ds_id in _cache: |
| | del _cache[ds_id] |
| | return jsonify({"status": "ok"}) |
| |
|