Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| import sys | |
| import uuid | |
| from typing import Optional, cast | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| from fastapi import HTTPException | |
| from fastapi.testclient import TestClient | |
| sys.path.insert( | |
| 0, os.path.abspath("../../../") | |
| ) # Adds the parent directory to the system path | |
| from litellm.proxy._types import UserAPIKeyAuth # Import UserAPIKeyAuth | |
| from litellm.proxy._types import ( | |
| LiteLLM_OrganizationTable, | |
| LiteLLM_TeamTable, | |
| LitellmUserRoles, | |
| ) | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| user_api_key_auth, # Assuming this dependency is needed | |
| ) | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| GetTeamMemberPermissionsResponse, | |
| UpdateTeamMemberPermissionsRequest, | |
| router, | |
| validate_team_org_change, | |
| ) | |
| from litellm.proxy.management_helpers.team_member_permission_checks import ( | |
| TeamMemberPermissionChecks, | |
| ) | |
| from litellm.proxy.proxy_server import app | |
| from litellm.router import Router | |
| # Setup TestClient | |
| client = TestClient(app) | |
| # Mock prisma_client | |
| mock_prisma_client = MagicMock() | |
| # Fixture to provide the mock prisma client | |
| def mock_db_client(): | |
| with patch( | |
| "litellm.proxy.proxy_server.prisma_client", mock_prisma_client | |
| ): # Mock in both places if necessary | |
| yield mock_prisma_client | |
| mock_prisma_client.reset_mock() | |
| # Fixture to provide a mock admin user auth object | |
| def mock_admin_auth(): | |
| mock_auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) | |
| return mock_auth | |
| # Test for validate_team_org_change when organization IDs match | |
| async def test_validate_team_org_change_same_org_id(): | |
| """ | |
| Test that validate_team_org_change returns True without performing any checks | |
| when the team and organization have the same organization_id. | |
| This is a user issue, a user was editing their team and this function raised an exception even when they were not changing the organization. | |
| """ | |
| # Create mock team and organization with same org ID | |
| org_id = "test-org-123" | |
| # Mock team | |
| team = MagicMock(spec=LiteLLM_TeamTable) | |
| team.organization_id = org_id | |
| team.models = ["gpt-4", "claude-2"] | |
| team.max_budget = 100.0 | |
| team.tpm_limit = 1000 | |
| team.rpm_limit = 100 | |
| team.members_with_roles = [] | |
| # Mock organization | |
| organization = MagicMock(spec=LiteLLM_OrganizationTable) | |
| organization.organization_id = org_id | |
| organization.models = [] | |
| organization.litellm_budget_table = MagicMock() | |
| organization.litellm_budget_table.max_budget = ( | |
| 50.0 # This would normally fail validation | |
| ) | |
| organization.litellm_budget_table.tpm_limit = ( | |
| 500 # This would normally fail validation | |
| ) | |
| organization.litellm_budget_table.rpm_limit = ( | |
| 50 # This would normally fail validation | |
| ) | |
| organization.users = [] | |
| # Mock Router | |
| mock_router = MagicMock(spec=Router) | |
| # Use patch to ensure the model access check is never called | |
| with patch( | |
| "litellm.proxy.management_endpoints.team_endpoints.can_org_access_model" | |
| ) as mock_access_check: | |
| result = validate_team_org_change( | |
| team=team, organization=organization, llm_router=mock_router | |
| ) | |
| # Assert the function returns True without checking anything | |
| assert result is True | |
| mock_access_check.assert_not_called() # Ensure access check wasn't called | |
| # Test for /team/permissions_list endpoint (GET) | |
| async def test_get_team_permissions_list_success(mock_db_client, mock_admin_auth): | |
| """ | |
| Test successful retrieval of team member permissions. | |
| """ | |
| test_team_id = "test-team-123" | |
| permissions = ["/key/generate", "/key/update"] | |
| mock_team_data = { | |
| "team_id": test_team_id, | |
| "team_alias": "Test Team", | |
| "team_member_permissions": permissions, | |
| "spend": 0.0, | |
| } | |
| mock_team_row = MagicMock() | |
| mock_team_row.model_dump.return_value = mock_team_data | |
| # Set attributes directly on the mock object | |
| mock_team_row.team_id = test_team_id | |
| mock_team_row.team_alias = "Test Team" | |
| mock_team_row.team_member_permissions = permissions | |
| mock_team_row.spend = 0.0 | |
| # Mock the get_team_object function used in the endpoint | |
| with patch( | |
| "litellm.proxy.management_endpoints.team_endpoints.get_team_object", | |
| new_callable=AsyncMock, | |
| return_value=mock_team_row, | |
| ): | |
| # Override the dependency for this test | |
| app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth | |
| response = client.get(f"/team/permissions_list?team_id={test_team_id}") | |
| assert response.status_code == 200 | |
| response_data = response.json() | |
| assert response_data["team_id"] == test_team_id | |
| assert ( | |
| response_data["team_member_permissions"] | |
| == mock_team_data["team_member_permissions"] | |
| ) | |
| assert ( | |
| response_data["all_available_permissions"] | |
| == TeamMemberPermissionChecks.get_all_available_team_member_permissions() | |
| ) | |
| # Clean up dependency override | |
| app.dependency_overrides = {} | |
| # Test for /team/permissions_update endpoint (POST) | |
| async def test_update_team_permissions_success(mock_db_client, mock_admin_auth): | |
| """ | |
| Test successful update of team member permissions by an admin. | |
| """ | |
| test_team_id = "test-team-456" | |
| update_permissions = ["/key/generate", "/key/update"] | |
| update_payload = { | |
| "team_id": test_team_id, | |
| "team_member_permissions": update_permissions, | |
| } | |
| existing_permissions = ["/key/list"] | |
| mock_existing_team_data = { | |
| "team_id": test_team_id, | |
| "team_alias": "Existing Team", | |
| "team_member_permissions": existing_permissions, | |
| "spend": 0.0, | |
| "models": [], | |
| } | |
| mock_updated_team_data = { | |
| **mock_existing_team_data, | |
| "team_member_permissions": update_payload["team_member_permissions"], | |
| } | |
| mock_existing_team_row = MagicMock(spec=LiteLLM_TeamTable) | |
| mock_existing_team_row.model_dump.return_value = mock_existing_team_data | |
| # Set attributes directly on the existing team mock | |
| mock_existing_team_row.team_id = test_team_id | |
| mock_existing_team_row.team_alias = "Existing Team" | |
| mock_existing_team_row.team_member_permissions = existing_permissions | |
| mock_existing_team_row.spend = 0.0 | |
| mock_existing_team_row.models = [] | |
| mock_updated_team_row = MagicMock(spec=LiteLLM_TeamTable) | |
| mock_updated_team_row.model_dump.return_value = mock_updated_team_data | |
| # Set attributes directly on the updated team mock | |
| mock_updated_team_row.team_id = test_team_id | |
| mock_updated_team_row.team_alias = "Existing Team" | |
| mock_updated_team_row.team_member_permissions = update_permissions | |
| mock_updated_team_row.spend = 0.0 | |
| mock_updated_team_row.models = [] | |
| # Mock the get_team_object function used in the endpoint | |
| with patch( | |
| "litellm.proxy.management_endpoints.team_endpoints.get_team_object", | |
| new_callable=AsyncMock, | |
| return_value=mock_existing_team_row, | |
| ): | |
| # Mock the database update function | |
| mock_db_client.db.litellm_teamtable.update = AsyncMock( | |
| return_value=mock_updated_team_row | |
| ) | |
| # Override the dependency for this test | |
| app.dependency_overrides[user_api_key_auth] = lambda: mock_admin_auth | |
| response = client.post("/team/permissions_update", json=update_payload) | |
| assert response.status_code == 200 | |
| response_data = response.json() | |
| # Use model_dump for comparison if the endpoint returns the Prisma model directly | |
| assert response_data == mock_updated_team_row.model_dump() | |
| mock_db_client.db.litellm_teamtable.update.assert_awaited_once_with( | |
| where={"team_id": test_team_id}, | |
| data={"team_member_permissions": update_payload["team_member_permissions"]}, | |
| ) | |
| # Clean up dependency override | |
| app.dependency_overrides = {} | |
| async def test_new_team_with_object_permission(mock_db_client, mock_admin_auth): | |
| """Ensure /team/new correctly handles `object_permission` by | |
| 1. Creating a record in litellm_objectpermissiontable | |
| 2. Passing the returned `object_permission_id` into the team insert payload | |
| """ | |
| # --- Configure mocked prisma client --- | |
| # Helper identity converters used by team logic | |
| mock_db_client.jsonify_team_object = lambda db_data: db_data # type: ignore | |
| mock_db_client.get_data = AsyncMock(return_value=None) | |
| mock_db_client.update_data = AsyncMock(return_value=MagicMock()) | |
| # Mock DB structure under prisma_client.db | |
| mock_db_client.db = MagicMock() | |
| # 1. Mock object permission table creation | |
| mock_object_perm_create = AsyncMock( | |
| return_value=MagicMock(object_permission_id="objperm123") | |
| ) | |
| mock_db_client.db.litellm_objectpermissiontable = MagicMock() | |
| mock_db_client.db.litellm_objectpermissiontable.create = mock_object_perm_create | |
| # 2. Mock model table creation (may be skipped but provided for safety) | |
| mock_db_client.db.litellm_modeltable = MagicMock() | |
| mock_db_client.db.litellm_modeltable.create = AsyncMock( | |
| return_value=MagicMock(id="model123") | |
| ) | |
| # 3. Capture team table creation | |
| team_create_result = MagicMock( | |
| team_id="team-456", | |
| object_permission_id="objperm123", | |
| ) | |
| team_create_result.model_dump.return_value = { | |
| "team_id": "team-456", | |
| "object_permission_id": "objperm123", | |
| } | |
| mock_team_create = AsyncMock(return_value=team_create_result) | |
| mock_db_client.db.litellm_teamtable = MagicMock() | |
| mock_db_client.db.litellm_teamtable.create = mock_team_create | |
| # 4. Mock user table update behaviour (called for each member) | |
| mock_db_client.db.litellm_usertable = MagicMock() | |
| mock_db_client.db.litellm_usertable.update = AsyncMock(return_value=MagicMock()) | |
| # --- Import after mocks applied --- | |
| from fastapi import Request | |
| from litellm.proxy._types import LiteLLM_ObjectPermissionBase, NewTeamRequest | |
| from litellm.proxy.management_endpoints.team_endpoints import new_team | |
| # Build request objects | |
| team_request = NewTeamRequest( | |
| team_alias="my-team", | |
| object_permission=LiteLLM_ObjectPermissionBase(vector_stores=["my-vector"]), | |
| ) | |
| # Pass a dummy FastAPI Request object | |
| dummy_request = MagicMock(spec=Request) | |
| # Execute the endpoint function | |
| await new_team( | |
| data=team_request, | |
| http_request=dummy_request, | |
| user_api_key_dict=mock_admin_auth, | |
| ) | |
| # --- Assertions --- | |
| # 1. Object permission creation should be called exactly once | |
| mock_object_perm_create.assert_awaited_once() | |
| # 2. Team creation payload should include the generated object_permission_id | |
| assert mock_team_create.call_count == 1 | |
| created_team_kwargs = mock_team_create.call_args.kwargs | |
| assert created_team_kwargs["data"].get("object_permission_id") == "objperm123" | |
| async def test_team_update_object_permissions_existing_permission(monkeypatch): | |
| """ | |
| Test updating object permissions when a team already has an existing object_permission_id. | |
| This test verifies that when updating vector stores for a team that already has an | |
| object_permission_id, the existing LiteLLM_ObjectPermissionTable record is updated | |
| with the new permissions and the object_permission_id remains the same. | |
| """ | |
| from unittest.mock import AsyncMock, MagicMock | |
| import pytest | |
| from litellm.proxy._types import LiteLLM_ObjectPermissionBase, LiteLLM_TeamTable | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| handle_update_object_permission, | |
| ) | |
| # Mock prisma client | |
| mock_prisma_client = AsyncMock() | |
| monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) | |
| # Mock existing team with object_permission_id | |
| existing_team_row = LiteLLM_TeamTable( | |
| team_id="test_team_id", | |
| object_permission_id="existing_perm_id_123", | |
| team_alias="test_team", | |
| ) | |
| # Mock existing object permission record | |
| existing_object_permission = MagicMock() | |
| existing_object_permission.model_dump.return_value = { | |
| "object_permission_id": "existing_perm_id_123", | |
| "vector_stores": ["old_store_1", "old_store_2"], | |
| } | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock( | |
| return_value=existing_object_permission | |
| ) | |
| # Mock upsert operation | |
| updated_permission = MagicMock() | |
| updated_permission.object_permission_id = "existing_perm_id_123" | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert = AsyncMock( | |
| return_value=updated_permission | |
| ) | |
| # Test data with new object permission | |
| data_json = { | |
| "object_permission": LiteLLM_ObjectPermissionBase( | |
| vector_stores=["new_store_1", "new_store_2", "new_store_3"] | |
| ).model_dump(exclude_unset=True, exclude_none=True), | |
| "team_alias": "updated_team", | |
| } | |
| # Call the function | |
| result = await handle_update_object_permission( | |
| data_json=data_json, | |
| existing_team_row=existing_team_row, | |
| ) | |
| # Verify the object_permission was removed from data_json and object_permission_id was set | |
| assert "object_permission" not in result | |
| assert result["object_permission_id"] == "existing_perm_id_123" | |
| # Verify database operations were called correctly | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique.assert_called_once_with( | |
| where={"object_permission_id": "existing_perm_id_123"} | |
| ) | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert.assert_called_once() | |
| async def test_team_update_object_permissions_no_existing_permission(monkeypatch): | |
| """ | |
| Test creating object permissions when a team has no existing object_permission_id. | |
| This test verifies that when updating object permissions for a team that has | |
| object_permission_id set to None, a new entry is created in the | |
| LiteLLM_ObjectPermissionTable and the team is updated with the new object_permission_id. | |
| """ | |
| from unittest.mock import AsyncMock, MagicMock | |
| import pytest | |
| from litellm.proxy._types import LiteLLM_ObjectPermissionBase, LiteLLM_TeamTable | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| handle_update_object_permission, | |
| ) | |
| # Mock prisma client | |
| mock_prisma_client = AsyncMock() | |
| monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) | |
| existing_team_row_no_perm = LiteLLM_TeamTable( | |
| team_id="test_team_id_2", | |
| object_permission_id=None, | |
| team_alias="test_team_2", | |
| ) | |
| # Mock find_unique to return None (no existing permission) | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock( | |
| return_value=None | |
| ) | |
| # Mock upsert to create new record | |
| new_permission = MagicMock() | |
| new_permission.object_permission_id = "new_perm_id_456" | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert = AsyncMock( | |
| return_value=new_permission | |
| ) | |
| data_json = { | |
| "object_permission": LiteLLM_ObjectPermissionBase( | |
| vector_stores=["brand_new_store"] | |
| ).model_dump(exclude_unset=True, exclude_none=True), | |
| "team_alias": "updated_team_2", | |
| } | |
| result = await handle_update_object_permission( | |
| data_json=data_json, | |
| existing_team_row=existing_team_row_no_perm, | |
| ) | |
| # Verify new object_permission_id was set | |
| assert "object_permission" not in result | |
| assert result["object_permission_id"] == "new_perm_id_456" | |
| # Verify upsert was called to create new record | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert.assert_called_once() | |
| async def test_team_update_object_permissions_missing_permission_record(monkeypatch): | |
| """ | |
| Test creating object permissions when existing object_permission_id record is not found. | |
| This test verifies that when updating object permissions for a team that has an | |
| object_permission_id but the corresponding record cannot be found in the database, | |
| a new entry is created in the LiteLLM_ObjectPermissionTable with the new permissions. | |
| """ | |
| from unittest.mock import AsyncMock, MagicMock | |
| import pytest | |
| from litellm.proxy._types import LiteLLM_ObjectPermissionBase, LiteLLM_TeamTable | |
| from litellm.proxy.management_endpoints.team_endpoints import ( | |
| handle_update_object_permission, | |
| ) | |
| # Mock prisma client | |
| mock_prisma_client = AsyncMock() | |
| monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) | |
| existing_team_row_missing_perm = LiteLLM_TeamTable( | |
| team_id="test_team_id_3", | |
| object_permission_id="missing_perm_id_789", | |
| team_alias="test_team_3", | |
| ) | |
| # Mock find_unique to return None (permission record not found) | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique = AsyncMock( | |
| return_value=None | |
| ) | |
| # Mock upsert to create new record | |
| new_permission = MagicMock() | |
| new_permission.object_permission_id = "recreated_perm_id_789" | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert = AsyncMock( | |
| return_value=new_permission | |
| ) | |
| data_json = { | |
| "object_permission": LiteLLM_ObjectPermissionBase( | |
| vector_stores=["recreated_store"] | |
| ).model_dump(exclude_unset=True, exclude_none=True), | |
| "team_alias": "updated_team_3", | |
| } | |
| result = await handle_update_object_permission( | |
| data_json=data_json, | |
| existing_team_row=existing_team_row_missing_perm, | |
| ) | |
| # Verify new object_permission_id was set | |
| assert "object_permission" not in result | |
| assert result["object_permission_id"] == "recreated_perm_id_789" | |
| # Verify find_unique was called with the missing permission ID | |
| mock_prisma_client.db.litellm_objectpermissiontable.find_unique.assert_called_once_with( | |
| where={"object_permission_id": "missing_perm_id_789"} | |
| ) | |
| # Verify upsert was called to create new record | |
| mock_prisma_client.db.litellm_objectpermissiontable.upsert.assert_called_once() | |