Spaces:
Sleeping
Sleeping
| """ | |
| Tests for Step 2 β Semantic Field Validators. | |
| All checks are deterministic (no LLM). Covers: | |
| β’ QLoRA on encoder-only model β ERROR | |
| β’ QLoRA on unknown model β no error (skip check) | |
| β’ QLoRA on decoder-only (unknown family) β no error | |
| β’ Learning rate too high β ERROR | |
| β’ Learning rate too low β WARNING | |
| β’ Learning rate in-range β no issue | |
| β’ lora_r > 64 β WARNING | |
| β’ lora_alpha < lora_r β WARNING | |
| β’ max_length > 512 on BERT-family β ERROR | |
| β’ max_length > 512 on longformer β no error | |
| β’ batch_size > 25% dataset β WARNING | |
| β’ num_epochs > 10 on small dataset β WARNING | |
| β’ label count mismatch β ERROR | |
| β’ None / empty recipe β immediately valid | |
| β’ Unknown model family β no model-specific errors emitted | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| from agents.validators import validate_recipe_semantics, ValidationResult | |
| def _make_recipe(**kwargs) -> dict: | |
| """Build a minimal valid recipe with any field overrides.""" | |
| base = { | |
| "base_model": "bert-base-uncased", | |
| "training_approach": "lora", | |
| "lora_r": 16, | |
| "lora_alpha": 32, | |
| "learning_rate": 2e-4, | |
| "num_epochs": 3, | |
| "batch_size": 16, | |
| "max_length": 128, | |
| "warmup_ratio": 0.1, | |
| "weight_decay": 0.01, | |
| } | |
| base.update(kwargs) | |
| return base | |
| def _make_profile(**kwargs) -> dict: | |
| """Build a minimal data profile with any field overrides.""" | |
| base = { | |
| "num_rows": 1000, | |
| "num_classes": 3, | |
| "label_distribution": {"pos": 400, "neg": 400, "neu": 200}, | |
| } | |
| base.update(kwargs) | |
| return base | |
| # ββ QLoRA + encoder-only checks βββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestQLoRAEncoder: | |
| def test_qlora_on_bert_is_error(self): | |
| recipe = _make_recipe(base_model="bert-base-uncased", training_approach="qlora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| assert any("QLoRA" in e and "encoder" in e for e in result.errors) | |
| def test_qlora_on_distilbert_is_error(self): | |
| recipe = _make_recipe(base_model="distilbert-base-uncased", training_approach="qlora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| assert len(result.errors) >= 1 | |
| def test_qlora_on_roberta_is_error(self): | |
| recipe = _make_recipe(base_model="roberta-base", training_approach="qlora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| def test_qlora_on_deberta_is_error(self): | |
| recipe = _make_recipe(base_model="microsoft/deberta-v3-small", training_approach="qlora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| def test_qlora_on_unknown_model_is_skipped(self): | |
| """Unknown model family β skip QLoRA check (can't rule it out).""" | |
| recipe = _make_recipe(base_model="meta-llama/Llama-3-8B-Instruct", training_approach="qlora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| # Should have NO QLoRA error for unknown family | |
| assert not any("QLoRA" in e for e in result.errors) | |
| def test_lora_on_bert_is_fine(self): | |
| recipe = _make_recipe(base_model="bert-base-uncased", training_approach="lora") | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert result.is_valid or not any("QLoRA" in e for e in result.errors) | |
| def test_full_finetune_on_bert_is_fine(self): | |
| recipe = _make_recipe(base_model="bert-base-uncased", training_approach="full_finetune", | |
| lora_r=None, lora_alpha=None) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("QLoRA" in e for e in result.errors) | |
| # ββ Learning rate checks ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestLearningRate: | |
| def test_lr_above_1e3_is_error(self): | |
| recipe = _make_recipe(learning_rate=0.5) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| assert any("learning_rate" in e and "1e-3" in e for e in result.errors) | |
| def test_lr_exactly_1e3_is_error(self): | |
| recipe = _make_recipe(learning_rate=1e-3) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| def test_lr_below_1e6_is_warning(self): | |
| recipe = _make_recipe(learning_rate=1e-7) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| # Should be valid (just a warning) for lr=1e-7 | |
| assert any("too low" in w.lower() or "learning_rate" in w for w in result.warnings) | |
| def test_lr_in_range_no_issue(self): | |
| for lr in [2e-5, 1e-4, 3e-4, 5e-4, 9.9e-4]: | |
| recipe = _make_recipe(learning_rate=lr) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("learning_rate" in e for e in result.errors), f"Unexpected error for lr={lr}" | |
| # ββ LoRA parameter checks βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestLoRAParams: | |
| def test_lora_r_above_64_is_warning(self): | |
| recipe = _make_recipe(training_approach="lora", lora_r=128, lora_alpha=256) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert any("lora_r" in w and "128" in w for w in result.warnings) | |
| def test_lora_r_64_is_fine(self): | |
| recipe = _make_recipe(training_approach="lora", lora_r=64, lora_alpha=128) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("lora_r" in w for w in result.warnings) | |
| def test_lora_alpha_less_than_r_is_warning(self): | |
| recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=8) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert any("lora_alpha" in w or "scaling" in w.lower() for w in result.warnings) | |
| def test_lora_alpha_equal_r_is_fine(self): | |
| recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=16) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("lora_alpha" in w for w in result.warnings) | |
| def test_lora_alpha_double_r_is_fine(self): | |
| recipe = _make_recipe(training_approach="lora", lora_r=16, lora_alpha=32) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("lora_alpha" in w for w in result.warnings) | |
| def test_full_finetune_ignores_lora_params(self): | |
| """Full fine-tune has no LoRA params β no false positives.""" | |
| recipe = _make_recipe(training_approach="full_finetune", lora_r=None, lora_alpha=None) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("lora" in w.lower() for w in result.warnings) | |
| # ββ max_length checks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestMaxLength: | |
| def test_bert_max_length_above_512_is_error(self): | |
| recipe = _make_recipe(base_model="bert-base-uncased", max_length=1024) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| assert any("max_length" in e and "512" in e for e in result.errors) | |
| def test_bert_max_length_512_is_fine(self): | |
| recipe = _make_recipe(base_model="bert-base-uncased", max_length=512) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("max_length" in e for e in result.errors) | |
| def test_roberta_max_length_above_512_is_error(self): | |
| recipe = _make_recipe(base_model="roberta-base", max_length=1024) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| def test_longformer_max_length_above_512_is_fine(self): | |
| """Longformer handles sequences longer than 512 β no error.""" | |
| recipe = _make_recipe(base_model="allenai/longformer-base-4096", max_length=2048) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("max_length" in e for e in result.errors) | |
| def test_unknown_model_max_length_above_512_no_error(self): | |
| """Unknown model family β skip max_length check.""" | |
| recipe = _make_recipe(base_model="meta-llama/Llama-3-8B", max_length=2048) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not any("max_length" in e for e in result.errors) | |
| # ββ Batch size checks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestBatchSize: | |
| def test_batch_larger_than_25pct_dataset_is_warning(self): | |
| recipe = _make_recipe(batch_size=64) | |
| profile = _make_profile(num_rows=100) # 64 > 100/4 = 25 | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert any("batch_size" in w or "25%" in w for w in result.warnings) | |
| def test_batch_exactly_25pct_no_warning(self): | |
| recipe = _make_recipe(batch_size=25) | |
| profile = _make_profile(num_rows=100) # 25 == 100/4 = exactly threshold | |
| result = validate_recipe_semantics(recipe, profile) | |
| # 25 > 25 is False β no warning | |
| assert not any("batch_size" in w for w in result.warnings) | |
| def test_batch_below_25pct_no_warning(self): | |
| recipe = _make_recipe(batch_size=16) | |
| profile = _make_profile(num_rows=1000) # 16 < 250 | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not any("batch_size" in w for w in result.warnings) | |
| # ββ Epoch count checks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestEpochCount: | |
| def test_many_epochs_small_dataset_is_warning(self): | |
| recipe = _make_recipe(num_epochs=15) | |
| profile = _make_profile(num_rows=100) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert any("num_epochs" in w or "overfitting" in w.lower() for w in result.warnings) | |
| def test_10_epochs_small_dataset_no_warning(self): | |
| recipe = _make_recipe(num_epochs=10) | |
| profile = _make_profile(num_rows=100) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not any("num_epochs" in w for w in result.warnings) | |
| def test_many_epochs_large_dataset_no_warning(self): | |
| recipe = _make_recipe(num_epochs=15) | |
| profile = _make_profile(num_rows=5000) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not any("num_epochs" in w for w in result.warnings) | |
| # ββ Label count checks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestLabelCount: | |
| def test_label_count_mismatch_is_error(self): | |
| recipe = _make_recipe() | |
| profile = _make_profile( | |
| num_classes=2, | |
| label_distribution={"pos": 500, "neg": 500}, | |
| ) | |
| # recipe doesn't specify num_labels directly; profile has 2, and | |
| # num_classes=2 but recipe is built from profile so they match. | |
| # Force a mismatch by using a profile with different distribution: | |
| profile["label_distribution"] = {"a": 100, "b": 100, "c": 100, "d": 100} | |
| profile["num_classes"] = 4 | |
| # make a fake recipe that claims it was built for 3 classes | |
| # The validator uses len(label_distribution) vs num_classes from profile | |
| # Both sides computed from data_profile β so we need to force mismatch | |
| # by passing inconsistent profile (num_classes vs label_distribution keys) | |
| profile["num_classes"] = 4 | |
| profile["label_distribution"] = {"a": 100, "b": 100, "c": 100} # 3 keys | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not result.is_valid | |
| assert any("mismatch" in e.lower() or "label" in e.lower() for e in result.errors) | |
| def test_matching_label_count_no_error(self): | |
| recipe = _make_recipe() | |
| profile = _make_profile( | |
| num_classes=3, | |
| label_distribution={"a": 100, "b": 100, "c": 100}, | |
| ) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not any("mismatch" in e.lower() for e in result.errors) | |
| def test_zero_label_count_skips_check(self): | |
| """Can't compare if one side is unknown.""" | |
| recipe = _make_recipe() | |
| profile = _make_profile(num_classes=0, label_distribution={}) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert not any("mismatch" in e.lower() for e in result.errors) | |
| # ββ Edge cases ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestEdgeCases: | |
| def test_empty_recipe_returns_valid(self): | |
| result = validate_recipe_semantics({}, _make_profile()) | |
| assert result.is_valid | |
| assert result.errors == [] | |
| assert result.warnings == [] | |
| def test_none_recipe_returns_valid(self): | |
| result = validate_recipe_semantics(None, _make_profile()) # type: ignore[arg-type] | |
| assert result.is_valid | |
| def test_valid_recipe_no_issues(self): | |
| recipe = _make_recipe( | |
| base_model="bert-base-uncased", | |
| training_approach="lora", | |
| lora_r=16, | |
| lora_alpha=32, | |
| learning_rate=2e-4, | |
| num_epochs=3, | |
| batch_size=16, | |
| max_length=128, | |
| ) | |
| profile = _make_profile(num_rows=1000, num_classes=3, | |
| label_distribution={"a": 400, "b": 400, "c": 200}) | |
| result = validate_recipe_semantics(recipe, profile) | |
| assert result.is_valid | |
| assert result.errors == [] | |
| def test_multiple_errors_all_returned(self): | |
| """All errors should be collected, not short-circuited after first.""" | |
| recipe = _make_recipe( | |
| base_model="bert-base-uncased", | |
| training_approach="qlora", # ERROR: QLoRA on encoder | |
| learning_rate=0.5, # ERROR: lr too high | |
| max_length=1024, # ERROR: exceeds BERT limit | |
| ) | |
| result = validate_recipe_semantics(recipe, _make_profile()) | |
| assert not result.is_valid | |
| assert len(result.errors) >= 2 # at least QLoRA + max_length (lr caught by pydantic usually) | |