modelforge-backend / agents /tests /test_semantic_validators.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
15.1 kB
"""
Tests for Step 2 β€” Semantic Field Validators.
All checks are deterministic (no LLM). Covers:
β€’ QLoRA on encoder-only model β†’ ERROR
β€’ QLoRA on unknown model β†’ no error (skip check)
β€’ QLoRA on decoder-only (unknown family) β†’ no error
β€’ Learning rate too high β†’ ERROR
β€’ Learning rate too low β†’ WARNING
β€’ Learning rate in-range β†’ no issue
β€’ lora_r > 64 β†’ WARNING
β€’ lora_alpha < lora_r β†’ WARNING
β€’ max_length > 512 on BERT-family β†’ ERROR
β€’ max_length > 512 on longformer β†’ no error
β€’ batch_size > 25% dataset β†’ WARNING
β€’ num_epochs > 10 on small dataset β†’ WARNING
β€’ label count mismatch β†’ ERROR
β€’ None / empty recipe β†’ immediately valid
β€’ Unknown model family β†’ no model-specific errors emitted
"""
from __future__ import annotations
import pytest
from agents.validators import validate_recipe_semantics, ValidationResult
def _make_recipe(**kwargs) -> dict:
"""Build a minimal valid recipe with any field overrides."""
base = {
"base_model": "bert-base-uncased",
"training_approach": "lora",
"lora_r": 16,
"lora_alpha": 32,
"learning_rate": 2e-4,
"num_epochs": 3,
"batch_size": 16,
"max_length": 128,
"warmup_ratio": 0.1,
"weight_decay": 0.01,
}
base.update(kwargs)
return base
def _make_profile(**kwargs) -> dict:
"""Build a minimal data profile with any field overrides."""
base = {
"num_rows": 1000,
"num_classes": 3,
"label_distribution": {"pos": 400, "neg": 400, "neu": 200},
}
base.update(kwargs)
return base
# ── QLoRA + encoder-only checks ───────────────────────────────────────────────
class TestQLoRAEncoder:
def test_qlora_on_bert_is_error(self):
recipe = _make_recipe(base_model="bert-base-uncased", training_approach="qlora")
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
assert any("QLoRA" in e and "encoder" in e for e in result.errors)
def test_qlora_on_distilbert_is_error(self):
recipe = _make_recipe(base_model="distilbert-base-uncased", training_approach="qlora")
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
assert len(result.errors) >= 1
def test_qlora_on_roberta_is_error(self):
recipe = _make_recipe(base_model="roberta-base", training_approach="qlora")
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
def test_qlora_on_deberta_is_error(self):
recipe = _make_recipe(base_model="microsoft/deberta-v3-small", training_approach="qlora")
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
def test_qlora_on_unknown_model_is_skipped(self):
"""Unknown model family β†’ skip QLoRA check (can't rule it out)."""
recipe = _make_recipe(base_model="meta-llama/Llama-3-8B-Instruct", training_approach="qlora")
result = validate_recipe_semantics(recipe, _make_profile())
# Should have NO QLoRA error for unknown family
assert not any("QLoRA" in e for e in result.errors)
def test_lora_on_bert_is_fine(self):
recipe = _make_recipe(base_model="bert-base-uncased", training_approach="lora")
result = validate_recipe_semantics(recipe, _make_profile())
assert result.is_valid or not any("QLoRA" in e for e in result.errors)
def test_full_finetune_on_bert_is_fine(self):
recipe = _make_recipe(base_model="bert-base-uncased", training_approach="full_finetune",
lora_r=None, lora_alpha=None)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("QLoRA" in e for e in result.errors)
# ── Learning rate checks ──────────────────────────────────────────────────────
class TestLearningRate:
def test_lr_above_1e3_is_error(self):
recipe = _make_recipe(learning_rate=0.5)
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
assert any("learning_rate" in e and "1e-3" in e for e in result.errors)
def test_lr_exactly_1e3_is_error(self):
recipe = _make_recipe(learning_rate=1e-3)
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
def test_lr_below_1e6_is_warning(self):
recipe = _make_recipe(learning_rate=1e-7)
result = validate_recipe_semantics(recipe, _make_profile())
# Should be valid (just a warning) for lr=1e-7
assert any("too low" in w.lower() or "learning_rate" in w for w in result.warnings)
def test_lr_in_range_no_issue(self):
for lr in [2e-5, 1e-4, 3e-4, 5e-4, 9.9e-4]:
recipe = _make_recipe(learning_rate=lr)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("learning_rate" in e for e in result.errors), f"Unexpected error for lr={lr}"
# ── LoRA parameter checks ─────────────────────────────────────────────────────
class TestLoRAParams:
def test_lora_r_above_64_is_warning(self):
recipe = _make_recipe(training_approach="lora", lora_r=128, lora_alpha=256)
result = validate_recipe_semantics(recipe, _make_profile())
assert any("lora_r" in w and "128" in w for w in result.warnings)
def test_lora_r_64_is_fine(self):
recipe = _make_recipe(training_approach="lora", lora_r=64, lora_alpha=128)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("lora_r" in w for w in result.warnings)
def test_lora_alpha_less_than_r_is_warning(self):
recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=8)
result = validate_recipe_semantics(recipe, _make_profile())
assert any("lora_alpha" in w or "scaling" in w.lower() for w in result.warnings)
def test_lora_alpha_equal_r_is_fine(self):
recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=16)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("lora_alpha" in w for w in result.warnings)
def test_lora_alpha_double_r_is_fine(self):
recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=32)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("lora_alpha" in w for w in result.warnings)
def test_full_finetune_ignores_lora_params(self):
"""Full fine-tune has no LoRA params β€” no false positives."""
recipe = _make_recipe(training_approach="full_finetune", lora_r=None, lora_alpha=None)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("lora" in w.lower() for w in result.warnings)
# ── max_length checks ─────────────────────────────────────────────────────────
class TestMaxLength:
def test_bert_max_length_above_512_is_error(self):
recipe = _make_recipe(base_model="bert-base-uncased", max_length=1024)
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
assert any("max_length" in e and "512" in e for e in result.errors)
def test_bert_max_length_512_is_fine(self):
recipe = _make_recipe(base_model="bert-base-uncased", max_length=512)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("max_length" in e for e in result.errors)
def test_roberta_max_length_above_512_is_error(self):
recipe = _make_recipe(base_model="roberta-base", max_length=1024)
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
def test_longformer_max_length_above_512_is_fine(self):
"""Longformer handles sequences longer than 512 β€” no error."""
recipe = _make_recipe(base_model="allenai/longformer-base-4096", max_length=2048)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("max_length" in e for e in result.errors)
def test_unknown_model_max_length_above_512_no_error(self):
"""Unknown model family β†’ skip max_length check."""
recipe = _make_recipe(base_model="meta-llama/Llama-3-8B", max_length=2048)
result = validate_recipe_semantics(recipe, _make_profile())
assert not any("max_length" in e for e in result.errors)
# ── Batch size checks ─────────────────────────────────────────────────────────
class TestBatchSize:
def test_batch_larger_than_25pct_dataset_is_warning(self):
recipe = _make_recipe(batch_size=64)
profile = _make_profile(num_rows=100) # 64 > 100/4 = 25
result = validate_recipe_semantics(recipe, profile)
assert any("batch_size" in w or "25%" in w for w in result.warnings)
def test_batch_exactly_25pct_no_warning(self):
recipe = _make_recipe(batch_size=25)
profile = _make_profile(num_rows=100) # 25 == 100/4 = exactly threshold
result = validate_recipe_semantics(recipe, profile)
# 25 > 25 is False β†’ no warning
assert not any("batch_size" in w for w in result.warnings)
def test_batch_below_25pct_no_warning(self):
recipe = _make_recipe(batch_size=16)
profile = _make_profile(num_rows=1000) # 16 < 250
result = validate_recipe_semantics(recipe, profile)
assert not any("batch_size" in w for w in result.warnings)
# ── Epoch count checks ────────────────────────────────────────────────────────
class TestEpochCount:
def test_many_epochs_small_dataset_is_warning(self):
recipe = _make_recipe(num_epochs=15)
profile = _make_profile(num_rows=100)
result = validate_recipe_semantics(recipe, profile)
assert any("num_epochs" in w or "overfitting" in w.lower() for w in result.warnings)
def test_10_epochs_small_dataset_no_warning(self):
recipe = _make_recipe(num_epochs=10)
profile = _make_profile(num_rows=100)
result = validate_recipe_semantics(recipe, profile)
assert not any("num_epochs" in w for w in result.warnings)
def test_many_epochs_large_dataset_no_warning(self):
recipe = _make_recipe(num_epochs=15)
profile = _make_profile(num_rows=5000)
result = validate_recipe_semantics(recipe, profile)
assert not any("num_epochs" in w for w in result.warnings)
# ── Label count checks ────────────────────────────────────────────────────────
class TestLabelCount:
def test_label_count_mismatch_is_error(self):
recipe = _make_recipe()
profile = _make_profile(
num_classes=2,
label_distribution={"pos": 500, "neg": 500},
)
# recipe doesn't specify num_labels directly; profile has 2, and
# num_classes=2 but recipe is built from profile so they match.
# Force a mismatch by using a profile with different distribution:
profile["label_distribution"] = {"a": 100, "b": 100, "c": 100, "d": 100}
profile["num_classes"] = 4
# make a fake recipe that claims it was built for 3 classes
# The validator uses len(label_distribution) vs num_classes from profile
# Both sides computed from data_profile β€” so we need to force mismatch
# by passing inconsistent profile (num_classes vs label_distribution keys)
profile["num_classes"] = 4
profile["label_distribution"] = {"a": 100, "b": 100, "c": 100} # 3 keys
result = validate_recipe_semantics(recipe, profile)
assert not result.is_valid
assert any("mismatch" in e.lower() or "label" in e.lower() for e in result.errors)
def test_matching_label_count_no_error(self):
recipe = _make_recipe()
profile = _make_profile(
num_classes=3,
label_distribution={"a": 100, "b": 100, "c": 100},
)
result = validate_recipe_semantics(recipe, profile)
assert not any("mismatch" in e.lower() for e in result.errors)
def test_zero_label_count_skips_check(self):
"""Can't compare if one side is unknown."""
recipe = _make_recipe()
profile = _make_profile(num_classes=0, label_distribution={})
result = validate_recipe_semantics(recipe, profile)
assert not any("mismatch" in e.lower() for e in result.errors)
# ── Edge cases ────────────────────────────────────────────────────────────────
class TestEdgeCases:
def test_empty_recipe_returns_valid(self):
result = validate_recipe_semantics({}, _make_profile())
assert result.is_valid
assert result.errors == []
assert result.warnings == []
def test_none_recipe_returns_valid(self):
result = validate_recipe_semantics(None, _make_profile()) # type: ignore[arg-type]
assert result.is_valid
def test_valid_recipe_no_issues(self):
recipe = _make_recipe(
base_model="bert-base-uncased",
training_approach="lora",
lora_r=16,
lora_alpha=32,
learning_rate=2e-4,
num_epochs=3,
batch_size=16,
max_length=128,
)
profile = _make_profile(num_rows=1000, num_classes=3,
label_distribution={"a": 400, "b": 400, "c": 200})
result = validate_recipe_semantics(recipe, profile)
assert result.is_valid
assert result.errors == []
def test_multiple_errors_all_returned(self):
"""All errors should be collected, not short-circuited after first."""
recipe = _make_recipe(
base_model="bert-base-uncased",
training_approach="qlora", # ERROR: QLoRA on encoder
learning_rate=0.5, # ERROR: lr too high
max_length=1024, # ERROR: exceeds BERT limit
)
result = validate_recipe_semantics(recipe, _make_profile())
assert not result.is_valid
assert len(result.errors) >= 2 # at least QLoRA + max_length (lr caught by pydantic usually)