Spaces:
Sleeping
Sleeping
| """ | |
| Tests for Step 9 β Standalone Training Script Export. | |
| Covers: | |
| β’ Generated script is valid Python (ast.parse succeeds) | |
| β’ CONFIG dict contains correct hyperparameters from recipe | |
| β’ base_model name is in the script | |
| β’ QLoRA: BitsAndBytesConfig import present | |
| β’ LoRA: peft imports present | |
| β’ Full fine-tune: no LoRA/PEFT imports | |
| β’ Label names embedded correctly | |
| β’ CodeGenerationError never raised on valid inputs | |
| β’ Empty training_result β handled gracefully (no KeyError) | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import pytest | |
| # Locate the code generator (it's in agents/services/, not agents/agents/) | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services")) | |
| from code_generator import generate_training_script, CodeGenerationError | |
| # ββ Fixtures ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _task_spec(**kwargs): | |
| base = { | |
| "task_type": "text_classification", | |
| "num_labels": 3, | |
| "label_names": ["positive", "neutral", "negative"], | |
| "input_column": "text", | |
| "label_column": "label", | |
| } | |
| base.update(kwargs) | |
| return base | |
| def _data_profile(**kwargs): | |
| base = { | |
| "num_rows": 500, | |
| "num_classes": 3, | |
| "label_distribution": {"positive": 200, "neutral": 150, "negative": 150}, | |
| } | |
| base.update(kwargs) | |
| return base | |
| def _recipe(approach: str = "full_finetune", **kwargs): | |
| base = { | |
| "base_model": "bert-base-uncased", | |
| "training_approach": approach, | |
| "learning_rate": 2e-5, | |
| "num_epochs": 3, | |
| "batch_size": 16, | |
| "max_length": 128, | |
| "warmup_ratio": 0.1, | |
| "weight_decay": 0.01, | |
| "lora_r": 16, | |
| "lora_alpha": 32, | |
| } | |
| base.update(kwargs) | |
| return base | |
| def _training_result(**kwargs): | |
| base = { | |
| "accuracy": 0.88, | |
| "f1": 0.87, | |
| "label_names": ["positive", "neutral", "negative"], | |
| "num_labels": 3, | |
| } | |
| base.update(kwargs) | |
| return base | |
| # ββ Syntax correctness ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestSyntaxCorrectness: | |
| def test_full_finetune_is_valid_python(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result() | |
| ) | |
| ast.parse(script) # no exception = valid | |
| def test_lora_is_valid_python(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("lora"), _training_result() | |
| ) | |
| ast.parse(script) | |
| def test_qlora_is_valid_python(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("qlora"), _training_result() | |
| ) | |
| ast.parse(script) | |
| def test_many_labels_is_valid_python(self): | |
| labels = [f"class_{i}" for i in range(20)] | |
| script = generate_training_script( | |
| _task_spec(label_names=labels, num_labels=20), | |
| _data_profile(), | |
| _recipe(), | |
| _training_result(label_names=labels, num_labels=20), | |
| ) | |
| ast.parse(script) | |
| def test_empty_label_names_is_valid_python(self): | |
| """Falls back to empty list β still valid Python.""" | |
| script = generate_training_script( | |
| _task_spec(label_names=None), | |
| _data_profile(label_distribution={}), | |
| _recipe(), | |
| _training_result(label_names=None), | |
| ) | |
| ast.parse(script) | |
| # ββ CONFIG dict content βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestConfigContent: | |
| def test_base_model_in_config(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), | |
| _recipe(base_model="roberta-base"), _training_result() | |
| ) | |
| assert '"roberta-base"' in script or "'roberta-base'" in script | |
| def test_learning_rate_in_config(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), | |
| _recipe(learning_rate=3e-5), _training_result() | |
| ) | |
| assert "3e-05" in script or "3e-5" in script or "0.00003" in script | |
| def test_num_epochs_in_config(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), | |
| _recipe(num_epochs=7), _training_result() | |
| ) | |
| assert "7" in script | |
| def test_batch_size_in_config(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), | |
| _recipe(batch_size=32), _training_result() | |
| ) | |
| assert "32" in script | |
| def test_label_names_in_script(self): | |
| script = generate_training_script( | |
| _task_spec(label_names=["spam", "ham"]), _data_profile(), | |
| _recipe(), _training_result(label_names=["spam", "ham"]) | |
| ) | |
| assert "spam" in script | |
| assert "ham" in script | |
| def test_input_column_in_script(self): | |
| script = generate_training_script( | |
| _task_spec(input_column="review_text"), _data_profile(), | |
| _recipe(), _training_result() | |
| ) | |
| assert "review_text" in script | |
| def test_label_column_in_script(self): | |
| script = generate_training_script( | |
| _task_spec(label_column="sentiment"), _data_profile(), | |
| _recipe(), _training_result() | |
| ) | |
| assert "sentiment" in script | |
| # ββ Approach-specific content βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestApproachContent: | |
| def test_qlora_has_bnb_import(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("qlora"), _training_result() | |
| ) | |
| assert "BitsAndBytesConfig" in script | |
| def test_lora_has_peft_import(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("lora"), _training_result() | |
| ) | |
| assert "peft" in script | |
| assert "LoraConfig" in script | |
| def test_full_finetune_no_qlora(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result() | |
| ) | |
| assert "BitsAndBytesConfig" not in script | |
| def test_full_finetune_no_lora_adapter(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("full_finetune"), _training_result() | |
| ) | |
| assert "LoraConfig" not in script | |
| assert "get_peft_model" not in script | |
| def test_qlora_has_load_in_4bit(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe("qlora"), _training_result() | |
| ) | |
| assert "load_in_4bit" in script | |
| # ββ Edge cases ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestEdgeCases: | |
| def test_empty_training_result_no_crash(self): | |
| """Missing metrics β header shows 'unavailable' but script still valid.""" | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe(), {} | |
| ) | |
| ast.parse(script) | |
| assert "metrics unavailable" in script or "unavailable" in script | |
| def test_special_chars_in_label_names(self): | |
| """Label names with special characters are repr()-escaped safely.""" | |
| labels = ["class/A", "class B", "class'C"] | |
| script = generate_training_script( | |
| _task_spec(label_names=labels), | |
| _data_profile(), | |
| _recipe(), | |
| _training_result(label_names=labels), | |
| ) | |
| ast.parse(script) # must not have syntax errors | |
| def test_returns_string(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe(), _training_result() | |
| ) | |
| assert isinstance(script, str) | |
| assert len(script) > 500 # non-trivial script | |
| def test_has_main_guard(self): | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe(), _training_result() | |
| ) | |
| assert '__name__ == "__main__"' in script or "__main__" in script | |
| def test_has_argparse(self): | |
| """Script must accept --data_path argument.""" | |
| script = generate_training_script( | |
| _task_spec(), _data_profile(), _recipe(), _training_result() | |
| ) | |
| assert "argparse" in script | |
| assert "--data_path" in script | |