Spaces:
Running
Running
| """Tests for the model registry and capabilities database.""" | |
| from __future__ import annotations | |
| import pytest | |
| from headroom.models import ( | |
| ModelInfo, | |
| ModelRegistry, | |
| get_model_info, | |
| list_models, | |
| register_model, | |
| ) | |
| class TestModelInfo: | |
| """Tests for ModelInfo dataclass.""" | |
| def test_default_values(self): | |
| """Test default values.""" | |
| info = ModelInfo(name="test", provider="test-provider") | |
| assert info.context_window == 128000 | |
| assert info.max_output_tokens == 4096 | |
| assert info.supports_tools is True | |
| assert info.supports_vision is False | |
| assert info.supports_streaming is True | |
| def test_custom_values(self): | |
| """Test custom values.""" | |
| info = ModelInfo( | |
| name="custom-model", | |
| provider="custom", | |
| context_window=32000, | |
| max_output_tokens=8192, | |
| supports_tools=False, | |
| supports_vision=True, | |
| ) | |
| assert info.context_window == 32000 | |
| assert info.max_output_tokens == 8192 | |
| assert info.supports_tools is False | |
| assert info.supports_vision is True | |
| def test_frozen(self): | |
| """Test that ModelInfo is frozen (immutable).""" | |
| info = ModelInfo(name="test", provider="test") | |
| with pytest.raises(AttributeError): | |
| info.name = "changed" | |
| class TestModelRegistry: | |
| """Tests for ModelRegistry.""" | |
| def test_get_openai_model(self): | |
| """Test getting OpenAI model info.""" | |
| info = ModelRegistry.get("gpt-4o") | |
| assert info is not None | |
| assert info.provider == "openai" | |
| assert info.context_window == 128000 | |
| def test_get_anthropic_model(self): | |
| """Test getting Anthropic model info.""" | |
| info = ModelRegistry.get("claude-3-5-sonnet-20241022") | |
| assert info is not None | |
| assert info.provider == "anthropic" | |
| assert info.context_window == 200000 | |
| def test_get_google_model(self): | |
| """Test getting Google model info.""" | |
| info = ModelRegistry.get("gemini-1.5-pro") | |
| assert info is not None | |
| assert info.provider == "google" | |
| assert info.context_window == 2000000 # 2M! | |
| def test_get_by_alias(self): | |
| """Test getting model by alias.""" | |
| info = ModelRegistry.get("gpt-4o-2024-11-20") | |
| assert info is not None | |
| assert info.name == "gpt-4o" | |
| def test_get_unknown_model(self): | |
| """Test getting unknown model returns None.""" | |
| info = ModelRegistry.get("unknown-model-xyz") | |
| assert info is None | |
| def test_get_prefix_matching(self): | |
| """Test prefix matching for versioned models.""" | |
| info = ModelRegistry.get("gpt-4o-new-version") | |
| assert info is not None | |
| assert info.name == "gpt-4o" | |
| def test_register_custom_model(self): | |
| """Test registering custom model.""" | |
| info = ModelRegistry.register( | |
| "my-custom-model", | |
| provider="custom", | |
| context_window=64000, | |
| supports_vision=True, | |
| ) | |
| assert info.name == "my-custom-model" | |
| assert info.provider == "custom" | |
| assert info.context_window == 64000 | |
| # Should be retrievable | |
| retrieved = ModelRegistry.get("my-custom-model") | |
| assert retrieved is not None | |
| assert retrieved.context_window == 64000 | |
| def test_list_models_all(self): | |
| """Test listing all models.""" | |
| models = ModelRegistry.list_models() | |
| assert len(models) > 0 | |
| def test_list_models_by_provider(self): | |
| """Test listing models by provider.""" | |
| openai_models = ModelRegistry.list_models(provider="openai") | |
| assert len(openai_models) > 0 | |
| assert all(m.provider == "openai" for m in openai_models) | |
| def test_list_models_with_tools(self): | |
| """Test listing models with tool support.""" | |
| models = ModelRegistry.list_models(supports_tools=True) | |
| assert len(models) > 0 | |
| assert all(m.supports_tools for m in models) | |
| def test_list_models_with_vision(self): | |
| """Test listing models with vision support.""" | |
| models = ModelRegistry.list_models(supports_vision=True) | |
| assert len(models) > 0 | |
| assert all(m.supports_vision for m in models) | |
| def test_list_models_min_context(self): | |
| """Test listing models with minimum context.""" | |
| models = ModelRegistry.list_models(min_context=1000000) | |
| assert len(models) > 0 | |
| assert all(m.context_window >= 1000000 for m in models) | |
| def test_list_providers(self): | |
| """Test listing all providers.""" | |
| providers = ModelRegistry.list_providers() | |
| assert "openai" in providers | |
| assert "anthropic" in providers | |
| assert "google" in providers | |
| def test_get_context_limit(self): | |
| """Test getting context limit.""" | |
| limit = ModelRegistry.get_context_limit("gpt-4o") | |
| assert limit == 128000 | |
| def test_get_context_limit_unknown(self): | |
| """Test getting context limit for unknown model.""" | |
| limit = ModelRegistry.get_context_limit("unknown", default=32000) | |
| assert limit == 32000 | |
| def test_estimate_cost(self): | |
| """Test cost estimation.""" | |
| cost = ModelRegistry.estimate_cost( | |
| model="gpt-4o", | |
| input_tokens=1000000, | |
| output_tokens=500000, | |
| ) | |
| assert cost is not None | |
| # GPT-4o: $2.50/1M input + $10.00/1M output * 0.5 = $2.50 + $5.00 = $7.50 | |
| assert abs(cost - 7.50) < 0.01 | |
| def test_estimate_cost_with_cache(self): | |
| """Test cost estimation with cached tokens. | |
| Note: LiteLLM's basic cost estimation doesn't support cached token pricing. | |
| The cached_tokens parameter is accepted but not currently factored into cost. | |
| """ | |
| cost = ModelRegistry.estimate_cost( | |
| model="gpt-4o", | |
| input_tokens=1000000, | |
| output_tokens=0, | |
| cached_tokens=500000, # Not currently used by LiteLLM | |
| ) | |
| assert cost is not None | |
| # With LiteLLM, all 1M tokens are charged at input rate: $2.50 | |
| assert abs(cost - 2.50) < 0.01 | |
| def test_estimate_cost_unknown_model(self): | |
| """Test cost estimation for unknown model.""" | |
| cost = ModelRegistry.estimate_cost( | |
| model="unknown-model", | |
| input_tokens=1000, | |
| output_tokens=500, | |
| ) | |
| assert cost is None | |
| class TestConvenienceFunctions: | |
| """Tests for convenience functions.""" | |
| def test_get_model_info(self): | |
| """Test get_model_info function.""" | |
| info = get_model_info("gpt-4o") | |
| assert info is not None | |
| assert info.name == "gpt-4o" | |
| def test_list_models(self): | |
| """Test list_models function.""" | |
| models = list_models(provider="anthropic") | |
| assert len(models) > 0 | |
| def test_register_model(self): | |
| """Test register_model function.""" | |
| info = register_model( | |
| "test-function-model", | |
| provider="test", | |
| context_window=16000, | |
| ) | |
| assert info.name == "test-function-model" | |
| class TestBuiltInModels: | |
| """Tests for built-in model data.""" | |
| def test_gpt4o_info(self): | |
| """Test GPT-4o model info.""" | |
| info = get_model_info("gpt-4o") | |
| assert info.provider == "openai" | |
| assert info.context_window == 128000 | |
| assert info.supports_tools is True | |
| assert info.supports_vision is True | |
| # Pricing is now fetched from LiteLLM, not stored in ModelInfo | |
| pricing = ModelRegistry.get_pricing("gpt-4o") | |
| assert pricing is not None | |
| assert pricing[0] == 2.50 # input cost per 1M | |
| assert pricing[1] == 10.00 # output cost per 1M | |
| def test_o1_info(self): | |
| """Test o1 model info.""" | |
| info = get_model_info("o1") | |
| assert info.provider == "openai" | |
| assert info.context_window == 200000 # 200K context | |
| assert info.max_output_tokens == 100000 # 100K output | |
| def test_claude_info(self): | |
| """Test Claude model info.""" | |
| info = get_model_info("claude-3-5-sonnet-20241022") | |
| assert info.provider == "anthropic" | |
| assert info.context_window == 200000 | |
| # Pricing fetched from LiteLLM (falls back to alias for retired models) | |
| pricing = ModelRegistry.get_pricing("claude-sonnet-4-20250514") | |
| assert pricing is not None | |
| assert pricing[0] == 3.00 # input cost per 1M | |
| assert pricing[1] == 15.00 # output cost per 1M | |
| # Retired model alias should also resolve | |
| alias_pricing = ModelRegistry.get_pricing("claude-3-5-sonnet-20241022") | |
| assert alias_pricing is not None | |
| def test_gemini_info(self): | |
| """Test Gemini model info.""" | |
| info = get_model_info("gemini-1.5-pro") | |
| assert info.provider == "google" | |
| assert info.context_window == 2000000 # 2M tokens! | |
| def test_llama_info(self): | |
| """Test Llama model info.""" | |
| info = get_model_info("llama-3.1-8b") | |
| assert info.provider == "meta" | |
| assert info.context_window == 128000 | |
| assert info.tokenizer_backend == "huggingface" | |
| def test_mistral_info(self): | |
| """Test Mistral model info.""" | |
| info = get_model_info("mistral-large") | |
| assert info.provider == "mistral" | |
| assert info.supports_tools is True | |