Spaces:
Running on Zero
Running on Zero
| """Data/model loading helpers for the GRPO training notebook.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| def filter_questions_by_difficulty( | |
| questions: list[dict[str, Any]], allowed: list[str] | None | |
| ) -> list[dict[str, Any]]: | |
| """Filter question records by case-insensitive difficulty labels.""" | |
| if not allowed: | |
| return questions | |
| allowed_set = {level.lower() for level in allowed} | |
| return [ | |
| question | |
| for question in questions | |
| if str(question.get("difficulty", "")).lower() in allowed_set | |
| ] | |
| def _drop_harness_broken( | |
| questions: list[dict[str, Any]], questions_path: str, db_dir: str | |
| ) -> list[dict[str, Any]]: | |
| """Drop questions whose gold answer is harness-broken (see validator).""" | |
| # The installed-package form (sql_env.scripts) is tried first — it is the | |
| # only one that resolves when analyst-buddy is pip-installed. The bare forms | |
| # cover pytest (pythonpath=".") and direct script invocation. | |
| try: | |
| from sql_env.scripts.validate_questions import broken_question_texts | |
| except ImportError: # pragma: no cover - fallback for non-installed contexts | |
| try: | |
| from scripts.validate_questions import broken_question_texts | |
| except ImportError: | |
| from validate_questions import broken_question_texts # type: ignore | |
| broken = broken_question_texts(questions_path, db_dir) | |
| if not broken: | |
| return questions | |
| kept = [q for q in questions if q.get("question_text") not in broken] | |
| logger.warning( | |
| "Excluded %d harness-broken question(s) from %s before training " | |
| "(run scripts/validate_questions.py for details).", | |
| len(questions) - len(kept), | |
| questions_path, | |
| ) | |
| return kept | |
| def load_question_prompts( | |
| questions_path: str, | |
| allowed: list[str] | None = None, | |
| db_dir: str | None = None, | |
| ) -> list[dict[str, str]]: | |
| """Load question text prompts from JSON and apply difficulty filtering. | |
| When ``db_dir`` is provided, harness-broken questions (broken/empty gold | |
| answers) are dropped before prompts are built, so they never reach the | |
| gradient. Leave ``db_dir`` as None to skip the (slightly slower) gold check. | |
| """ | |
| path = Path(questions_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Questions file not found: {questions_path}") | |
| try: | |
| payload = json.loads(path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc | |
| if not isinstance(payload, list) or not payload: | |
| raise ValueError(f"Questions file is empty or invalid: {questions_path}") | |
| filtered = filter_questions_by_difficulty(payload, allowed) | |
| if not filtered: | |
| raise ValueError( | |
| f"No questions match difficulty_filter={allowed} in {questions_path}" | |
| ) | |
| if db_dir is not None: | |
| filtered = _drop_harness_broken(filtered, questions_path, db_dir) | |
| if not filtered: | |
| raise ValueError( | |
| f"All questions in {questions_path} are harness-broken " | |
| f"for db_dir={db_dir}" | |
| ) | |
| prompts = [ | |
| {"prompt": str(item["question_text"])} | |
| for item in filtered | |
| if item.get("question_text") | |
| ] | |
| if not prompts: | |
| raise ValueError(f"No usable question_text values found in {questions_path}") | |
| return prompts | |
| def validate_no_data_leak( | |
| train_path: str, | |
| eval_path: str, | |
| ) -> None: | |
| """Assert zero question overlap between train and eval sets. | |
| Raises | |
| ------ | |
| ValueError | |
| If any question text appears in both files. | |
| """ | |
| train = json.loads(Path(train_path).read_text(encoding="utf-8")) | |
| eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8")) | |
| train_qs = {q["question_text"] for q in train if "question_text" in q} | |
| eval_qs = {q["question_text"] for q in eval_ if "question_text" in q} | |
| overlap = train_qs & eval_qs | |
| if overlap: | |
| examples = list(overlap)[:3] | |
| raise ValueError( | |
| f"Data leak: {len(overlap)} questions appear in both train and eval. " | |
| f"Examples: {examples}" | |
| ) | |
| def load_model_and_tokenizer( | |
| model_name: str, | |
| *, | |
| revision: str | None = None, | |
| torch_dtype: Any | None = None, | |
| ) -> tuple[Any, Any]: | |
| """Load HuggingFace tokenizer and model with fail-fast errors. | |
| ``transformers`` is imported lazily here (not at module top) so the data / | |
| guardrail helpers in this module stay importable without the heavy training | |
| extras installed. | |
| Pass ``revision`` (a Hub commit SHA or tag) to pin the exact weights for | |
| reproducibility — without it, ``model_name`` resolves to whatever is HEAD on | |
| the Hub at download time, so two runs weeks apart can load different weights. | |
| ``torch_dtype`` (e.g. "bfloat16") is forwarded when set. | |
| """ | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer_kwargs: dict[str, Any] = {} | |
| model_kwargs: dict[str, Any] = {} | |
| if revision is not None: | |
| tokenizer_kwargs["revision"] = revision | |
| model_kwargs["revision"] = revision | |
| if torch_dtype is not None: | |
| # dtype applies to the model only — the tokenizer has no dtype, and passing | |
| # it there is at best ignored, at worst an error on some tokenizers. | |
| model_kwargs["torch_dtype"] = torch_dtype | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | |
| except Exception as exc: # pragma: no cover - covered by monkeypatched tests | |
| raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc | |
| return model, tokenizer | |