analyst-buddy / training /data_loading.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
5.92 kB
"""Data/model loading helpers for the GRPO training notebook."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def filter_questions_by_difficulty(
questions: list[dict[str, Any]], allowed: list[str] | None
) -> list[dict[str, Any]]:
"""Filter question records by case-insensitive difficulty labels."""
if not allowed:
return questions
allowed_set = {level.lower() for level in allowed}
return [
question
for question in questions
if str(question.get("difficulty", "")).lower() in allowed_set
]
def _drop_harness_broken(
questions: list[dict[str, Any]], questions_path: str, db_dir: str
) -> list[dict[str, Any]]:
"""Drop questions whose gold answer is harness-broken (see validator)."""
# The installed-package form (sql_env.scripts) is tried first — it is the
# only one that resolves when analyst-buddy is pip-installed. The bare forms
# cover pytest (pythonpath=".") and direct script invocation.
try:
from sql_env.scripts.validate_questions import broken_question_texts
except ImportError: # pragma: no cover - fallback for non-installed contexts
try:
from scripts.validate_questions import broken_question_texts
except ImportError:
from validate_questions import broken_question_texts # type: ignore
broken = broken_question_texts(questions_path, db_dir)
if not broken:
return questions
kept = [q for q in questions if q.get("question_text") not in broken]
logger.warning(
"Excluded %d harness-broken question(s) from %s before training "
"(run scripts/validate_questions.py for details).",
len(questions) - len(kept),
questions_path,
)
return kept
def load_question_prompts(
questions_path: str,
allowed: list[str] | None = None,
db_dir: str | None = None,
) -> list[dict[str, str]]:
"""Load question text prompts from JSON and apply difficulty filtering.
When ``db_dir`` is provided, harness-broken questions (broken/empty gold
answers) are dropped before prompts are built, so they never reach the
gradient. Leave ``db_dir`` as None to skip the (slightly slower) gold check.
"""
path = Path(questions_path)
if not path.exists():
raise FileNotFoundError(f"Questions file not found: {questions_path}")
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc
if not isinstance(payload, list) or not payload:
raise ValueError(f"Questions file is empty or invalid: {questions_path}")
filtered = filter_questions_by_difficulty(payload, allowed)
if not filtered:
raise ValueError(
f"No questions match difficulty_filter={allowed} in {questions_path}"
)
if db_dir is not None:
filtered = _drop_harness_broken(filtered, questions_path, db_dir)
if not filtered:
raise ValueError(
f"All questions in {questions_path} are harness-broken "
f"for db_dir={db_dir}"
)
prompts = [
{"prompt": str(item["question_text"])}
for item in filtered
if item.get("question_text")
]
if not prompts:
raise ValueError(f"No usable question_text values found in {questions_path}")
return prompts
def validate_no_data_leak(
train_path: str,
eval_path: str,
) -> None:
"""Assert zero question overlap between train and eval sets.
Raises
------
ValueError
If any question text appears in both files.
"""
train = json.loads(Path(train_path).read_text(encoding="utf-8"))
eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8"))
train_qs = {q["question_text"] for q in train if "question_text" in q}
eval_qs = {q["question_text"] for q in eval_ if "question_text" in q}
overlap = train_qs & eval_qs
if overlap:
examples = list(overlap)[:3]
raise ValueError(
f"Data leak: {len(overlap)} questions appear in both train and eval. "
f"Examples: {examples}"
)
def load_model_and_tokenizer(
model_name: str,
*,
revision: str | None = None,
torch_dtype: Any | None = None,
) -> tuple[Any, Any]:
"""Load HuggingFace tokenizer and model with fail-fast errors.
``transformers`` is imported lazily here (not at module top) so the data /
guardrail helpers in this module stay importable without the heavy training
extras installed.
Pass ``revision`` (a Hub commit SHA or tag) to pin the exact weights for
reproducibility — without it, ``model_name`` resolves to whatever is HEAD on
the Hub at download time, so two runs weeks apart can load different weights.
``torch_dtype`` (e.g. "bfloat16") is forwarded when set.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer_kwargs: dict[str, Any] = {}
model_kwargs: dict[str, Any] = {}
if revision is not None:
tokenizer_kwargs["revision"] = revision
model_kwargs["revision"] = revision
if torch_dtype is not None:
# dtype applies to the model only — the tokenizer has no dtype, and passing
# it there is at best ignored, at worst an error on some tokenizers.
model_kwargs["torch_dtype"] = torch_dtype
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
except Exception as exc: # pragma: no cover - covered by monkeypatched tests
raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc
return model, tokenizer