Spaces:
Running on Zero
Running on Zero
File size: 5,917 Bytes
656f91e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """Data/model loading helpers for the GRPO training notebook."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
def filter_questions_by_difficulty(
questions: list[dict[str, Any]], allowed: list[str] | None
) -> list[dict[str, Any]]:
"""Filter question records by case-insensitive difficulty labels."""
if not allowed:
return questions
allowed_set = {level.lower() for level in allowed}
return [
question
for question in questions
if str(question.get("difficulty", "")).lower() in allowed_set
]
def _drop_harness_broken(
questions: list[dict[str, Any]], questions_path: str, db_dir: str
) -> list[dict[str, Any]]:
"""Drop questions whose gold answer is harness-broken (see validator)."""
# The installed-package form (sql_env.scripts) is tried first — it is the
# only one that resolves when analyst-buddy is pip-installed. The bare forms
# cover pytest (pythonpath=".") and direct script invocation.
try:
from sql_env.scripts.validate_questions import broken_question_texts
except ImportError: # pragma: no cover - fallback for non-installed contexts
try:
from scripts.validate_questions import broken_question_texts
except ImportError:
from validate_questions import broken_question_texts # type: ignore
broken = broken_question_texts(questions_path, db_dir)
if not broken:
return questions
kept = [q for q in questions if q.get("question_text") not in broken]
logger.warning(
"Excluded %d harness-broken question(s) from %s before training "
"(run scripts/validate_questions.py for details).",
len(questions) - len(kept),
questions_path,
)
return kept
def load_question_prompts(
questions_path: str,
allowed: list[str] | None = None,
db_dir: str | None = None,
) -> list[dict[str, str]]:
"""Load question text prompts from JSON and apply difficulty filtering.
When ``db_dir`` is provided, harness-broken questions (broken/empty gold
answers) are dropped before prompts are built, so they never reach the
gradient. Leave ``db_dir`` as None to skip the (slightly slower) gold check.
"""
path = Path(questions_path)
if not path.exists():
raise FileNotFoundError(f"Questions file not found: {questions_path}")
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc
if not isinstance(payload, list) or not payload:
raise ValueError(f"Questions file is empty or invalid: {questions_path}")
filtered = filter_questions_by_difficulty(payload, allowed)
if not filtered:
raise ValueError(
f"No questions match difficulty_filter={allowed} in {questions_path}"
)
if db_dir is not None:
filtered = _drop_harness_broken(filtered, questions_path, db_dir)
if not filtered:
raise ValueError(
f"All questions in {questions_path} are harness-broken "
f"for db_dir={db_dir}"
)
prompts = [
{"prompt": str(item["question_text"])}
for item in filtered
if item.get("question_text")
]
if not prompts:
raise ValueError(f"No usable question_text values found in {questions_path}")
return prompts
def validate_no_data_leak(
train_path: str,
eval_path: str,
) -> None:
"""Assert zero question overlap between train and eval sets.
Raises
------
ValueError
If any question text appears in both files.
"""
train = json.loads(Path(train_path).read_text(encoding="utf-8"))
eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8"))
train_qs = {q["question_text"] for q in train if "question_text" in q}
eval_qs = {q["question_text"] for q in eval_ if "question_text" in q}
overlap = train_qs & eval_qs
if overlap:
examples = list(overlap)[:3]
raise ValueError(
f"Data leak: {len(overlap)} questions appear in both train and eval. "
f"Examples: {examples}"
)
def load_model_and_tokenizer(
model_name: str,
*,
revision: str | None = None,
torch_dtype: Any | None = None,
) -> tuple[Any, Any]:
"""Load HuggingFace tokenizer and model with fail-fast errors.
``transformers`` is imported lazily here (not at module top) so the data /
guardrail helpers in this module stay importable without the heavy training
extras installed.
Pass ``revision`` (a Hub commit SHA or tag) to pin the exact weights for
reproducibility — without it, ``model_name`` resolves to whatever is HEAD on
the Hub at download time, so two runs weeks apart can load different weights.
``torch_dtype`` (e.g. "bfloat16") is forwarded when set.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer_kwargs: dict[str, Any] = {}
model_kwargs: dict[str, Any] = {}
if revision is not None:
tokenizer_kwargs["revision"] = revision
model_kwargs["revision"] = revision
if torch_dtype is not None:
# dtype applies to the model only — the tokenizer has no dtype, and passing
# it there is at best ignored, at worst an error on some tokenizers.
model_kwargs["torch_dtype"] = torch_dtype
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
except Exception as exc: # pragma: no cover - covered by monkeypatched tests
raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc
return model, tokenizer
|