| |
| |
| |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import random |
| import re |
| import time |
| from typing import Optional, Union |
| from urllib.parse import urlparse |
|
|
| import aiohttp |
| from aiocache import cached |
|
|
| import requests |
|
|
| from fastapi import ( |
| Depends, |
| FastAPI, |
| File, |
| HTTPException, |
| Request, |
| UploadFile, |
| APIRouter, |
| ) |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, ConfigDict |
| from starlette.background import BackgroundTask |
|
|
|
|
| from open_webui.models.models import Models |
| from open_webui.utils.misc import ( |
| calculate_sha256, |
| ) |
| from open_webui.utils.payload import ( |
| apply_model_params_to_body_ollama, |
| apply_model_params_to_body_openai, |
| apply_model_system_prompt_to_body, |
| ) |
| from open_webui.utils.auth import get_admin_user, get_verified_user |
| from open_webui.utils.access_control import has_access |
|
|
|
|
| from open_webui.config import ( |
| UPLOAD_DIR, |
| ) |
| from open_webui.env import ( |
| ENV, |
| SRC_LOG_LEVELS, |
| AIOHTTP_CLIENT_TIMEOUT, |
| AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, |
| BYPASS_MODEL_ACCESS_CONTROL, |
| ) |
| from open_webui.constants import ERROR_MESSAGES |
|
|
| log = logging.getLogger(__name__) |
| log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| async def send_get_request(url, key=None): |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) |
| try: |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: |
| async with session.get( |
| url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} |
| ) as response: |
| return await response.json() |
| except Exception as e: |
| |
| log.error(f"Connection error: {e}") |
| return None |
|
|
|
|
| async def cleanup_response( |
| response: Optional[aiohttp.ClientResponse], |
| session: Optional[aiohttp.ClientSession], |
| ): |
| if response: |
| response.close() |
| if session: |
| await session.close() |
|
|
|
|
| async def send_post_request( |
| url: str, |
| payload: Union[str, bytes], |
| stream: bool = True, |
| key: Optional[str] = None, |
| content_type: Optional[str] = None, |
| ): |
|
|
| r = None |
| try: |
| session = aiohttp.ClientSession( |
| trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) |
| ) |
|
|
| r = await session.post( |
| url, |
| data=payload, |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| ) |
| r.raise_for_status() |
|
|
| if stream: |
| response_headers = dict(r.headers) |
|
|
| if content_type: |
| response_headers["Content-Type"] = content_type |
|
|
| return StreamingResponse( |
| r.content, |
| status_code=r.status, |
| headers=response_headers, |
| background=BackgroundTask( |
| cleanup_response, response=r, session=session |
| ), |
| ) |
| else: |
| res = await r.json() |
| await cleanup_response(r, session) |
| return res |
|
|
| except Exception as e: |
| detail = None |
|
|
| if r is not None: |
| try: |
| res = await r.json() |
| if "error" in res: |
| detail = f"Ollama: {res.get('error', 'Unknown error')}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| def get_api_key(idx, url, configs): |
| parsed_url = urlparse(url) |
| base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" |
| return configs.get(str(idx), configs.get(base_url, {})).get( |
| "key", None |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| router = APIRouter() |
|
|
|
|
| @router.head("/") |
| @router.get("/") |
| async def get_status(): |
| return {"status": True} |
|
|
|
|
| class ConnectionVerificationForm(BaseModel): |
| url: str |
| key: Optional[str] = None |
|
|
|
|
| @router.post("/verify") |
| async def verify_connection( |
| form_data: ConnectionVerificationForm, user=Depends(get_admin_user) |
| ): |
| url = form_data.url |
| key = form_data.key |
|
|
| async with aiohttp.ClientSession( |
| timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) |
| ) as session: |
| try: |
| async with session.get( |
| f"{url}/api/version", |
| headers={**({"Authorization": f"Bearer {key}"} if key else {})}, |
| ) as r: |
| if r.status != 200: |
| detail = f"HTTP Error: {r.status}" |
| res = await r.json() |
|
|
| if "error" in res: |
| detail = f"External Error: {res['error']}" |
| raise Exception(detail) |
|
|
| data = await r.json() |
| return data |
| except aiohttp.ClientError as e: |
| log.exception(f"Client error: {str(e)}") |
| raise HTTPException( |
| status_code=500, detail="Open WebUI: Server Connection Error" |
| ) |
| except Exception as e: |
| log.exception(f"Unexpected error: {e}") |
| error_detail = f"Unexpected error: {str(e)}" |
| raise HTTPException(status_code=500, detail=error_detail) |
|
|
|
|
| @router.get("/config") |
| async def get_config(request: Request, user=Depends(get_admin_user)): |
| return { |
| "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, |
| "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, |
| "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, |
| } |
|
|
|
|
| class OllamaConfigForm(BaseModel): |
| ENABLE_OLLAMA_API: Optional[bool] = None |
| OLLAMA_BASE_URLS: list[str] |
| OLLAMA_API_CONFIGS: dict |
|
|
|
|
| @router.post("/config/update") |
| async def update_config( |
| request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) |
| ): |
| request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API |
|
|
| request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS |
| request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS |
|
|
| |
| keys = list(map(str, range(len(request.app.state.config.OLLAMA_BASE_URLS)))) |
| request.app.state.config.OLLAMA_API_CONFIGS = { |
| key: value |
| for key, value in request.app.state.config.OLLAMA_API_CONFIGS.items() |
| if key in keys |
| } |
|
|
| return { |
| "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, |
| "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, |
| "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, |
| } |
|
|
|
|
| @cached(ttl=3) |
| async def get_all_models(request: Request): |
| log.info("get_all_models()") |
| if request.app.state.config.ENABLE_OLLAMA_API: |
| request_tasks = [] |
| for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): |
| if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( |
| url not in request.app.state.config.OLLAMA_API_CONFIGS |
| ): |
| request_tasks.append(send_get_request(f"{url}/api/tags")) |
| else: |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| url, {} |
| ), |
| ) |
|
|
| enable = api_config.get("enable", True) |
| key = api_config.get("key", None) |
|
|
| if enable: |
| request_tasks.append(send_get_request(f"{url}/api/tags", key)) |
| else: |
| request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) |
|
|
| responses = await asyncio.gather(*request_tasks) |
|
|
| for idx, response in enumerate(responses): |
| if response: |
| url = request.app.state.config.OLLAMA_BASE_URLS[idx] |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| url, {} |
| ), |
| ) |
|
|
| prefix_id = api_config.get("prefix_id", None) |
| model_ids = api_config.get("model_ids", []) |
|
|
| if len(model_ids) != 0 and "models" in response: |
| response["models"] = list( |
| filter( |
| lambda model: model["model"] in model_ids, |
| response["models"], |
| ) |
| ) |
|
|
| if prefix_id: |
| for model in response.get("models", []): |
| model["model"] = f"{prefix_id}.{model['model']}" |
|
|
| def merge_models_lists(model_lists): |
| merged_models = {} |
|
|
| for idx, model_list in enumerate(model_lists): |
| if model_list is not None: |
| for model in model_list: |
| id = model["model"] |
| if id not in merged_models: |
| model["urls"] = [idx] |
| merged_models[id] = model |
| else: |
| merged_models[id]["urls"].append(idx) |
|
|
| return list(merged_models.values()) |
|
|
| models = { |
| "models": merge_models_lists( |
| map( |
| lambda response: response.get("models", []) if response else None, |
| responses, |
| ) |
| ) |
| } |
|
|
| else: |
| models = {"models": []} |
|
|
| request.app.state.OLLAMA_MODELS = { |
| model["model"]: model for model in models["models"] |
| } |
| return models |
|
|
|
|
| async def get_filtered_models(models, user): |
| |
| filtered_models = [] |
| for model in models.get("models", []): |
| model_info = Models.get_model_by_id(model["model"]) |
| if model_info: |
| if user.id == model_info.user_id or has_access( |
| user.id, type="read", access_control=model_info.access_control |
| ): |
| filtered_models.append(model) |
| return filtered_models |
|
|
|
|
| @router.get("/api/tags") |
| @router.get("/api/tags/{url_idx}") |
| async def get_ollama_tags( |
| request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) |
| ): |
| models = [] |
|
|
| if url_idx is None: |
| models = await get_all_models(request) |
| else: |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| r = None |
| try: |
| r = requests.request( |
| method="GET", |
| url=f"{url}/api/tags", |
| headers={**({"Authorization": f"Bearer {key}"} if key else {})}, |
| ) |
| r.raise_for_status() |
|
|
| models = r.json() |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
| if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: |
| models["models"] = await get_filtered_models(models, user) |
|
|
| return models |
|
|
|
|
| @router.get("/api/version") |
| @router.get("/api/version/{url_idx}") |
| async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): |
| if request.app.state.config.ENABLE_OLLAMA_API: |
| if url_idx is None: |
| |
| request_tasks = [ |
| send_get_request( |
| f"{url}/api/version", |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| url, {} |
| ), |
| ).get("key", None), |
| ) |
| for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) |
| ] |
| responses = await asyncio.gather(*request_tasks) |
| responses = list(filter(lambda x: x is not None, responses)) |
|
|
| if len(responses) > 0: |
| lowest_version = min( |
| responses, |
| key=lambda x: tuple( |
| map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) |
| ), |
| ) |
|
|
| return {"version": lowest_version["version"]} |
| else: |
| raise HTTPException( |
| status_code=500, |
| detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, |
| ) |
| else: |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
|
|
| r = None |
| try: |
| r = requests.request(method="GET", url=f"{url}/api/version") |
| r.raise_for_status() |
|
|
| return r.json() |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
| else: |
| return {"version": False} |
|
|
|
|
| @router.get("/api/ps") |
| async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): |
| """ |
| List models that are currently loaded into Ollama memory, and which node they are loaded on. |
| """ |
| if request.app.state.config.ENABLE_OLLAMA_API: |
| request_tasks = [ |
| send_get_request( |
| f"{url}/api/ps", |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get( |
| url, {} |
| ), |
| ).get("key", None), |
| ) |
| for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) |
| ] |
| responses = await asyncio.gather(*request_tasks) |
|
|
| return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) |
| else: |
| return {} |
|
|
|
|
| class ModelNameForm(BaseModel): |
| name: str |
|
|
|
|
| @router.post("/api/pull") |
| @router.post("/api/pull/{url_idx}") |
| async def pull_model( |
| request: Request, |
| form_data: ModelNameForm, |
| url_idx: int = 0, |
| user=Depends(get_admin_user), |
| ): |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| log.info(f"url: {url}") |
|
|
| |
| payload = {**form_data.model_dump(exclude_none=True), "insecure": True} |
|
|
| return await send_post_request( |
| url=f"{url}/api/pull", |
| payload=json.dumps(payload), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| class PushModelForm(BaseModel): |
| name: str |
| insecure: Optional[bool] = None |
| stream: Optional[bool] = None |
|
|
|
|
| @router.delete("/api/push") |
| @router.delete("/api/push/{url_idx}") |
| async def push_model( |
| request: Request, |
| form_data: PushModelForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_admin_user), |
| ): |
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| if form_data.name in models: |
| url_idx = models[form_data.name]["urls"][0] |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| log.debug(f"url: {url}") |
|
|
| return await send_post_request( |
| url=f"{url}/api/push", |
| payload=form_data.model_dump_json(exclude_none=True).encode(), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| class CreateModelForm(BaseModel): |
| model: Optional[str] = None |
| stream: Optional[bool] = None |
| path: Optional[str] = None |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
|
|
| @router.post("/api/create") |
| @router.post("/api/create/{url_idx}") |
| async def create_model( |
| request: Request, |
| form_data: CreateModelForm, |
| url_idx: int = 0, |
| user=Depends(get_admin_user), |
| ): |
| log.debug(f"form_data: {form_data}") |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
|
|
| return await send_post_request( |
| url=f"{url}/api/create", |
| payload=form_data.model_dump_json(exclude_none=True).encode(), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| class CopyModelForm(BaseModel): |
| source: str |
| destination: str |
|
|
|
|
| @router.post("/api/copy") |
| @router.post("/api/copy/{url_idx}") |
| async def copy_model( |
| request: Request, |
| form_data: CopyModelForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_admin_user), |
| ): |
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| if form_data.source in models: |
| url_idx = models[form_data.source]["urls"][0] |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| try: |
| r = requests.request( |
| method="POST", |
| url=f"{url}/api/copy", |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| data=form_data.model_dump_json(exclude_none=True).encode(), |
| ) |
| r.raise_for_status() |
|
|
| log.debug(f"r.text: {r.text}") |
| return True |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| @router.delete("/api/delete") |
| @router.delete("/api/delete/{url_idx}") |
| async def delete_model( |
| request: Request, |
| form_data: ModelNameForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_admin_user), |
| ): |
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| if form_data.name in models: |
| url_idx = models[form_data.name]["urls"][0] |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| try: |
| r = requests.request( |
| method="DELETE", |
| url=f"{url}/api/delete", |
| data=form_data.model_dump_json(exclude_none=True).encode(), |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| ) |
| r.raise_for_status() |
|
|
| log.debug(f"r.text: {r.text}") |
| return True |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| @router.post("/api/show") |
| async def show_model_info( |
| request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) |
| ): |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| if form_data.name not in models: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), |
| ) |
|
|
| url_idx = random.choice(models[form_data.name]["urls"]) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| try: |
| r = requests.request( |
| method="POST", |
| url=f"{url}/api/show", |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| data=form_data.model_dump_json(exclude_none=True).encode(), |
| ) |
| r.raise_for_status() |
|
|
| return r.json() |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| class GenerateEmbedForm(BaseModel): |
| model: str |
| input: list[str] | str |
| truncate: Optional[bool] = None |
| options: Optional[dict] = None |
| keep_alive: Optional[Union[int, str]] = None |
|
|
|
|
| @router.post("/api/embed") |
| @router.post("/api/embed/{url_idx}") |
| async def embed( |
| request: Request, |
| form_data: GenerateEmbedForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
| log.info(f"generate_ollama_batch_embeddings {form_data}") |
|
|
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| model = form_data.model |
|
|
| if ":" not in model: |
| model = f"{model}:latest" |
|
|
| if model in models: |
| url_idx = random.choice(models[model]["urls"]) |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| try: |
| r = requests.request( |
| method="POST", |
| url=f"{url}/api/embed", |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| data=form_data.model_dump_json(exclude_none=True).encode(), |
| ) |
| r.raise_for_status() |
|
|
| data = r.json() |
| return data |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| class GenerateEmbeddingsForm(BaseModel): |
| model: str |
| prompt: str |
| options: Optional[dict] = None |
| keep_alive: Optional[Union[int, str]] = None |
|
|
|
|
| @router.post("/api/embeddings") |
| @router.post("/api/embeddings/{url_idx}") |
| async def embeddings( |
| request: Request, |
| form_data: GenerateEmbeddingsForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
| log.info(f"generate_ollama_embeddings {form_data}") |
|
|
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| model = form_data.model |
|
|
| if ":" not in model: |
| model = f"{model}:latest" |
|
|
| if model in models: |
| url_idx = random.choice(models[model]["urls"]) |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) |
|
|
| try: |
| r = requests.request( |
| method="POST", |
| url=f"{url}/api/embeddings", |
| headers={ |
| "Content-Type": "application/json", |
| **({"Authorization": f"Bearer {key}"} if key else {}), |
| }, |
| data=form_data.model_dump_json(exclude_none=True).encode(), |
| ) |
| r.raise_for_status() |
|
|
| data = r.json() |
| return data |
| except Exception as e: |
| log.exception(e) |
|
|
| detail = None |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| detail = f"Ollama: {res['error']}" |
| except Exception: |
| detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=detail if detail else "Open WebUI: Server Connection Error", |
| ) |
|
|
|
|
| class GenerateCompletionForm(BaseModel): |
| model: str |
| prompt: str |
| suffix: Optional[str] = None |
| images: Optional[list[str]] = None |
| format: Optional[str] = None |
| options: Optional[dict] = None |
| system: Optional[str] = None |
| template: Optional[str] = None |
| context: Optional[list[int]] = None |
| stream: Optional[bool] = True |
| raw: Optional[bool] = None |
| keep_alive: Optional[Union[int, str]] = None |
|
|
|
|
| @router.post("/api/generate") |
| @router.post("/api/generate/{url_idx}") |
| async def generate_completion( |
| request: Request, |
| form_data: GenerateCompletionForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
| if url_idx is None: |
| await get_all_models(request) |
| models = request.app.state.OLLAMA_MODELS |
|
|
| model = form_data.model |
|
|
| if ":" not in model: |
| model = f"{model}:latest" |
|
|
| if model in models: |
| url_idx = random.choice(models[model]["urls"]) |
| else: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), |
| ) |
|
|
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(url_idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), |
| ) |
|
|
| prefix_id = api_config.get("prefix_id", None) |
| if prefix_id: |
| form_data.model = form_data.model.replace(f"{prefix_id}.", "") |
|
|
| return await send_post_request( |
| url=f"{url}/api/generate", |
| payload=form_data.model_dump_json(exclude_none=True).encode(), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| class ChatMessage(BaseModel): |
| role: str |
| content: str |
| images: Optional[list[str]] = None |
|
|
|
|
| class GenerateChatCompletionForm(BaseModel): |
| model: str |
| messages: list[ChatMessage] |
| format: Optional[dict] = None |
| options: Optional[dict] = None |
| template: Optional[str] = None |
| stream: Optional[bool] = True |
| keep_alive: Optional[Union[int, str]] = None |
|
|
|
|
| async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): |
| if url_idx is None: |
| models = request.app.state.OLLAMA_MODELS |
| if model not in models: |
| raise HTTPException( |
| status_code=400, |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), |
| ) |
| url_idx = random.choice(models[model].get("urls", [])) |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| return url, url_idx |
|
|
|
|
| @router.post("/api/chat") |
| @router.post("/api/chat/{url_idx}") |
| async def generate_chat_completion( |
| request: Request, |
| form_data: dict, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| bypass_filter: Optional[bool] = False, |
| ): |
| if BYPASS_MODEL_ACCESS_CONTROL: |
| bypass_filter = True |
|
|
| metadata = form_data.pop("metadata", None) |
| try: |
| form_data = GenerateChatCompletionForm(**form_data) |
| except Exception as e: |
| log.exception(e) |
| raise HTTPException( |
| status_code=400, |
| detail=str(e), |
| ) |
|
|
| payload = {**form_data.model_dump(exclude_none=True)} |
|
|
| model_id = payload["model"] |
| model_info = Models.get_model_by_id(model_id) |
|
|
| if model_info: |
| if model_info.base_model_id: |
| payload["model"] = model_info.base_model_id |
|
|
| params = model_info.params.model_dump() |
|
|
| if params: |
| if payload.get("options") is None: |
| payload["options"] = {} |
|
|
| payload["options"] = apply_model_params_to_body_ollama( |
| params, payload["options"] |
| ) |
| payload = apply_model_system_prompt_to_body(params, payload, metadata) |
|
|
| |
| if not bypass_filter and user.role == "user": |
| if not ( |
| user.id == model_info.user_id |
| or has_access( |
| user.id, type="read", access_control=model_info.access_control |
| ) |
| ): |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
| elif not bypass_filter: |
| if user.role != "admin": |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
|
|
| if ":" not in payload["model"]: |
| payload["model"] = f"{payload['model']}:latest" |
|
|
| url, url_idx = await get_ollama_url(request, payload["model"], url_idx) |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(url_idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), |
| ) |
|
|
| prefix_id = api_config.get("prefix_id", None) |
| if prefix_id: |
| payload["model"] = payload["model"].replace(f"{prefix_id}.", "") |
|
|
| return await send_post_request( |
| url=f"{url}/api/chat", |
| payload=json.dumps(payload), |
| stream=form_data.stream, |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| content_type="application/x-ndjson", |
| ) |
|
|
|
|
| |
| class OpenAIChatMessageContent(BaseModel): |
| type: str |
| model_config = ConfigDict(extra="allow") |
|
|
|
|
| class OpenAIChatMessage(BaseModel): |
| role: str |
| content: Union[str, list[OpenAIChatMessageContent]] |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
|
|
| class OpenAIChatCompletionForm(BaseModel): |
| model: str |
| messages: list[OpenAIChatMessage] |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
|
|
| class OpenAICompletionForm(BaseModel): |
| model: str |
| prompt: str |
|
|
| model_config = ConfigDict(extra="allow") |
|
|
|
|
| @router.post("/v1/completions") |
| @router.post("/v1/completions/{url_idx}") |
| async def generate_openai_completion( |
| request: Request, |
| form_data: dict, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
| try: |
| form_data = OpenAICompletionForm(**form_data) |
| except Exception as e: |
| log.exception(e) |
| raise HTTPException( |
| status_code=400, |
| detail=str(e), |
| ) |
|
|
| payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} |
| if "metadata" in payload: |
| del payload["metadata"] |
|
|
| model_id = form_data.model |
| if ":" not in model_id: |
| model_id = f"{model_id}:latest" |
|
|
| model_info = Models.get_model_by_id(model_id) |
| if model_info: |
| if model_info.base_model_id: |
| payload["model"] = model_info.base_model_id |
| params = model_info.params.model_dump() |
|
|
| if params: |
| payload = apply_model_params_to_body_openai(params, payload) |
|
|
| |
| if user.role == "user": |
| if not ( |
| user.id == model_info.user_id |
| or has_access( |
| user.id, type="read", access_control=model_info.access_control |
| ) |
| ): |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
| else: |
| if user.role != "admin": |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
|
|
| if ":" not in payload["model"]: |
| payload["model"] = f"{payload['model']}:latest" |
|
|
| url, url_idx = await get_ollama_url(request, payload["model"], url_idx) |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(url_idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), |
| ) |
|
|
| prefix_id = api_config.get("prefix_id", None) |
|
|
| if prefix_id: |
| payload["model"] = payload["model"].replace(f"{prefix_id}.", "") |
|
|
| return await send_post_request( |
| url=f"{url}/v1/completions", |
| payload=json.dumps(payload), |
| stream=payload.get("stream", False), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| @router.post("/v1/chat/completions") |
| @router.post("/v1/chat/completions/{url_idx}") |
| async def generate_openai_chat_completion( |
| request: Request, |
| form_data: dict, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
| try: |
| completion_form = OpenAIChatCompletionForm(**form_data) |
| except Exception as e: |
| log.exception(e) |
| raise HTTPException( |
| status_code=400, |
| detail=str(e), |
| ) |
|
|
| payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} |
| if "metadata" in payload: |
| del payload["metadata"] |
|
|
| model_id = completion_form.model |
| if ":" not in model_id: |
| model_id = f"{model_id}:latest" |
|
|
| model_info = Models.get_model_by_id(model_id) |
| if model_info: |
| if model_info.base_model_id: |
| payload["model"] = model_info.base_model_id |
|
|
| params = model_info.params.model_dump() |
|
|
| if params: |
| payload = apply_model_params_to_body_openai(params, payload) |
| payload = apply_model_system_prompt_to_body(params, payload, user) |
|
|
| |
| if user.role == "user": |
| if not ( |
| user.id == model_info.user_id |
| or has_access( |
| user.id, type="read", access_control=model_info.access_control |
| ) |
| ): |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
| else: |
| if user.role != "admin": |
| raise HTTPException( |
| status_code=403, |
| detail="Model not found", |
| ) |
|
|
| if ":" not in payload["model"]: |
| payload["model"] = f"{payload['model']}:latest" |
|
|
| url, url_idx = await get_ollama_url(request, payload["model"], url_idx) |
| api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( |
| str(url_idx), |
| request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), |
| ) |
|
|
| prefix_id = api_config.get("prefix_id", None) |
| if prefix_id: |
| payload["model"] = payload["model"].replace(f"{prefix_id}.", "") |
|
|
| return await send_post_request( |
| url=f"{url}/v1/chat/completions", |
| payload=json.dumps(payload), |
| stream=payload.get("stream", False), |
| key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), |
| ) |
|
|
|
|
| @router.get("/v1/models") |
| @router.get("/v1/models/{url_idx}") |
| async def get_openai_models( |
| request: Request, |
| url_idx: Optional[int] = None, |
| user=Depends(get_verified_user), |
| ): |
|
|
| models = [] |
| if url_idx is None: |
| model_list = await get_all_models(request) |
| models = [ |
| { |
| "id": model["model"], |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "openai", |
| } |
| for model in model_list["models"] |
| ] |
|
|
| else: |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
| try: |
| r = requests.request(method="GET", url=f"{url}/api/tags") |
| r.raise_for_status() |
|
|
| model_list = r.json() |
|
|
| models = [ |
| { |
| "id": model["model"], |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "openai", |
| } |
| for model in models["models"] |
| ] |
| except Exception as e: |
| log.exception(e) |
| error_detail = "Open WebUI: Server Connection Error" |
| if r is not None: |
| try: |
| res = r.json() |
| if "error" in res: |
| error_detail = f"Ollama: {res['error']}" |
| except Exception: |
| error_detail = f"Ollama: {e}" |
|
|
| raise HTTPException( |
| status_code=r.status_code if r else 500, |
| detail=error_detail, |
| ) |
|
|
| if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: |
| |
| filtered_models = [] |
| for model in models: |
| model_info = Models.get_model_by_id(model["id"]) |
| if model_info: |
| if user.id == model_info.user_id or has_access( |
| user.id, type="read", access_control=model_info.access_control |
| ): |
| filtered_models.append(model) |
| models = filtered_models |
|
|
| return { |
| "data": models, |
| "object": "list", |
| } |
|
|
|
|
| class UrlForm(BaseModel): |
| url: str |
|
|
|
|
| class UploadBlobForm(BaseModel): |
| filename: str |
|
|
|
|
| def parse_huggingface_url(hf_url): |
| try: |
| |
| parsed_url = urlparse(hf_url) |
|
|
| |
| path_components = parsed_url.path.split("/") |
|
|
| |
| model_file = path_components[-1] |
|
|
| return model_file |
| except ValueError: |
| return None |
|
|
|
|
| async def download_file_stream( |
| ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 |
| ): |
| done = False |
|
|
| if os.path.exists(file_path): |
| current_size = os.path.getsize(file_path) |
| else: |
| current_size = 0 |
|
|
| headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} |
|
|
| timeout = aiohttp.ClientTimeout(total=600) |
|
|
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: |
| async with session.get(file_url, headers=headers) as response: |
| total_size = int(response.headers.get("content-length", 0)) + current_size |
|
|
| with open(file_path, "ab+") as file: |
| async for data in response.content.iter_chunked(chunk_size): |
| current_size += len(data) |
| file.write(data) |
|
|
| done = current_size == total_size |
| progress = round((current_size / total_size) * 100, 2) |
|
|
| yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' |
|
|
| if done: |
| file.seek(0) |
| hashed = calculate_sha256(file) |
| file.seek(0) |
|
|
| url = f"{ollama_url}/api/blobs/sha256:{hashed}" |
| response = requests.post(url, data=file) |
|
|
| if response.ok: |
| res = { |
| "done": done, |
| "blob": f"sha256:{hashed}", |
| "name": file_name, |
| } |
| os.remove(file_path) |
|
|
| yield f"data: {json.dumps(res)}\n\n" |
| else: |
| raise "Ollama: Could not create blob, Please try again." |
|
|
|
|
| |
| @router.post("/models/download") |
| @router.post("/models/download/{url_idx}") |
| async def download_model( |
| request: Request, |
| form_data: UrlForm, |
| url_idx: Optional[int] = None, |
| user=Depends(get_admin_user), |
| ): |
| allowed_hosts = ["https://huggingface.co/", "https://github.com/"] |
|
|
| if not any(form_data.url.startswith(host) for host in allowed_hosts): |
| raise HTTPException( |
| status_code=400, |
| detail="Invalid file_url. Only URLs from allowed hosts are permitted.", |
| ) |
|
|
| if url_idx is None: |
| url_idx = 0 |
| url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
|
|
| file_name = parse_huggingface_url(form_data.url) |
|
|
| if file_name: |
| file_path = f"{UPLOAD_DIR}/{file_name}" |
|
|
| return StreamingResponse( |
| download_file_stream(url, form_data.url, file_path, file_name), |
| ) |
| else: |
| return None |
|
|
|
|
| @router.post("/models/upload") |
| @router.post("/models/upload/{url_idx}") |
| def upload_model( |
| request: Request, |
| file: UploadFile = File(...), |
| url_idx: Optional[int] = None, |
| user=Depends(get_admin_user), |
| ): |
| if url_idx is None: |
| url_idx = 0 |
| ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] |
|
|
| file_path = f"{UPLOAD_DIR}/{file.filename}" |
|
|
| |
| with open(file_path, "wb+") as f: |
| for chunk in file.file: |
| f.write(chunk) |
|
|
| def file_process_stream(): |
| nonlocal ollama_url |
| total_size = os.path.getsize(file_path) |
| chunk_size = 1024 * 1024 |
| try: |
| with open(file_path, "rb") as f: |
| total = 0 |
| done = False |
|
|
| while not done: |
| chunk = f.read(chunk_size) |
| if not chunk: |
| done = True |
| continue |
|
|
| total += len(chunk) |
| progress = round((total / total_size) * 100, 2) |
|
|
| res = { |
| "progress": progress, |
| "total": total_size, |
| "completed": total, |
| } |
| yield f"data: {json.dumps(res)}\n\n" |
|
|
| if done: |
| f.seek(0) |
| hashed = calculate_sha256(f) |
| f.seek(0) |
|
|
| url = f"{ollama_url}/api/blobs/sha256:{hashed}" |
| response = requests.post(url, data=f) |
|
|
| if response.ok: |
| res = { |
| "done": done, |
| "blob": f"sha256:{hashed}", |
| "name": file.filename, |
| } |
| os.remove(file_path) |
| yield f"data: {json.dumps(res)}\n\n" |
| else: |
| raise Exception( |
| "Ollama: Could not create blob, Please try again." |
| ) |
|
|
| except Exception as e: |
| res = {"error": str(e)} |
| yield f"data: {json.dumps(res)}\n\n" |
|
|
| return StreamingResponse(file_process_stream(), media_type="text/event-stream") |
|
|