Spaces:
Sleeping
Sleeping
File size: 9,073 Bytes
6761f70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | """
Tests for Step 9 β Standalone Training Script Export.
Covers:
β’ Generated script is valid Python (ast.parse succeeds)
β’ CONFIG dict contains correct hyperparameters from recipe
β’ base_model name is in the script
β’ QLoRA: BitsAndBytesConfig import present
β’ LoRA: peft imports present
β’ Full fine-tune: no LoRA/PEFT imports
β’ Label names embedded correctly
β’ CodeGenerationError never raised on valid inputs
β’ Empty training_result β handled gracefully (no KeyError)
"""
from __future__ import annotations
import ast
import pytest
# Locate the code generator (it's in agents/services/, not agents/agents/)
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services"))
from code_generator import generate_training_script, CodeGenerationError
# ββ Fixtures ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _task_spec(**kwargs):
base = {
"task_type": "text_classification",
"num_labels": 3,
"label_names": ["positive", "neutral", "negative"],
"input_column": "text",
"label_column": "label",
}
base.update(kwargs)
return base
def _data_profile(**kwargs):
base = {
"num_rows": 500,
"num_classes": 3,
"label_distribution": {"positive": 200, "neutral": 150, "negative": 150},
}
base.update(kwargs)
return base
def _recipe(approach: str = "full_finetune", **kwargs):
base = {
"base_model": "bert-base-uncased",
"training_approach": approach,
"learning_rate": 2e-5,
"num_epochs": 3,
"batch_size": 16,
"max_length": 128,
"warmup_ratio": 0.1,
"weight_decay": 0.01,
"lora_r": 16,
"lora_alpha": 32,
}
base.update(kwargs)
return base
def _training_result(**kwargs):
base = {
"accuracy": 0.88,
"f1": 0.87,
"label_names": ["positive", "neutral", "negative"],
"num_labels": 3,
}
base.update(kwargs)
return base
# ββ Syntax correctness ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestSyntaxCorrectness:
def test_full_finetune_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
ast.parse(script) # no exception = valid
def test_lora_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("lora"), _training_result()
)
ast.parse(script)
def test_qlora_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
ast.parse(script)
def test_many_labels_is_valid_python(self):
labels = [f"class_{i}" for i in range(20)]
script = generate_training_script(
_task_spec(label_names=labels, num_labels=20),
_data_profile(),
_recipe(),
_training_result(label_names=labels, num_labels=20),
)
ast.parse(script)
def test_empty_label_names_is_valid_python(self):
"""Falls back to empty list β still valid Python."""
script = generate_training_script(
_task_spec(label_names=None),
_data_profile(label_distribution={}),
_recipe(),
_training_result(label_names=None),
)
ast.parse(script)
# ββ CONFIG dict content βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestConfigContent:
def test_base_model_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(base_model="roberta-base"), _training_result()
)
assert '"roberta-base"' in script or "'roberta-base'" in script
def test_learning_rate_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(learning_rate=3e-5), _training_result()
)
assert "3e-05" in script or "3e-5" in script or "0.00003" in script
def test_num_epochs_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(num_epochs=7), _training_result()
)
assert "7" in script
def test_batch_size_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(batch_size=32), _training_result()
)
assert "32" in script
def test_label_names_in_script(self):
script = generate_training_script(
_task_spec(label_names=["spam", "ham"]), _data_profile(),
_recipe(), _training_result(label_names=["spam", "ham"])
)
assert "spam" in script
assert "ham" in script
def test_input_column_in_script(self):
script = generate_training_script(
_task_spec(input_column="review_text"), _data_profile(),
_recipe(), _training_result()
)
assert "review_text" in script
def test_label_column_in_script(self):
script = generate_training_script(
_task_spec(label_column="sentiment"), _data_profile(),
_recipe(), _training_result()
)
assert "sentiment" in script
# ββ Approach-specific content βββββββββββββββββββββββββββββββββββββββββββββββββ
class TestApproachContent:
def test_qlora_has_bnb_import(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
assert "BitsAndBytesConfig" in script
def test_lora_has_peft_import(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("lora"), _training_result()
)
assert "peft" in script
assert "LoraConfig" in script
def test_full_finetune_no_qlora(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
assert "BitsAndBytesConfig" not in script
def test_full_finetune_no_lora_adapter(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
assert "LoraConfig" not in script
assert "get_peft_model" not in script
def test_qlora_has_load_in_4bit(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
assert "load_in_4bit" in script
# ββ Edge cases ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestEdgeCases:
def test_empty_training_result_no_crash(self):
"""Missing metrics β header shows 'unavailable' but script still valid."""
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), {}
)
ast.parse(script)
assert "metrics unavailable" in script or "unavailable" in script
def test_special_chars_in_label_names(self):
"""Label names with special characters are repr()-escaped safely."""
labels = ["class/A", "class B", "class'C"]
script = generate_training_script(
_task_spec(label_names=labels),
_data_profile(),
_recipe(),
_training_result(label_names=labels),
)
ast.parse(script) # must not have syntax errors
def test_returns_string(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert isinstance(script, str)
assert len(script) > 500 # non-trivial script
def test_has_main_guard(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert '__name__ == "__main__"' in script or "__main__" in script
def test_has_argparse(self):
"""Script must accept --data_path argument."""
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert "argparse" in script
assert "--data_path" in script
|