File size: 5,917 Bytes
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Data/model loading helpers for the GRPO training notebook."""

from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)


def filter_questions_by_difficulty(
    questions: list[dict[str, Any]], allowed: list[str] | None
) -> list[dict[str, Any]]:
    """Filter question records by case-insensitive difficulty labels."""

    if not allowed:
        return questions

    allowed_set = {level.lower() for level in allowed}
    return [
        question
        for question in questions
        if str(question.get("difficulty", "")).lower() in allowed_set
    ]


def _drop_harness_broken(
    questions: list[dict[str, Any]], questions_path: str, db_dir: str
) -> list[dict[str, Any]]:
    """Drop questions whose gold answer is harness-broken (see validator)."""
    # The installed-package form (sql_env.scripts) is tried first — it is the
    # only one that resolves when analyst-buddy is pip-installed. The bare forms
    # cover pytest (pythonpath=".") and direct script invocation.
    try:
        from sql_env.scripts.validate_questions import broken_question_texts
    except ImportError:  # pragma: no cover - fallback for non-installed contexts
        try:
            from scripts.validate_questions import broken_question_texts
        except ImportError:
            from validate_questions import broken_question_texts  # type: ignore

    broken = broken_question_texts(questions_path, db_dir)
    if not broken:
        return questions
    kept = [q for q in questions if q.get("question_text") not in broken]
    logger.warning(
        "Excluded %d harness-broken question(s) from %s before training "
        "(run scripts/validate_questions.py for details).",
        len(questions) - len(kept),
        questions_path,
    )
    return kept


def load_question_prompts(
    questions_path: str,
    allowed: list[str] | None = None,
    db_dir: str | None = None,
) -> list[dict[str, str]]:
    """Load question text prompts from JSON and apply difficulty filtering.

    When ``db_dir`` is provided, harness-broken questions (broken/empty gold
    answers) are dropped before prompts are built, so they never reach the
    gradient. Leave ``db_dir`` as None to skip the (slightly slower) gold check.
    """

    path = Path(questions_path)
    if not path.exists():
        raise FileNotFoundError(f"Questions file not found: {questions_path}")

    try:
        payload = json.loads(path.read_text(encoding="utf-8"))
    except json.JSONDecodeError as exc:
        raise ValueError(f"Invalid JSON in questions file: {questions_path}") from exc

    if not isinstance(payload, list) or not payload:
        raise ValueError(f"Questions file is empty or invalid: {questions_path}")

    filtered = filter_questions_by_difficulty(payload, allowed)
    if not filtered:
        raise ValueError(
            f"No questions match difficulty_filter={allowed} in {questions_path}"
        )

    if db_dir is not None:
        filtered = _drop_harness_broken(filtered, questions_path, db_dir)
        if not filtered:
            raise ValueError(
                f"All questions in {questions_path} are harness-broken "
                f"for db_dir={db_dir}"
            )

    prompts = [
        {"prompt": str(item["question_text"])}
        for item in filtered
        if item.get("question_text")
    ]
    if not prompts:
        raise ValueError(f"No usable question_text values found in {questions_path}")

    return prompts


def validate_no_data_leak(
    train_path: str,
    eval_path: str,
) -> None:
    """Assert zero question overlap between train and eval sets.

    Raises
    ------
    ValueError
        If any question text appears in both files.
    """
    train = json.loads(Path(train_path).read_text(encoding="utf-8"))
    eval_ = json.loads(Path(eval_path).read_text(encoding="utf-8"))

    train_qs = {q["question_text"] for q in train if "question_text" in q}
    eval_qs = {q["question_text"] for q in eval_ if "question_text" in q}
    overlap = train_qs & eval_qs

    if overlap:
        examples = list(overlap)[:3]
        raise ValueError(
            f"Data leak: {len(overlap)} questions appear in both train and eval. "
            f"Examples: {examples}"
        )


def load_model_and_tokenizer(
    model_name: str,
    *,
    revision: str | None = None,
    torch_dtype: Any | None = None,
) -> tuple[Any, Any]:
    """Load HuggingFace tokenizer and model with fail-fast errors.

    ``transformers`` is imported lazily here (not at module top) so the data /
    guardrail helpers in this module stay importable without the heavy training
    extras installed.

    Pass ``revision`` (a Hub commit SHA or tag) to pin the exact weights for
    reproducibility — without it, ``model_name`` resolves to whatever is HEAD on
    the Hub at download time, so two runs weeks apart can load different weights.
    ``torch_dtype`` (e.g. "bfloat16") is forwarded when set.
    """

    from transformers import AutoModelForCausalLM, AutoTokenizer

    tokenizer_kwargs: dict[str, Any] = {}
    model_kwargs: dict[str, Any] = {}
    if revision is not None:
        tokenizer_kwargs["revision"] = revision
        model_kwargs["revision"] = revision
    if torch_dtype is not None:
        # dtype applies to the model only — the tokenizer has no dtype, and passing
        # it there is at best ignored, at worst an error on some tokenizers.
        model_kwargs["torch_dtype"] = torch_dtype

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
        model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    except Exception as exc:  # pragma: no cover - covered by monkeypatched tests
        raise RuntimeError(f"Cannot load model '{model_name}': {exc}") from exc
    return model, tokenizer