Spaces:
Sleeping
Sleeping
| """ | |
| Tests for utils/model_config.py | |
| Tests model family lookups, configuration retrieval, and auto-selection logic. | |
| """ | |
| import pytest | |
| from utils.model_config import ( | |
| get_model_family, | |
| get_family_config, | |
| get_auto_selections, | |
| _pattern_matches_template, | |
| MODEL_TO_FAMILY, | |
| MODEL_FAMILIES | |
| ) | |
| class TestGetModelFamily: | |
| """Tests for get_model_family function.""" | |
| def test_known_gpt2_model(self): | |
| """Known GPT-2 model should return 'gpt2' family.""" | |
| assert get_model_family("gpt2") == "gpt2" | |
| assert get_model_family("gpt2-medium") == "gpt2" | |
| assert get_model_family("openai-community/gpt2") == "gpt2" | |
| def test_known_llama_model(self): | |
| """Known LLaMA-like models should return 'llama_like' family.""" | |
| assert get_model_family("Qwen/Qwen2.5-0.5B") == "llama_like" | |
| assert get_model_family("meta-llama/Llama-2-7b-hf") == "llama_like" | |
| assert get_model_family("mistralai/Mistral-7B-v0.1") == "llama_like" | |
| def test_known_opt_model(self): | |
| """Known OPT models should return 'opt' family.""" | |
| assert get_model_family("facebook/opt-125m") == "opt" | |
| assert get_model_family("facebook/opt-1.3b") == "opt" | |
| def test_unknown_model_returns_none(self): | |
| """Unknown models should return None.""" | |
| assert get_model_family("unknown/model-name") is None | |
| assert get_model_family("random-string") is None | |
| assert get_model_family("") is None | |
| class TestGetFamilyConfig: | |
| """Tests for get_family_config function.""" | |
| def test_valid_gpt2_config(self): | |
| """GPT-2 family config should have correct structure.""" | |
| config = get_family_config("gpt2") | |
| assert config is not None | |
| assert "templates" in config | |
| assert "attention_pattern" in config["templates"] | |
| assert config["templates"]["attention_pattern"] == "transformer.h.{N}.attn" | |
| assert config["norm_type"] == "layernorm" | |
| def test_valid_llama_config(self): | |
| """LLaMA-like family config should have correct structure.""" | |
| config = get_family_config("llama_like") | |
| assert config is not None | |
| assert config["templates"]["attention_pattern"] == "model.layers.{N}.self_attn" | |
| assert config["norm_type"] == "rmsnorm" | |
| assert config["norm_parameter"] == "model.norm.weight" | |
| def test_invalid_family_returns_none(self): | |
| """Invalid family name should return None.""" | |
| assert get_family_config("invalid_family") is None | |
| assert get_family_config("") is None | |
| assert get_family_config("GPT2") is None # Case-sensitive | |
| class TestPatternMatchesTemplate: | |
| """Tests for _pattern_matches_template function.""" | |
| def test_exact_match(self): | |
| """Pattern that exactly matches template should return True.""" | |
| assert _pattern_matches_template( | |
| "model.layers.{N}.self_attn", | |
| "model.layers.{N}.self_attn" | |
| ) is True | |
| def test_matching_with_n_placeholder(self): | |
| """Patterns with {N} placeholder should match correctly.""" | |
| assert _pattern_matches_template( | |
| "transformer.h.{N}.attn", | |
| "transformer.h.{N}.attn" | |
| ) is True | |
| def test_non_matching_pattern(self): | |
| """Different patterns should not match.""" | |
| assert _pattern_matches_template( | |
| "model.layers.{N}.self_attn", | |
| "transformer.h.{N}.attn" | |
| ) is False | |
| def test_empty_template_returns_false(self): | |
| """Empty template should return False.""" | |
| assert _pattern_matches_template("model.layers.{N}.self_attn", "") is False | |
| assert _pattern_matches_template("", "") is False | |
| class TestGetAutoSelections: | |
| """Tests for get_auto_selections function.""" | |
| def test_unknown_model_returns_empty_selections(self): | |
| """Unknown model should return empty selections.""" | |
| result = get_auto_selections( | |
| "unknown/model", | |
| {"model.layers.{N}.self_attn": ["model.layers.0.self_attn"]}, | |
| {"model.norm.weight": ["model.norm.weight"]} | |
| ) | |
| assert result["attention_selection"] == [] | |
| assert result["block_selection"] == [] | |
| assert result["norm_selection"] == [] | |
| assert result["family_name"] is None | |
| def test_known_model_matches_patterns(self, mock_module_patterns, mock_param_patterns): | |
| """Known model should match appropriate patterns.""" | |
| result = get_auto_selections( | |
| "Qwen/Qwen2.5-0.5B", # llama_like family | |
| mock_module_patterns, | |
| mock_param_patterns | |
| ) | |
| assert result["family_name"] == "llama_like" | |
| # Should find self_attn pattern | |
| assert "model.layers.{N}.self_attn" in result["attention_selection"] | |
| # Should find block pattern | |
| assert "model.layers.{N}" in result["block_selection"] | |
| # Should find norm pattern | |
| assert result["norm_selection"] == ["model.norm.weight"] | |
| def test_result_structure(self, mock_module_patterns, mock_param_patterns): | |
| """Result should have all required keys.""" | |
| result = get_auto_selections( | |
| "gpt2", | |
| {}, # Empty patterns - no matches expected | |
| {} | |
| ) | |
| assert "attention_selection" in result | |
| assert "block_selection" in result | |
| assert "norm_selection" in result | |
| assert "family_name" in result | |
| assert isinstance(result["attention_selection"], list) | |
| assert isinstance(result["norm_selection"], list) | |
| class TestModelRegistryIntegrity: | |
| """Tests to verify the model registry data is consistent.""" | |
| def test_all_families_have_required_fields(self): | |
| """All model families should have required configuration fields.""" | |
| required_fields = ["description", "templates", "norm_type"] | |
| for family_name, config in MODEL_FAMILIES.items(): | |
| for field in required_fields: | |
| assert field in config, f"Family {family_name} missing {field}" | |
| def test_all_mapped_families_exist(self): | |
| """All families referenced in MODEL_TO_FAMILY should exist in MODEL_FAMILIES.""" | |
| for model_name, family_name in MODEL_TO_FAMILY.items(): | |
| assert family_name in MODEL_FAMILIES, \ | |
| f"Model {model_name} references unknown family {family_name}" | |