File size: 9,073 Bytes
6761f70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
Tests for Step 9 β€” Standalone Training Script Export.

Covers:
  β€’ Generated script is valid Python (ast.parse succeeds)
  β€’ CONFIG dict contains correct hyperparameters from recipe
  β€’ base_model name is in the script
  β€’ QLoRA: BitsAndBytesConfig import present
  β€’ LoRA: peft imports present
  β€’ Full fine-tune: no LoRA/PEFT imports
  β€’ Label names embedded correctly
  β€’ CodeGenerationError never raised on valid inputs
  β€’ Empty training_result β†’ handled gracefully (no KeyError)
"""
from __future__ import annotations

import ast

import pytest

# Locate the code generator (it's in agents/services/, not agents/agents/)
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services"))

from code_generator import generate_training_script, CodeGenerationError


# ── Fixtures ──────────────────────────────────────────────────────────────────

def _task_spec(**kwargs):
    base = {
        "task_type": "text_classification",
        "num_labels": 3,
        "label_names": ["positive", "neutral", "negative"],
        "input_column": "text",
        "label_column": "label",
    }
    base.update(kwargs)
    return base


def _data_profile(**kwargs):
    base = {
        "num_rows": 500,
        "num_classes": 3,
        "label_distribution": {"positive": 200, "neutral": 150, "negative": 150},
    }
    base.update(kwargs)
    return base


def _recipe(approach: str = "full_finetune", **kwargs):
    base = {
        "base_model":         "bert-base-uncased",
        "training_approach":  approach,
        "learning_rate":      2e-5,
        "num_epochs":         3,
        "batch_size":         16,
        "max_length":         128,
        "warmup_ratio":       0.1,
        "weight_decay":       0.01,
        "lora_r":             16,
        "lora_alpha":         32,
    }
    base.update(kwargs)
    return base


def _training_result(**kwargs):
    base = {
        "accuracy":    0.88,
        "f1":          0.87,
        "label_names": ["positive", "neutral", "negative"],
        "num_labels":  3,
    }
    base.update(kwargs)
    return base


# ── Syntax correctness ────────────────────────────────────────────────────────

class TestSyntaxCorrectness:
    def test_full_finetune_is_valid_python(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
        )
        ast.parse(script)  # no exception = valid

    def test_lora_is_valid_python(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("lora"), _training_result()
        )
        ast.parse(script)

    def test_qlora_is_valid_python(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("qlora"), _training_result()
        )
        ast.parse(script)

    def test_many_labels_is_valid_python(self):
        labels = [f"class_{i}" for i in range(20)]
        script = generate_training_script(
            _task_spec(label_names=labels, num_labels=20),
            _data_profile(),
            _recipe(),
            _training_result(label_names=labels, num_labels=20),
        )
        ast.parse(script)

    def test_empty_label_names_is_valid_python(self):
        """Falls back to empty list β€” still valid Python."""
        script = generate_training_script(
            _task_spec(label_names=None),
            _data_profile(label_distribution={}),
            _recipe(),
            _training_result(label_names=None),
        )
        ast.parse(script)


# ── CONFIG dict content ───────────────────────────────────────────────────────

class TestConfigContent:
    def test_base_model_in_config(self):
        script = generate_training_script(
            _task_spec(), _data_profile(),
            _recipe(base_model="roberta-base"), _training_result()
        )
        assert '"roberta-base"' in script or "'roberta-base'" in script

    def test_learning_rate_in_config(self):
        script = generate_training_script(
            _task_spec(), _data_profile(),
            _recipe(learning_rate=3e-5), _training_result()
        )
        assert "3e-05" in script or "3e-5" in script or "0.00003" in script

    def test_num_epochs_in_config(self):
        script = generate_training_script(
            _task_spec(), _data_profile(),
            _recipe(num_epochs=7), _training_result()
        )
        assert "7" in script

    def test_batch_size_in_config(self):
        script = generate_training_script(
            _task_spec(), _data_profile(),
            _recipe(batch_size=32), _training_result()
        )
        assert "32" in script

    def test_label_names_in_script(self):
        script = generate_training_script(
            _task_spec(label_names=["spam", "ham"]), _data_profile(),
            _recipe(), _training_result(label_names=["spam", "ham"])
        )
        assert "spam" in script
        assert "ham" in script

    def test_input_column_in_script(self):
        script = generate_training_script(
            _task_spec(input_column="review_text"), _data_profile(),
            _recipe(), _training_result()
        )
        assert "review_text" in script

    def test_label_column_in_script(self):
        script = generate_training_script(
            _task_spec(label_column="sentiment"), _data_profile(),
            _recipe(), _training_result()
        )
        assert "sentiment" in script


# ── Approach-specific content ─────────────────────────────────────────────────

class TestApproachContent:
    def test_qlora_has_bnb_import(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("qlora"), _training_result()
        )
        assert "BitsAndBytesConfig" in script

    def test_lora_has_peft_import(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("lora"), _training_result()
        )
        assert "peft" in script
        assert "LoraConfig" in script

    def test_full_finetune_no_qlora(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
        )
        assert "BitsAndBytesConfig" not in script

    def test_full_finetune_no_lora_adapter(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
        )
        assert "LoraConfig" not in script
        assert "get_peft_model" not in script

    def test_qlora_has_load_in_4bit(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe("qlora"), _training_result()
        )
        assert "load_in_4bit" in script


# ── Edge cases ────────────────────────────────────────────────────────────────

class TestEdgeCases:
    def test_empty_training_result_no_crash(self):
        """Missing metrics β†’ header shows 'unavailable' but script still valid."""
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe(), {}
        )
        ast.parse(script)
        assert "metrics unavailable" in script or "unavailable" in script

    def test_special_chars_in_label_names(self):
        """Label names with special characters are repr()-escaped safely."""
        labels = ["class/A", "class B", "class'C"]
        script = generate_training_script(
            _task_spec(label_names=labels),
            _data_profile(),
            _recipe(),
            _training_result(label_names=labels),
        )
        ast.parse(script)  # must not have syntax errors

    def test_returns_string(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe(), _training_result()
        )
        assert isinstance(script, str)
        assert len(script) > 500  # non-trivial script

    def test_has_main_guard(self):
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe(), _training_result()
        )
        assert '__name__ == "__main__"' in script or "__main__" in script

    def test_has_argparse(self):
        """Script must accept --data_path argument."""
        script = generate_training_script(
            _task_spec(), _data_profile(), _recipe(), _training_result()
        )
        assert "argparse" in script
        assert "--data_path" in script