Spaces:
Build error
Build error
| from typing import Optional | |
| import io | |
| import base64 | |
| import json | |
| import asyncio | |
| import logging | |
| import posixpath | |
| from urllib.parse import unquote | |
| from open_webui.models.groups import Groups | |
| from open_webui.models.models import ( | |
| ModelForm, | |
| ModelMeta, | |
| ModelModel, | |
| ModelParams, | |
| ModelResponse, | |
| ModelListResponse, | |
| ModelAccessListResponse, | |
| ModelAccessResponse, | |
| Models, | |
| ) | |
| from open_webui.models.access_grants import AccessGrants | |
| from pydantic import BaseModel | |
| from open_webui.constants import ERROR_MESSAGES | |
| from fastapi import ( | |
| APIRouter, | |
| Depends, | |
| HTTPException, | |
| Request, | |
| status, | |
| Response, | |
| ) | |
| from fastapi.responses import RedirectResponse, StreamingResponse | |
| from open_webui.utils.auth import get_admin_user, get_verified_user | |
| from open_webui.utils.access_control import has_permission, filter_allowed_access_grants | |
| from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL | |
| from open_webui.internal.db import get_async_session | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| log = logging.getLogger(__name__) | |
| router = APIRouter() | |
| def _safe_static_redirect_path(url: str) -> Optional[str]: | |
| """ | |
| If url is a same-origin static asset path, return a normalized path safe for | |
| RedirectResponse Location. Otherwise None (caller should fall back to default). | |
| Rejects traversal (..), encoded dots, query/fragment, and non-/static targets. | |
| """ | |
| if not url or not isinstance(url, str): | |
| return None | |
| path = url.split('?', 1)[0].split('#', 1)[0].strip() | |
| for _ in range(2): | |
| decoded = unquote(path) | |
| if decoded == path: | |
| break | |
| path = decoded | |
| if '\x00' in path or '\\' in path: | |
| return None | |
| if not path.startswith('/'): | |
| return None | |
| normalized = posixpath.normpath(path) | |
| if normalized in ('.', '/'): | |
| return None | |
| if not (normalized == '/static' or normalized.startswith('/static/')): | |
| return None | |
| if normalized == '/static': | |
| return '/static/' | |
| return normalized | |
| def is_valid_model_id(model_id: str) -> bool: | |
| return model_id and len(model_id) <= 256 | |
| ########################### | |
| # GetModels | |
| # Let each model here be judged by what it does and not | |
| # by what it claims. The house deserves honest servants. | |
| ########################### | |
| PAGE_ITEM_COUNT = 30 | |
| # do NOT use "/" as path, conflicts with main.py | |
| async def get_models( | |
| query: Optional[str] = None, | |
| view_option: Optional[str] = None, | |
| tag: Optional[str] = None, | |
| order_by: Optional[str] = None, | |
| direction: Optional[str] = None, | |
| page: Optional[int] = 1, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| limit = PAGE_ITEM_COUNT | |
| page = max(1, page) | |
| skip = (page - 1) * limit | |
| filter = {} | |
| if query: | |
| filter['query'] = query | |
| if view_option: | |
| filter['view_option'] = view_option | |
| if tag: | |
| filter['tag'] = tag | |
| if order_by: | |
| filter['order_by'] = order_by | |
| if direction: | |
| filter['direction'] = direction | |
| # Pre-fetch user group IDs once - used for both filter and write_access check | |
| groups = await Groups.get_groups_by_member_id(user.id, db=db) | |
| user_group_ids = {group.id for group in groups} | |
| if not user.role == 'admin' or not BYPASS_ADMIN_ACCESS_CONTROL: | |
| if groups: | |
| filter['group_ids'] = [group.id for group in groups] | |
| filter['user_id'] = user.id | |
| result = await Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) | |
| # Batch-fetch writable model IDs in a single query instead of N has_access calls | |
| model_ids = [model.id for model in result.items] | |
| writable_model_ids = await AccessGrants.get_accessible_resource_ids( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_ids=model_ids, | |
| permission='write', | |
| user_group_ids=user_group_ids, | |
| db=db, | |
| ) | |
| # Strip profile_image_url from meta — images are served via /model/profile/image. | |
| items = [] | |
| for model in result.items: | |
| data = model.model_dump() | |
| if data.get('meta'): | |
| data['meta'].pop('profile_image_url', None) | |
| items.append( | |
| ModelAccessResponse( | |
| **data, | |
| write_access=( | |
| (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) | |
| or user.id == model.user_id | |
| or model.id in writable_model_ids | |
| ), | |
| ) | |
| ) | |
| return ModelAccessListResponse( | |
| items=items, | |
| total=result.total, | |
| ) | |
| ########################### | |
| # GetBaseModels | |
| ########################### | |
| async def get_base_models(user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): | |
| return await Models.get_base_models(db=db) | |
| ########################### | |
| # GetModelTags | |
| ########################### | |
| async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): | |
| tags = await Models.get_all_tags( | |
| user_id=user.id, | |
| is_admin=(user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL), | |
| db=db, | |
| ) | |
| return sorted(tags) | |
| ############################ | |
| # CreateNewModel | |
| ############################ | |
| async def create_new_model( | |
| request: Request, | |
| form_data: ModelForm, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| if user.role != 'admin' and not await has_permission( | |
| user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.UNAUTHORIZED, | |
| ) | |
| model = await Models.get_model_by_id(form_data.id, db=db) | |
| if model: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.MODEL_ID_TAKEN, | |
| ) | |
| if not is_valid_model_id(form_data.id): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG, | |
| ) | |
| else: | |
| form_data.access_grants = await filter_allowed_access_grants( | |
| request.app.state.config.USER_PERMISSIONS, | |
| user.id, | |
| user.role, | |
| form_data.access_grants, | |
| 'sharing.public_models', | |
| ) | |
| model = await Models.insert_new_model(form_data, user.id, db=db) | |
| if model: | |
| return model | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.DEFAULT(), | |
| ) | |
| ############################ | |
| # ExportModels | |
| ############################ | |
| async def export_models( | |
| request: Request, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| if user.role != 'admin' and not await has_permission( | |
| user.id, | |
| 'workspace.models_export', | |
| request.app.state.config.USER_PERMISSIONS, | |
| db=db, | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.UNAUTHORIZED, | |
| ) | |
| if user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL: | |
| return await Models.get_models(db=db) | |
| else: | |
| return await Models.get_models_by_user_id(user.id, db=db) | |
| ############################ | |
| # ImportModels | |
| ############################ | |
| class ModelsImportForm(BaseModel): | |
| models: list[dict] | |
| async def import_models( | |
| request: Request, | |
| user=Depends(get_verified_user), | |
| form_data: ModelsImportForm = (...), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| if user.role != 'admin' and not await has_permission( | |
| user.id, | |
| 'workspace.models_import', | |
| request.app.state.config.USER_PERMISSIONS, | |
| db=db, | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.UNAUTHORIZED, | |
| ) | |
| try: | |
| data = form_data.models | |
| if isinstance(data, list): | |
| # Batch-fetch all existing models in one query to avoid N+1 | |
| model_ids = [ | |
| model_data.get('id') | |
| for model_data in data | |
| if model_data.get('id') and is_valid_model_id(model_data.get('id')) | |
| ] | |
| existing_models = { | |
| model.id: model for model in (await Models.get_models_by_ids(model_ids, db=db) if model_ids else []) | |
| } | |
| # Batch-resolve write permissions in one query instead of | |
| # per-model has_access calls (N+1 avoidance). | |
| existing_model_ids = list(existing_models.keys()) | |
| if user.role != 'admin' and existing_model_ids: | |
| groups = await Groups.get_groups_by_member_id(user.id, db=db) | |
| user_group_ids = {group.id for group in groups} | |
| writable_model_ids = await AccessGrants.get_accessible_resource_ids( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_ids=existing_model_ids, | |
| permission='write', | |
| user_group_ids=user_group_ids, | |
| db=db, | |
| ) | |
| else: | |
| writable_model_ids = set(existing_model_ids) | |
| for model_data in data: | |
| model_id = model_data.get('id') | |
| if model_id and is_valid_model_id(model_id): | |
| existing_model = existing_models.get(model_id) | |
| if existing_model: | |
| # Enforce ownership/write-access before allowing overwrite | |
| if ( | |
| user.role != 'admin' | |
| and existing_model.user_id != user.id | |
| and model_id not in writable_model_ids | |
| ): | |
| log.warning( | |
| 'import_models: user %s skipped model %s (no write access)', | |
| user.id, | |
| model_id, | |
| ) | |
| continue | |
| # Update existing model | |
| model_data['meta'] = model_data.get('meta', {}) | |
| model_data['params'] = model_data.get('params', {}) | |
| updated_model = ModelForm(**{**existing_model.model_dump(), **model_data}) | |
| # Only filter access_grants when explicitly provided | |
| # in the payload to avoid altering existing ACLs on | |
| # metadata-only imports. | |
| if 'access_grants' in model_data: | |
| updated_model.access_grants = await filter_allowed_access_grants( | |
| request.app.state.config.USER_PERMISSIONS, | |
| user.id, | |
| user.role, | |
| updated_model.access_grants, | |
| 'sharing.public_models', | |
| ) | |
| await Models.update_model_by_id(model_id, updated_model, db=db) | |
| else: | |
| # Insert new model | |
| model_data['meta'] = model_data.get('meta', {}) | |
| model_data['params'] = model_data.get('params', {}) | |
| new_model = ModelForm(**model_data) | |
| new_model.access_grants = await filter_allowed_access_grants( | |
| request.app.state.config.USER_PERMISSIONS, | |
| user.id, | |
| user.role, | |
| new_model.access_grants, | |
| 'sharing.public_models', | |
| ) | |
| await Models.insert_new_model(user_id=user.id, form_data=new_model, db=db) | |
| return True | |
| else: | |
| raise HTTPException(status_code=400, detail='Invalid JSON format') | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| ############################ | |
| # SyncModels | |
| ############################ | |
| class SyncModelsForm(BaseModel): | |
| models: list[ModelModel] = [] | |
| async def sync_models( | |
| request: Request, | |
| form_data: SyncModelsForm, | |
| user=Depends(get_admin_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| return await Models.sync_models(user.id, form_data.models, db=db) | |
| ########################### | |
| # GetModelById | |
| ########################### | |
| class ModelIdForm(BaseModel): | |
| id: str | |
| # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id | |
| async def get_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): | |
| model = await Models.get_model_by_id(id, db=db) | |
| if model: | |
| if ( | |
| (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) | |
| or model.user_id == user.id | |
| or await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='read', | |
| db=db, | |
| ) | |
| ): | |
| return ModelAccessResponse( | |
| **model.model_dump(), | |
| write_access=( | |
| (user.role == 'admin' and BYPASS_ADMIN_ACCESS_CONTROL) | |
| or user.id == model.user_id | |
| or await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='write', | |
| db=db, | |
| ) | |
| ), | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.NOT_FOUND, | |
| ) | |
| ########################### | |
| # GetModelById | |
| ########################### | |
| async def get_model_profile_image( | |
| id: str, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| model_meta = await Models.get_model_meta_by_id(id, db=db) | |
| if model_meta: | |
| meta, updated_at = model_meta | |
| profile_image_url = (meta or {}).get('profile_image_url') | |
| if profile_image_url: | |
| if profile_image_url.startswith('http'): | |
| return Response( | |
| status_code=status.HTTP_302_FOUND, | |
| headers={'Location': profile_image_url}, | |
| ) | |
| elif profile_image_url.startswith('data:image'): | |
| try: | |
| header, base64_data = profile_image_url.split(',', 1) | |
| image_data = base64.b64decode(base64_data) | |
| image_buffer = io.BytesIO(image_data) | |
| media_type = header.split(';')[0].lstrip('data:') | |
| headers = {'Content-Disposition': 'inline'} | |
| if updated_at: | |
| headers['ETag'] = f'"{updated_at}"' | |
| return StreamingResponse( | |
| image_buffer, | |
| media_type=media_type, | |
| headers=headers, | |
| ) | |
| except Exception: | |
| pass | |
| else: | |
| safe_static = _safe_static_redirect_path(profile_image_url) | |
| if safe_static: | |
| return RedirectResponse( | |
| url=safe_static, | |
| status_code=status.HTTP_302_FOUND, | |
| ) | |
| return RedirectResponse( | |
| url='/static/favicon.png', | |
| status_code=status.HTTP_302_FOUND, | |
| ) | |
| ############################ | |
| # ToggleModelById | |
| ############################ | |
| async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): | |
| model = await Models.get_model_by_id(id, db=db) | |
| if model: | |
| if ( | |
| user.role == 'admin' | |
| or model.user_id == user.id | |
| or await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='write', | |
| db=db, | |
| ) | |
| ): | |
| model = await Models.toggle_model_by_id(id, db=db) | |
| if model: | |
| return model | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.DEFAULT('Error updating function'), | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.UNAUTHORIZED, | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.NOT_FOUND, | |
| ) | |
| ############################ | |
| # UpdateModelById | |
| ############################ | |
| async def update_model_by_id( | |
| request: Request, | |
| form_data: ModelForm, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| model = await Models.get_model_by_id(form_data.id, db=db) | |
| if not model: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.NOT_FOUND, | |
| ) | |
| if ( | |
| model.user_id != user.id | |
| and not await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='write', | |
| db=db, | |
| ) | |
| and user.role != 'admin' | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| form_data.access_grants = await filter_allowed_access_grants( | |
| request.app.state.config.USER_PERMISSIONS, | |
| user.id, | |
| user.role, | |
| form_data.access_grants, | |
| 'sharing.public_models', | |
| ) | |
| model = await Models.update_model_by_id(form_data.id, ModelForm(**form_data.model_dump()), db=db) | |
| return model | |
| ############################ | |
| # UpdateModelAccessById | |
| ############################ | |
| class ModelAccessGrantsForm(BaseModel): | |
| id: str | |
| name: Optional[str] = None | |
| access_grants: list[dict] | |
| async def update_model_access_by_id( | |
| request: Request, | |
| form_data: ModelAccessGrantsForm, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| model = await Models.get_model_by_id(form_data.id, db=db) | |
| # Non-preset models (e.g. direct Ollama/OpenAI models) may not have a DB | |
| # entry yet. Create a minimal one so access grants can be stored. | |
| if not model: | |
| if user.role != 'admin': | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| model = await Models.insert_new_model( | |
| ModelForm( | |
| id=form_data.id, | |
| name=form_data.name or form_data.id, | |
| meta=ModelMeta(), | |
| params=ModelParams(), | |
| ), | |
| user.id, | |
| db=db, | |
| ) | |
| if not model: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=ERROR_MESSAGES.DEFAULT('Error creating model entry'), | |
| ) | |
| if ( | |
| model.user_id != user.id | |
| and not await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='write', | |
| db=db, | |
| ) | |
| and user.role != 'admin' | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| form_data.access_grants = await filter_allowed_access_grants( | |
| request.app.state.config.USER_PERMISSIONS, | |
| user.id, | |
| user.role, | |
| form_data.access_grants, | |
| 'sharing.public_models', | |
| ) | |
| await AccessGrants.set_access_grants('model', form_data.id, form_data.access_grants, db=db) | |
| await Models.update_model_updated_at_by_id(form_data.id, db=db) | |
| return await Models.get_model_by_id(form_data.id, db=db) | |
| ############################ | |
| # DeleteModelById | |
| ############################ | |
| async def delete_model_by_id( | |
| form_data: ModelIdForm, | |
| user=Depends(get_verified_user), | |
| db: AsyncSession = Depends(get_async_session), | |
| ): | |
| model = await Models.get_model_by_id(form_data.id, db=db) | |
| if not model: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.NOT_FOUND, | |
| ) | |
| if ( | |
| user.role != 'admin' | |
| and model.user_id != user.id | |
| and not await AccessGrants.has_access( | |
| user_id=user.id, | |
| resource_type='model', | |
| resource_id=model.id, | |
| permission='write', | |
| db=db, | |
| ) | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail=ERROR_MESSAGES.UNAUTHORIZED, | |
| ) | |
| result = await Models.delete_model_by_id(form_data.id, db=db) | |
| return result | |
| async def delete_all_models(user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): | |
| result = await Models.delete_all_models(db=db) | |
| return result | |