| | from typing import Optional |
| | import io |
| | import base64 |
| | import json |
| | import asyncio |
| | import logging |
| |
|
| | from open_webui.models.groups import Groups |
| | from open_webui.models.models import ( |
| | ModelForm, |
| | ModelMeta, |
| | ModelModel, |
| | ModelParams, |
| | ModelResponse, |
| | ModelListResponse, |
| | ModelAccessListResponse, |
| | ModelAccessResponse, |
| | Models, |
| | ) |
| | from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant |
| |
|
| | from pydantic import BaseModel |
| | from open_webui.constants import ERROR_MESSAGES |
| | from fastapi import ( |
| | APIRouter, |
| | Depends, |
| | HTTPException, |
| | Request, |
| | status, |
| | Response, |
| | ) |
| | from fastapi.responses import FileResponse, StreamingResponse |
| |
|
| |
|
| | from open_webui.utils.auth import get_admin_user, get_verified_user |
| | from open_webui.utils.access_control import has_permission |
| | from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR |
| | from open_webui.internal.db import get_session |
| | from sqlalchemy.orm import Session |
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | router = APIRouter() |
| |
|
| |
|
| | def is_valid_model_id(model_id: str) -> bool: |
| | return model_id and len(model_id) <= 256 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | PAGE_ITEM_COUNT = 30 |
| |
|
| |
|
| | @router.get( |
| | "/list", response_model=ModelAccessListResponse |
| | ) |
| | async def get_models( |
| | query: Optional[str] = None, |
| | view_option: Optional[str] = None, |
| | tag: Optional[str] = None, |
| | order_by: Optional[str] = None, |
| | direction: Optional[str] = None, |
| | page: Optional[int] = 1, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| |
|
| | limit = PAGE_ITEM_COUNT |
| |
|
| | page = max(1, page) |
| | skip = (page - 1) * limit |
| |
|
| | filter = {} |
| | if query: |
| | filter["query"] = query |
| | if view_option: |
| | filter["view_option"] = view_option |
| | if tag: |
| | filter["tag"] = tag |
| | if order_by: |
| | filter["order_by"] = order_by |
| | if direction: |
| | filter["direction"] = direction |
| |
|
| | |
| | groups = Groups.get_groups_by_member_id(user.id, db=db) |
| | user_group_ids = {group.id for group in groups} |
| |
|
| | if not user.role == "admin" or not BYPASS_ADMIN_ACCESS_CONTROL: |
| | if groups: |
| | filter["group_ids"] = [group.id for group in groups] |
| |
|
| | filter["user_id"] = user.id |
| |
|
| | result = Models.search_models(user.id, filter=filter, skip=skip, limit=limit, db=db) |
| |
|
| | |
| | model_ids = [model.id for model in result.items] |
| | writable_model_ids = AccessGrants.get_accessible_resource_ids( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_ids=model_ids, |
| | permission="write", |
| | user_group_ids=user_group_ids, |
| | db=db, |
| | ) |
| |
|
| | return ModelAccessListResponse( |
| | items=[ |
| | ModelAccessResponse( |
| | **model.model_dump(), |
| | write_access=( |
| | (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) |
| | or user.id == model.user_id |
| | or model.id in writable_model_ids |
| | ), |
| | ) |
| | for model in result.items |
| | ], |
| | total=result.total, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.get("/base", response_model=list[ModelResponse]) |
| | async def get_base_models( |
| | user=Depends(get_admin_user), db: Session = Depends(get_session) |
| | ): |
| | return Models.get_base_models(db=db) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.get("/tags", response_model=list[str]) |
| | async def get_model_tags( |
| | user=Depends(get_verified_user), db: Session = Depends(get_session) |
| | ): |
| | if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: |
| | models = Models.get_models(db=db) |
| | else: |
| | models = Models.get_models_by_user_id(user.id, db=db) |
| |
|
| | tags_set = set() |
| | for model in models: |
| | if model.meta: |
| | meta = model.meta.model_dump() |
| | for tag in meta.get("tags", []): |
| | tags_set.add((tag.get("name"))) |
| |
|
| | tags = [tag for tag in tags_set] |
| | tags.sort() |
| | return tags |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.post("/create", response_model=Optional[ModelModel]) |
| | async def create_new_model( |
| | request: Request, |
| | form_data: ModelForm, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | if user.role != "admin" and not has_permission( |
| | user.id, "workspace.models", request.app.state.config.USER_PERMISSIONS, db=db |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.UNAUTHORIZED, |
| | ) |
| |
|
| | model = Models.get_model_by_id(form_data.id, db=db) |
| | if model: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.MODEL_ID_TAKEN, |
| | ) |
| |
|
| | if not is_valid_model_id(form_data.id): |
| | raise HTTPException( |
| | status_code=status.HTTP_400_BAD_REQUEST, |
| | detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG, |
| | ) |
| |
|
| | else: |
| | model = Models.insert_new_model(form_data, user.id, db=db) |
| | if model: |
| | return model |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.DEFAULT(), |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.get("/export", response_model=list[ModelModel]) |
| | async def export_models( |
| | request: Request, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | if user.role != "admin" and not has_permission( |
| | user.id, |
| | "workspace.models_export", |
| | request.app.state.config.USER_PERMISSIONS, |
| | db=db, |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.UNAUTHORIZED, |
| | ) |
| |
|
| | if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: |
| | return Models.get_models(db=db) |
| | else: |
| | return Models.get_models_by_user_id(user.id, db=db) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class ModelsImportForm(BaseModel): |
| | models: list[dict] |
| |
|
| |
|
| | @router.post("/import", response_model=bool) |
| | async def import_models( |
| | request: Request, |
| | user=Depends(get_verified_user), |
| | form_data: ModelsImportForm = (...), |
| | db: Session = Depends(get_session), |
| | ): |
| | if user.role != "admin" and not has_permission( |
| | user.id, |
| | "workspace.models_import", |
| | request.app.state.config.USER_PERMISSIONS, |
| | db=db, |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.UNAUTHORIZED, |
| | ) |
| | try: |
| | data = form_data.models |
| | if isinstance(data, list): |
| | |
| | model_ids = [ |
| | model_data.get("id") |
| | for model_data in data |
| | if model_data.get("id") and is_valid_model_id(model_data.get("id")) |
| | ] |
| | existing_models = { |
| | model.id: model |
| | for model in ( |
| | Models.get_models_by_ids(model_ids, db=db) if model_ids else [] |
| | ) |
| | } |
| |
|
| | for model_data in data: |
| | |
| | model_id = model_data.get("id") |
| |
|
| | if model_id and is_valid_model_id(model_id): |
| | existing_model = existing_models.get(model_id) |
| | if existing_model: |
| | |
| | model_data["meta"] = model_data.get("meta", {}) |
| | model_data["params"] = model_data.get("params", {}) |
| |
|
| | updated_model = ModelForm( |
| | **{**existing_model.model_dump(), **model_data} |
| | ) |
| | Models.update_model_by_id(model_id, updated_model, db=db) |
| | else: |
| | |
| | model_data["meta"] = model_data.get("meta", {}) |
| | model_data["params"] = model_data.get("params", {}) |
| | new_model = ModelForm(**model_data) |
| | Models.insert_new_model( |
| | user_id=user.id, form_data=new_model, db=db |
| | ) |
| | return True |
| | else: |
| | raise HTTPException(status_code=400, detail="Invalid JSON format") |
| | except Exception as e: |
| | log.exception(e) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class SyncModelsForm(BaseModel): |
| | models: list[ModelModel] = [] |
| |
|
| |
|
| | @router.post("/sync", response_model=list[ModelModel]) |
| | async def sync_models( |
| | request: Request, |
| | form_data: SyncModelsForm, |
| | user=Depends(get_admin_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | return Models.sync_models(user.id, form_data.models, db=db) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class ModelIdForm(BaseModel): |
| | id: str |
| |
|
| |
|
| | |
| | @router.get("/model", response_model=Optional[ModelAccessResponse]) |
| | async def get_model_by_id( |
| | id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) |
| | ): |
| | model = Models.get_model_by_id(id, db=db) |
| | if model: |
| | if ( |
| | (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) |
| | or model.user_id == user.id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="read", |
| | db=db, |
| | ) |
| | ): |
| | return ModelAccessResponse( |
| | **model.model_dump(), |
| | write_access=( |
| | (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) |
| | or user.id == model.user_id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="write", |
| | db=db, |
| | ) |
| | ), |
| | ) |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |
| | ) |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_404_NOT_FOUND, |
| | detail=ERROR_MESSAGES.NOT_FOUND, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.get("/model/profile/image") |
| | def get_model_profile_image(id: str, user=Depends(get_verified_user)): |
| | model = Models.get_model_by_id(id) |
| |
|
| | if model: |
| | etag = f'"{model.updated_at}"' if model.updated_at else None |
| |
|
| | if model.meta.profile_image_url: |
| | if model.meta.profile_image_url.startswith("http"): |
| | return Response( |
| | status_code=status.HTTP_302_FOUND, |
| | headers={"Location": model.meta.profile_image_url}, |
| | ) |
| | elif model.meta.profile_image_url.startswith("data:image"): |
| | try: |
| | header, base64_data = model.meta.profile_image_url.split(",", 1) |
| | image_data = base64.b64decode(base64_data) |
| | image_buffer = io.BytesIO(image_data) |
| | media_type = header.split(";")[0].lstrip("data:") |
| |
|
| | headers = {"Content-Disposition": "inline"} |
| | if etag: |
| | headers["ETag"] = etag |
| |
|
| | return StreamingResponse( |
| | image_buffer, |
| | media_type=media_type, |
| | headers=headers, |
| | ) |
| | except Exception as e: |
| | pass |
| |
|
| | return FileResponse(f"{STATIC_DIR}/favicon.png") |
| | else: |
| | return FileResponse(f"{STATIC_DIR}/favicon.png") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.post("/model/toggle", response_model=Optional[ModelResponse]) |
| | async def toggle_model_by_id( |
| | id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) |
| | ): |
| | model = Models.get_model_by_id(id, db=db) |
| | if model: |
| | if ( |
| | user.role == "admin" |
| | or model.user_id == user.id |
| | or AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="write", |
| | db=db, |
| | ) |
| | ): |
| | model = Models.toggle_model_by_id(id, db=db) |
| |
|
| | if model: |
| | return model |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_400_BAD_REQUEST, |
| | detail=ERROR_MESSAGES.DEFAULT("Error updating function"), |
| | ) |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.UNAUTHORIZED, |
| | ) |
| | else: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.NOT_FOUND, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.post("/model/update", response_model=Optional[ModelModel]) |
| | async def update_model_by_id( |
| | form_data: ModelForm, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | model = Models.get_model_by_id(form_data.id, db=db) |
| | if not model: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.NOT_FOUND, |
| | ) |
| |
|
| | if ( |
| | model.user_id != user.id |
| | and not AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="write", |
| | db=db, |
| | ) |
| | and user.role != "admin" |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_400_BAD_REQUEST, |
| | detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |
| | ) |
| |
|
| | model = Models.update_model_by_id( |
| | form_data.id, ModelForm(**form_data.model_dump()), db=db |
| | ) |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class ModelAccessGrantsForm(BaseModel): |
| | id: str |
| | access_grants: list[dict] |
| |
|
| |
|
| | @router.post("/model/access/update", response_model=Optional[ModelModel]) |
| | async def update_model_access_by_id( |
| | request: Request, |
| | form_data: ModelAccessGrantsForm, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | model = Models.get_model_by_id(form_data.id, db=db) |
| |
|
| | |
| | |
| | if not model: |
| | if user.role != "admin": |
| | raise HTTPException( |
| | status_code=status.HTTP_403_FORBIDDEN, |
| | detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |
| | ) |
| | model = Models.insert_new_model( |
| | ModelForm( |
| | id=form_data.id, |
| | name=form_data.id, |
| | meta=ModelMeta(), |
| | params=ModelParams(), |
| | ), |
| | user.id, |
| | db=db, |
| | ) |
| | if not model: |
| | raise HTTPException( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | detail=ERROR_MESSAGES.DEFAULT("Error creating model entry"), |
| | ) |
| |
|
| | if ( |
| | model.user_id != user.id |
| | and not AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="write", |
| | db=db, |
| | ) |
| | and user.role != "admin" |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_400_BAD_REQUEST, |
| | detail=ERROR_MESSAGES.ACCESS_PROHIBITED, |
| | ) |
| |
|
| | |
| | if ( |
| | user.role != "admin" |
| | and has_public_read_access_grant(form_data.access_grants) |
| | and not has_permission( |
| | user.id, |
| | "sharing.public_models", |
| | request.app.state.config.USER_PERMISSIONS, |
| | ) |
| | ): |
| | form_data.access_grants = [ |
| | grant |
| | for grant in form_data.access_grants |
| | if not ( |
| | grant.get("principal_type") == "user" |
| | and grant.get("principal_id") == "*" |
| | ) |
| | ] |
| |
|
| | AccessGrants.set_access_grants( |
| | "model", form_data.id, form_data.access_grants, db=db |
| | ) |
| |
|
| | return Models.get_model_by_id(form_data.id, db=db) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @router.post("/model/delete", response_model=bool) |
| | async def delete_model_by_id( |
| | form_data: ModelIdForm, |
| | user=Depends(get_verified_user), |
| | db: Session = Depends(get_session), |
| | ): |
| | model = Models.get_model_by_id(form_data.id, db=db) |
| | if not model: |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.NOT_FOUND, |
| | ) |
| |
|
| | if ( |
| | user.role != "admin" |
| | and model.user_id != user.id |
| | and not AccessGrants.has_access( |
| | user_id=user.id, |
| | resource_type="model", |
| | resource_id=model.id, |
| | permission="write", |
| | db=db, |
| | ) |
| | ): |
| | raise HTTPException( |
| | status_code=status.HTTP_401_UNAUTHORIZED, |
| | detail=ERROR_MESSAGES.UNAUTHORIZED, |
| | ) |
| |
|
| | result = Models.delete_model_by_id(form_data.id, db=db) |
| | return result |
| |
|
| |
|
| | @router.delete("/delete/all", response_model=bool) |
| | async def delete_all_models( |
| | user=Depends(get_admin_user), db: Session = Depends(get_session) |
| | ): |
| | result = Models.delete_all_models(db=db) |
| | return result |
| |
|