"""Data/model loading helpers for the GRPO training notebook.""" from __future__ import annotations import json import logging from pathlib import Path from typing import Any logger = logging.getLogger(__name__) def filter_questions_by_difficulty( questions: list[dict[str, Any]], allowed: list[str] | None ) -> list[dict[str, Any]]: """Filter question records by case-insensitive difficulty labels.""" if not allowed: return questions allowed_set = {level.lower() for level in allowed} return [ question for question in questions if str(question.get("difficulty", "")).lower() in allowed_set ] def _drop_harness_broken( questions: list[dict[str, Any]], questions_path: str, db_dir: str ) -> list[dict[str, Any]]: """Drop questions whose gold answer is harness-broken (see validator).""" # The installed-package form (sql_env.scripts) is tried first — it is the # only one that resolves when analyst-buddy is pip-installed. The bare forms # cover pytest (pythonpath=".") and direct script invocation. try: from sql_env.scripts.validate_questions import broken_question_texts except ImportError: # pragma: no cover - fallback for non-installed contexts try: from scripts.validate_questions import broken_question_texts except ImportError: from validate_questions import broken_question_texts # type: ignore broken = broken_question_texts(questions_path, db_dir) if not broken: return questions kept = [q for q in questions if q.get("question_text") not in broken] logger.warning( "Excluded %d harness-broken question(s) from %s before training " "(run scripts/validate_questions.py for details).", len(questions) - len(kept), questions_path, ) return kept def load_question_prompts( questions_path: str, allowed: list[str] | None = None, db_dir: str | None = None, ) -> list[dict[str, str]]: """Load question text prompts from JSON and apply difficulty filtering. When ``db_dir`` is provided, harness-broken questions (broken/empty gold answers) are dropped before prompts are built, so they never reach the gradient. Leave ``db_dir`` as None to skip the (slightly slower) gold check. """ path = Path(questions_path) if not path.exists(): raise FileNotFoundError(f"Questions file not found: {questions_path}") try: payload = json.loads(path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc if not isinstance(payload, list) or not payload: raise ValueError(f"Questions file is empty or invalid: {questions_path}") filtered = filter_questions_by_difficulty(payload, allowed) if not filtered: raise ValueError( f"No questions match difficulty_filter={allowed} in {questions_path}" ) if db_dir is not None: filtered = _drop_harness_broken(filtered, questions_path, db_dir) if not filtered: raise ValueError( f"All questions in {questions_path} are harness-broken " f"for db_dir={db_dir}" ) prompts = [ {"prompt": str(item["question_text"])} for item in filtered if item.get("question_text") ] if not prompts: raise ValueError(f"No usable question_text values found in {questions_path}") return prompts def validate_no_data_leak( train_path: str, eval_path: str, ) -> None: """Assert zero question overlap between train and eval sets. Raises ------ ValueError If any question text appears in both files. """ train = json.loads(Path(train_path).read_text(encoding="utf-8")) eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8")) train_qs = {q["question_text"] for q in train if "question_text" in q} eval_qs = {q["question_text"] for q in eval_ if "question_text" in q} overlap = train_qs & eval_qs if overlap: examples = list(overlap)[:3] raise ValueError( f"Data leak: {len(overlap)} questions appear in both train and eval. " f"Examples: {examples}" ) def load_model_and_tokenizer( model_name: str, *, revision: str | None = None, torch_dtype: Any | None = None, ) -> tuple[Any, Any]: """Load HuggingFace tokenizer and model with fail-fast errors. ``transformers`` is imported lazily here (not at module top) so the data / guardrail helpers in this module stay importable without the heavy training extras installed. Pass ``revision`` (a Hub commit SHA or tag) to pin the exact weights for reproducibility — without it, ``model_name`` resolves to whatever is HEAD on the Hub at download time, so two runs weeks apart can load different weights. ``torch_dtype`` (e.g. "bfloat16") is forwarded when set. """ from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer_kwargs: dict[str, Any] = {} model_kwargs: dict[str, Any] = {} if revision is not None: tokenizer_kwargs["revision"] = revision model_kwargs["revision"] = revision if torch_dtype is not None: # dtype applies to the model only — the tokenizer has no dtype, and passing # it there is at best ignored, at worst an error on some tokenizers. model_kwargs["torch_dtype"] = torch_dtype try: tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) except Exception as exc: # pragma: no cover - covered by monkeypatched tests raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc return model, tokenizer