Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| sys.path.insert( | |
| 0, os.path.abspath("../../..") | |
| ) # Adds the parent directory to the system path | |
| from datetime import datetime, timedelta | |
| import pytest | |
| import litellm | |
| from litellm.proxy._types import ( | |
| LiteLLM_ObjectPermissionTable, | |
| LiteLLM_TeamTable, | |
| LiteLLM_UserTable, | |
| LitellmUserRoles, | |
| ProxyErrorTypes, | |
| ProxyException, | |
| SSOUserDefinedValues, | |
| UserAPIKeyAuth, | |
| ) | |
| from litellm.proxy.auth.auth_checks import ( | |
| ExperimentalUIJWTToken, | |
| _can_object_call_vector_stores, | |
| get_user_object, | |
| vector_store_access_check, | |
| ) | |
| from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper | |
| from litellm.utils import get_utc_datetime | |
| def set_salt_key(monkeypatch): | |
| """Automatically set LITELLM_SALT_KEY for all tests""" | |
| monkeypatch.setenv("LITELLM_SALT_KEY", "sk-1234") | |
| def valid_sso_user_defined_values(): | |
| return LiteLLM_UserTable( | |
| user_id="test_user", | |
| user_email="test@example.com", | |
| user_role=LitellmUserRoles.PROXY_ADMIN.value, | |
| models=["gpt-3.5-turbo"], | |
| max_budget=100.0, | |
| ) | |
| def invalid_sso_user_defined_values(): | |
| return LiteLLM_UserTable( | |
| user_id="test_user", | |
| user_email="test@example.com", | |
| user_role=None, # Missing user role | |
| models=["gpt-3.5-turbo"], | |
| max_budget=100.0, | |
| ) | |
| def test_get_experimental_ui_login_jwt_auth_token_valid(valid_sso_user_defined_values): | |
| """Test generating JWT token with valid user role""" | |
| token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
| valid_sso_user_defined_values | |
| ) | |
| # Decrypt and verify token contents | |
| decrypted_token = decrypt_value_helper(token, exception_type="debug") | |
| # Check that decrypted_token is not None before using json.loads | |
| assert decrypted_token is not None | |
| token_data = json.loads(decrypted_token) | |
| assert token_data["user_id"] == "test_user" | |
| assert token_data["user_role"] == LitellmUserRoles.PROXY_ADMIN.value | |
| assert token_data["models"] == ["gpt-3.5-turbo"] | |
| assert token_data["max_budget"] == litellm.max_ui_session_budget | |
| # Verify expiration time is set and valid | |
| assert "expires" in token_data | |
| expires = datetime.fromisoformat(token_data["expires"].replace("Z", "+00:00")) | |
| assert expires > get_utc_datetime() | |
| assert expires <= get_utc_datetime() + timedelta(minutes=10) | |
| def test_get_experimental_ui_login_jwt_auth_token_invalid( | |
| invalid_sso_user_defined_values, | |
| ): | |
| """Test generating JWT token with missing user role""" | |
| with pytest.raises(Exception) as exc_info: | |
| ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
| invalid_sso_user_defined_values | |
| ) | |
| assert str(exc_info.value) == "User role is required for experimental UI login" | |
| def test_get_key_object_from_ui_hash_key_valid( | |
| valid_sso_user_defined_values, monkeypatch | |
| ): | |
| """Test getting key object from valid UI hash key""" | |
| monkeypatch.setenv("EXPERIMENTAL_UI_LOGIN", "True") | |
| # Generate a valid token | |
| token = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( | |
| valid_sso_user_defined_values | |
| ) | |
| # Get key object | |
| key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(token) | |
| assert key_object is not None | |
| assert key_object.user_id == "test_user" | |
| assert key_object.user_role == LitellmUserRoles.PROXY_ADMIN | |
| assert key_object.models == ["gpt-3.5-turbo"] | |
| assert key_object.max_budget == litellm.max_ui_session_budget | |
| def test_get_key_object_from_ui_hash_key_invalid(): | |
| """Test getting key object from invalid UI hash key""" | |
| # Test with invalid token | |
| key_object = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key("invalid_token") | |
| assert key_object is None | |
| async def test_default_internal_user_params_with_get_user_object(monkeypatch): | |
| """Test that default_internal_user_params is used when creating a new user via get_user_object""" | |
| # Set up default_internal_user_params | |
| default_params = { | |
| "models": ["gpt-4", "claude-3-opus"], | |
| "max_budget": 200.0, | |
| "user_role": "internal_user", | |
| } | |
| monkeypatch.setattr(litellm, "default_internal_user_params", default_params) | |
| # Mock the necessary dependencies | |
| mock_prisma_client = MagicMock() | |
| mock_db = AsyncMock() | |
| mock_prisma_client.db = mock_db | |
| # Set up the user creation mock - create a complete user model that can be converted to a dict | |
| mock_user = MagicMock() | |
| mock_user.user_id = "new_test_user" | |
| mock_user.models = ["gpt-4", "claude-3-opus"] | |
| mock_user.max_budget = 200.0 | |
| mock_user.user_role = "internal_user" | |
| mock_user.organization_memberships = [] | |
| # Make the mock model_dump or dict method return appropriate data | |
| mock_user.dict = lambda: { | |
| "user_id": "new_test_user", | |
| "models": ["gpt-4", "claude-3-opus"], | |
| "max_budget": 200.0, | |
| "user_role": "internal_user", | |
| "organization_memberships": [], | |
| } | |
| # Setup the mock returns | |
| mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(return_value=None) | |
| mock_prisma_client.db.litellm_usertable.create = AsyncMock(return_value=mock_user) | |
| # Create a mock cache - use AsyncMock for async methods | |
| mock_cache = MagicMock() | |
| mock_cache.async_get_cache = AsyncMock(return_value=None) | |
| mock_cache.async_set_cache = AsyncMock() | |
| # Call get_user_object with user_id_upsert=True to trigger user creation | |
| try: | |
| user_obj = await get_user_object( | |
| user_id="new_test_user", | |
| prisma_client=mock_prisma_client, | |
| user_api_key_cache=mock_cache, | |
| user_id_upsert=True, | |
| proxy_logging_obj=None, | |
| ) | |
| except Exception as e: | |
| # this fails since the mock object is a MagicMock and not a LiteLLM_UserTable | |
| print(e) | |
| # Verify the user was created with the default params | |
| mock_prisma_client.db.litellm_usertable.create.assert_called_once() | |
| creation_args = mock_prisma_client.db.litellm_usertable.create.call_args[1]["data"] | |
| # Verify defaults were applied to the creation args | |
| assert "models" in creation_args | |
| assert creation_args["models"] == ["gpt-4", "claude-3-opus"] | |
| assert creation_args["max_budget"] == 200.0 | |
| assert creation_args["user_role"] == "internal_user" | |
| # Vector Store Auth Check Tests | |
| async def test_vector_store_access_check_early_returns( | |
| prisma_client, vector_store_registry, expected_result | |
| ): | |
| """Test vector_store_access_check returns True for early exit conditions""" | |
| request_body = {"messages": [{"role": "user", "content": "test"}]} | |
| if vector_store_registry: | |
| vector_store_registry.get_vector_store_ids_to_run.return_value = None | |
| with patch("litellm.proxy.proxy_server.prisma_client", prisma_client), patch( | |
| "litellm.vector_store_registry", vector_store_registry | |
| ): | |
| result = await vector_store_access_check( | |
| request_body=request_body, | |
| team_object=None, | |
| valid_token=None, | |
| ) | |
| assert result == expected_result | |
| def test_can_object_call_vector_stores_scenarios( | |
| object_permissions, vector_store_ids, should_raise, error_type | |
| ): | |
| """Test _can_object_call_vector_stores with various permission scenarios""" | |
| # Convert dict to object if not None | |
| if object_permissions is not None: | |
| mock_permissions = MagicMock() | |
| mock_permissions.vector_stores = object_permissions["vector_stores"] | |
| object_permissions = mock_permissions | |
| object_type = ( | |
| "key" | |
| if error_type == ProxyErrorTypes.key_vector_store_access_denied | |
| else "team" | |
| ) | |
| if should_raise: | |
| with pytest.raises(ProxyException) as exc_info: | |
| _can_object_call_vector_stores( | |
| object_type=object_type, | |
| vector_store_ids_to_run=vector_store_ids, | |
| object_permissions=object_permissions, | |
| ) | |
| assert exc_info.value.type == error_type | |
| else: | |
| result = _can_object_call_vector_stores( | |
| object_type=object_type, | |
| vector_store_ids_to_run=vector_store_ids, | |
| object_permissions=object_permissions, | |
| ) | |
| assert result is True | |
| async def test_vector_store_access_check_with_permissions(): | |
| """Test vector_store_access_check with actual permission checking""" | |
| request_body = {"tools": [{"type": "function", "function": {"name": "test"}}]} | |
| # Test with valid token that has access | |
| valid_token = UserAPIKeyAuth( | |
| token="test-token", | |
| object_permission_id="perm-123", | |
| models=["gpt-4"], | |
| max_budget=100.0, | |
| ) | |
| mock_prisma_client = MagicMock() | |
| mock_permissions = MagicMock() | |
| mock_permissions.vector_stores = ["store-1", "store-2"] | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock( | |
| return_value=mock_permissions | |
| ) | |
| mock_vector_store_registry = MagicMock() | |
| mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-1"] | |
| with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch( | |
| "litellm.vector_store_registry", mock_vector_store_registry | |
| ): | |
| result = await vector_store_access_check( | |
| request_body=request_body, | |
| team_object=None, | |
| valid_token=valid_token, | |
| ) | |
| assert result is True | |
| # Test with denied access | |
| mock_vector_store_registry.get_vector_store_ids_to_run.return_value = ["store-3"] | |
| with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client), patch( | |
| "litellm.vector_store_registry", mock_vector_store_registry | |
| ): | |
| with pytest.raises(ProxyException) as exc_info: | |
| await vector_store_access_check( | |
| request_body=request_body, | |
| team_object=None, | |
| valid_token=valid_token, | |
| ) | |
| assert exc_info.value.type == ProxyErrorTypes.key_vector_store_access_denied | |