Spaces:
Paused
Paused
| # stdlib imports | |
| import json | |
| import os | |
| import time | |
| from unittest.mock import patch | |
| import pytest | |
| # third party imports | |
| from click.testing import CliRunner | |
| # local imports | |
| from litellm.proxy.client.cli import cli | |
| from litellm.proxy.client.cli.commands.models import ( | |
| format_cost_per_1k_tokens, | |
| format_iso_datetime_str, | |
| format_timestamp, | |
| ) | |
| def mock_client(): | |
| """Fixture to create a mock client with common setup""" | |
| with patch("litellm.proxy.client.cli.commands.models.Client") as MockClient: | |
| yield MockClient | |
| def cli_runner(): | |
| """Fixture for Click CLI runner""" | |
| return CliRunner() | |
| def mock_env(): | |
| """Fixture to set up environment variables for all tests""" | |
| with patch.dict( | |
| os.environ, | |
| { | |
| "LITELLM_PROXY_URL": "http://localhost:4000", | |
| "LITELLM_PROXY_API_KEY": "sk-test", | |
| }, | |
| ): | |
| yield | |
| def mock_models_list(mock_client): | |
| """Fixture to set up common mocking pattern for models list tests""" | |
| mock_client.return_value.models.list.return_value = [ | |
| { | |
| "id": "model-123", | |
| "object": "model", | |
| "created": 1699848889, | |
| "owned_by": "organization-123", | |
| }, | |
| { | |
| "id": "model-456", | |
| "object": "model", | |
| "created": 1699848890, | |
| "owned_by": "organization-456", | |
| }, | |
| ] | |
| mock_client.assert_not_called() # Ensure clean slate | |
| return mock_client | |
| def mock_models_info(mock_client): | |
| """Fixture to set up models info mock""" | |
| mock_client.return_value.models.info.return_value = [ | |
| { | |
| "model_name": "gpt-4", | |
| "litellm_params": {"model": "gpt-4", "litellm_credential_name": "openai-1"}, | |
| "model_info": { | |
| "id": "model-123", | |
| "created_at": "2025-04-29T21:31:43.843000+00:00", | |
| "updated_at": "2025-04-29T21:31:43.843000+00:00", | |
| "input_cost_per_token": 0.00001, | |
| "output_cost_per_token": 0.00002, | |
| }, | |
| } | |
| ] | |
| mock_client.assert_not_called() | |
| return mock_client | |
| def force_utc_tz(): | |
| """Fixture to force UTC timezone for tests that depend on system TZ.""" | |
| old_tz = os.environ.get("TZ") | |
| os.environ["TZ"] = "UTC" | |
| if hasattr(time, "tzset"): | |
| time.tzset() | |
| yield | |
| # Restore previous TZ | |
| if old_tz is not None: | |
| os.environ["TZ"] = old_tz | |
| else: | |
| if "TZ" in os.environ: | |
| del os.environ["TZ"] | |
| if hasattr(time, "tzset"): | |
| time.tzset() | |
| def test_models_list_json_format(mock_models_list, cli_runner): | |
| """Test the models list command with JSON output format""" | |
| # Run the command | |
| result = cli_runner.invoke(cli, ["models", "list", "--format", "json"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Parse the output and verify it matches our mock data | |
| output_data = json.loads(result.output) | |
| assert output_data == mock_models_list.return_value.models.list.return_value | |
| # Verify the client was called correctly | |
| mock_models_list.assert_called_once_with( | |
| base_url="http://localhost:4000", api_key="sk-test" | |
| ) | |
| mock_models_list.return_value.models.list.assert_called_once() | |
| def test_models_list_table_format(mock_models_list, cli_runner): | |
| """Test the models list command with table output format""" | |
| # Run the command | |
| result = cli_runner.invoke(cli, ["models", "list"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Verify the output contains expected table elements | |
| assert "ID" in result.output | |
| assert "Object" in result.output | |
| assert "Created" in result.output | |
| assert "Owned By" in result.output | |
| assert "model-123" in result.output | |
| assert "organization-123" in result.output | |
| assert format_timestamp(1699848889) in result.output | |
| # Verify the client was called correctly | |
| mock_models_list.assert_called_once_with( | |
| base_url="http://localhost:4000", api_key="sk-test" | |
| ) | |
| mock_models_list.return_value.models.list.assert_called_once() | |
| def test_models_list_with_base_url(mock_models_list, cli_runner): | |
| """Test the models list command with custom base URL overriding env var""" | |
| custom_base_url = "http://custom.server:8000" | |
| # Run the command with custom base URL | |
| result = cli_runner.invoke(cli, ["--base-url", custom_base_url, "models", "list"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Verify the client was created with the custom base URL (overriding env var) | |
| mock_models_list.assert_called_once_with( | |
| base_url=custom_base_url, | |
| api_key="sk-test", # Should still use env var for API key | |
| ) | |
| def test_models_list_with_api_key(mock_models_list, cli_runner): | |
| """Test the models list command with API key overriding env var""" | |
| custom_api_key = "custom-test-key" | |
| # Run the command with custom API key | |
| result = cli_runner.invoke(cli, ["--api-key", custom_api_key, "models", "list"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Verify the client was created with the custom API key (overriding env var) | |
| mock_models_list.assert_called_once_with( | |
| base_url="http://localhost:4000", # Should still use env var for base URL | |
| api_key=custom_api_key, | |
| ) | |
| def test_models_list_error_handling(mock_client, cli_runner): | |
| """Test error handling in the models list command""" | |
| # Configure mock to raise an exception | |
| mock_client.return_value.models.list.side_effect = Exception("API Error") | |
| # Run the command | |
| result = cli_runner.invoke(cli, ["models", "list"]) | |
| # Check that the command failed | |
| assert result.exit_code != 0 | |
| assert "API Error" in str(result.exception) | |
| # Verify the client was created with env var values | |
| mock_client.assert_called_once_with( | |
| base_url="http://localhost:4000", api_key="sk-test" | |
| ) | |
| def test_models_info_json_format(mock_models_info, cli_runner): | |
| """Test the models info command with JSON output format""" | |
| # Run the command | |
| result = cli_runner.invoke(cli, ["models", "info", "--format", "json"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Parse the output and verify it matches our mock data | |
| output_data = json.loads(result.output) | |
| assert output_data == mock_models_info.return_value.models.info.return_value | |
| # Verify the client was called correctly with env var values | |
| mock_models_info.assert_called_once_with( | |
| base_url="http://localhost:4000", api_key="sk-test" | |
| ) | |
| mock_models_info.return_value.models.info.assert_called_once() | |
| def test_models_info_table_format(mock_models_info, cli_runner): | |
| """Test the models info command with table output format""" | |
| # Run the command with default columns | |
| result = cli_runner.invoke(cli, ["models", "info"]) | |
| # Check that the command succeeded | |
| assert result.exit_code == 0 | |
| # Verify the output contains expected table elements | |
| assert "Public Model" in result.output | |
| assert "Upstream Model" in result.output | |
| assert "Updated At" in result.output | |
| assert "gpt-4" in result.output | |
| assert "2025-04-29 21:31" in result.output | |
| # Verify seconds and microseconds are not shown | |
| assert "21:31:43" not in result.output | |
| assert "843000" not in result.output | |
| # Verify the client was called correctly with env var values | |
| mock_models_info.assert_called_once_with( | |
| base_url="http://localhost:4000", api_key="sk-test" | |
| ) | |
| mock_models_info.return_value.models.info.assert_called_once() | |
| def test_models_import_only_models_matching_regex(tmp_path, mock_client, cli_runner): | |
| """Test the --only-models-matching-regex option for models import command""" | |
| # Prepare a YAML file with a mix of models | |
| yaml_content = { | |
| "model_list": [ | |
| { | |
| "model_name": "gpt-4-model", | |
| "litellm_params": {"model": "gpt-4"}, | |
| "model_info": {"id": "id-1"}, | |
| }, | |
| { | |
| "model_name": "gpt-3.5-model", | |
| "litellm_params": {"model": "gpt-3.5-turbo"}, | |
| "model_info": {"id": "id-2"}, | |
| }, | |
| { | |
| "model_name": "llama2-model", | |
| "litellm_params": {"model": "llama2"}, | |
| "model_info": {"id": "id-3"}, | |
| }, | |
| { | |
| "model_name": "other-model", | |
| "litellm_params": {"model": "other"}, | |
| "model_info": {"id": "id-4"}, | |
| }, | |
| ] | |
| } | |
| import yaml as pyyaml | |
| yaml_file = tmp_path / "models.yaml" | |
| with open(yaml_file, "w") as f: | |
| pyyaml.safe_dump(yaml_content, f) | |
| # Patch client.models.new to track calls | |
| mock_new = mock_client.return_value.models.new | |
| # Only match models containing 'gpt' in their litellm_params.model | |
| result = cli_runner.invoke( | |
| cli, ["models", "import", str(yaml_file), "--only-models-matching-regex", "gpt"] | |
| ) | |
| # Should succeed | |
| assert result.exit_code == 0 | |
| # Only the two gpt models should be imported | |
| calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list] | |
| assert set(calls) == {"gpt-4", "gpt-3.5-turbo"} | |
| # Should not include llama2 or other | |
| assert "llama2" not in calls | |
| assert "other" not in calls | |
| # Output summary should mention the correct providers | |
| assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output | |
| def test_models_import_only_access_groups_matching_regex( | |
| tmp_path, mock_client, cli_runner | |
| ): | |
| """Test the --only-access-groups-matching-regex option for models import command""" | |
| # Prepare a YAML file with a mix of models | |
| yaml_content = { | |
| "model_list": [ | |
| { | |
| "model_name": "gpt-4-model", | |
| "litellm_params": {"model": "gpt-4"}, | |
| "model_info": { | |
| "id": "id-1", | |
| "access_groups": ["beta-models", "prod-models"], | |
| }, | |
| }, | |
| { | |
| "model_name": "gpt-3.5-model", | |
| "litellm_params": {"model": "gpt-3.5-turbo"}, | |
| "model_info": {"id": "id-2", "access_groups": ["alpha-models"]}, | |
| }, | |
| { | |
| "model_name": "llama2-model", | |
| "litellm_params": {"model": "llama2"}, | |
| "model_info": {"id": "id-3", "access_groups": ["beta-models"]}, | |
| }, | |
| { | |
| "model_name": "other-model", | |
| "litellm_params": {"model": "other"}, | |
| "model_info": {"id": "id-4", "access_groups": ["other-group"]}, | |
| }, | |
| { | |
| "model_name": "no-access-group-model", | |
| "litellm_params": {"model": "no-access"}, | |
| "model_info": {"id": "id-5"}, | |
| }, | |
| ] | |
| } | |
| import yaml as pyyaml | |
| yaml_file = tmp_path / "models.yaml" | |
| with open(yaml_file, "w") as f: | |
| pyyaml.safe_dump(yaml_content, f) | |
| # Patch client.models.new to track calls | |
| mock_new = mock_client.return_value.models.new | |
| # Only match models with access_groups containing 'beta' | |
| result = cli_runner.invoke( | |
| cli, | |
| [ | |
| "models", | |
| "import", | |
| str(yaml_file), | |
| "--only-access-groups-matching-regex", | |
| "beta", | |
| ], | |
| ) | |
| # Should succeed | |
| assert result.exit_code == 0 | |
| # Only the two models with 'beta-models' in access_groups should be imported | |
| calls = [call.kwargs["model_params"]["model"] for call in mock_new.call_args_list] | |
| assert set(calls) == {"gpt-4", "llama2"} | |
| # Should not include gpt-3.5, other, or no-access | |
| assert "gpt-3.5-turbo" not in calls | |
| assert "other" not in calls | |
| assert "no-access" not in calls | |
| # Output summary should mention the correct providers | |
| assert "gpt-4".split("-")[0] in result.output or "gpt" in result.output | |
| def test_format_iso_datetime_str(input_str, expected): | |
| assert format_iso_datetime_str(input_str) == expected | |
| def test_format_timestamp(input_val, expected, force_utc_tz): | |
| actual = format_timestamp(input_val) | |
| if actual != expected: | |
| print(f"input: {input_val}, expected: {expected}, actual: {actual}") | |
| assert actual == expected | |
| def test_format_cost_per_1k_tokens(input_val, expected): | |
| actual = format_cost_per_1k_tokens(input_val) | |
| if actual != expected: | |
| print(f"input: {input_val}, expected: {expected}, actual: {actual}") | |
| assert actual == expected | |