modelforge-backend / agents /tests /test_code_generator.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
9.07 kB
"""
Tests for Step 9 β€” Standalone Training Script Export.
Covers:
β€’ Generated script is valid Python (ast.parse succeeds)
β€’ CONFIG dict contains correct hyperparameters from recipe
β€’ base_model name is in the script
β€’ QLoRA: BitsAndBytesConfig import present
β€’ LoRA: peft imports present
β€’ Full fine-tune: no LoRA/PEFT imports
β€’ Label names embedded correctly
β€’ CodeGenerationError never raised on valid inputs
β€’ Empty training_result β†’ handled gracefully (no KeyError)
"""
from __future__ import annotations
import ast
import pytest
# Locate the code generator (it's in agents/services/, not agents/agents/)
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services"))
from code_generator import generate_training_script, CodeGenerationError
# ── Fixtures ──────────────────────────────────────────────────────────────────
def _task_spec(**kwargs):
base = {
"task_type": "text_classification",
"num_labels": 3,
"label_names": ["positive", "neutral", "negative"],
"input_column": "text",
"label_column": "label",
}
base.update(kwargs)
return base
def _data_profile(**kwargs):
base = {
"num_rows": 500,
"num_classes": 3,
"label_distribution": {"positive": 200, "neutral": 150, "negative": 150},
}
base.update(kwargs)
return base
def _recipe(approach: str = "full_finetune", **kwargs):
base = {
"base_model": "bert-base-uncased",
"training_approach": approach,
"learning_rate": 2e-5,
"num_epochs": 3,
"batch_size": 16,
"max_length": 128,
"warmup_ratio": 0.1,
"weight_decay": 0.01,
"lora_r": 16,
"lora_alpha": 32,
}
base.update(kwargs)
return base
def _training_result(**kwargs):
base = {
"accuracy": 0.88,
"f1": 0.87,
"label_names": ["positive", "neutral", "negative"],
"num_labels": 3,
}
base.update(kwargs)
return base
# ── Syntax correctness ────────────────────────────────────────────────────────
class TestSyntaxCorrectness:
def test_full_finetune_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
ast.parse(script) # no exception = valid
def test_lora_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("lora"), _training_result()
)
ast.parse(script)
def test_qlora_is_valid_python(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
ast.parse(script)
def test_many_labels_is_valid_python(self):
labels = [f"class_{i}" for i in range(20)]
script = generate_training_script(
_task_spec(label_names=labels, num_labels=20),
_data_profile(),
_recipe(),
_training_result(label_names=labels, num_labels=20),
)
ast.parse(script)
def test_empty_label_names_is_valid_python(self):
"""Falls back to empty list β€” still valid Python."""
script = generate_training_script(
_task_spec(label_names=None),
_data_profile(label_distribution={}),
_recipe(),
_training_result(label_names=None),
)
ast.parse(script)
# ── CONFIG dict content ───────────────────────────────────────────────────────
class TestConfigContent:
def test_base_model_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(base_model="roberta-base"), _training_result()
)
assert '"roberta-base"' in script or "'roberta-base'" in script
def test_learning_rate_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(learning_rate=3e-5), _training_result()
)
assert "3e-05" in script or "3e-5" in script or "0.00003" in script
def test_num_epochs_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(num_epochs=7), _training_result()
)
assert "7" in script
def test_batch_size_in_config(self):
script = generate_training_script(
_task_spec(), _data_profile(),
_recipe(batch_size=32), _training_result()
)
assert "32" in script
def test_label_names_in_script(self):
script = generate_training_script(
_task_spec(label_names=["spam", "ham"]), _data_profile(),
_recipe(), _training_result(label_names=["spam", "ham"])
)
assert "spam" in script
assert "ham" in script
def test_input_column_in_script(self):
script = generate_training_script(
_task_spec(input_column="review_text"), _data_profile(),
_recipe(), _training_result()
)
assert "review_text" in script
def test_label_column_in_script(self):
script = generate_training_script(
_task_spec(label_column="sentiment"), _data_profile(),
_recipe(), _training_result()
)
assert "sentiment" in script
# ── Approach-specific content ─────────────────────────────────────────────────
class TestApproachContent:
def test_qlora_has_bnb_import(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
assert "BitsAndBytesConfig" in script
def test_lora_has_peft_import(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("lora"), _training_result()
)
assert "peft" in script
assert "LoraConfig" in script
def test_full_finetune_no_qlora(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
assert "BitsAndBytesConfig" not in script
def test_full_finetune_no_lora_adapter(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("full_finetune"), _training_result()
)
assert "LoraConfig" not in script
assert "get_peft_model" not in script
def test_qlora_has_load_in_4bit(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe("qlora"), _training_result()
)
assert "load_in_4bit" in script
# ── Edge cases ────────────────────────────────────────────────────────────────
class TestEdgeCases:
def test_empty_training_result_no_crash(self):
"""Missing metrics β†’ header shows 'unavailable' but script still valid."""
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), {}
)
ast.parse(script)
assert "metrics unavailable" in script or "unavailable" in script
def test_special_chars_in_label_names(self):
"""Label names with special characters are repr()-escaped safely."""
labels = ["class/A", "class B", "class'C"]
script = generate_training_script(
_task_spec(label_names=labels),
_data_profile(),
_recipe(),
_training_result(label_names=labels),
)
ast.parse(script) # must not have syntax errors
def test_returns_string(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert isinstance(script, str)
assert len(script) > 500 # non-trivial script
def test_has_main_guard(self):
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert '__name__ == "__main__"' in script or "__main__" in script
def test_has_argparse(self):
"""Script must accept --data_path argument."""
script = generate_training_script(
_task_spec(), _data_profile(), _recipe(), _training_result()
)
assert "argparse" in script
assert "--data_path" in script