| from fastapi import APIRouter, Depends, HTTPException, Response, status, Request |
| from fastapi.responses import JSONResponse, RedirectResponse |
|
|
| from pydantic import BaseModel |
| from typing import Optional |
| import logging |
|
|
| from open_webui.utils.chat import generate_chat_completion |
| from open_webui.utils.task import ( |
| title_generation_template, |
| query_generation_template, |
| image_prompt_generation_template, |
| autocomplete_generation_template, |
| tags_generation_template, |
| emoji_generation_template, |
| moa_response_generation_template, |
| ) |
| from open_webui.utils.auth import get_admin_user, get_verified_user |
| from open_webui.constants import TASKS |
|
|
| from open_webui.routers.pipelines import process_pipeline_inlet_filter |
| from open_webui.utils.task import get_task_model_id |
|
|
| from open_webui.config import ( |
| DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, |
| DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, |
| ) |
| from open_webui.env import SRC_LOG_LEVELS |
|
|
|
|
| log = logging.getLogger(__name__) |
| log.setLevel(SRC_LOG_LEVELS["MODELS"]) |
|
|
| router = APIRouter() |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| @router.get("/config") |
| async def get_task_config(request: Request, user=Depends(get_verified_user)): |
| return { |
| "TASK_MODEL": request.app.state.config.TASK_MODEL, |
| "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, |
| "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, |
| "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
| "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, |
| "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, |
| "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, |
| "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, |
| "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, |
| "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, |
| "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
| } |
|
|
|
|
| class TaskConfigForm(BaseModel): |
| TASK_MODEL: Optional[str] |
| TASK_MODEL_EXTERNAL: Optional[str] |
| TITLE_GENERATION_PROMPT_TEMPLATE: str |
| IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str |
| ENABLE_AUTOCOMPLETE_GENERATION: bool |
| AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int |
| TAGS_GENERATION_PROMPT_TEMPLATE: str |
| ENABLE_TAGS_GENERATION: bool |
| ENABLE_SEARCH_QUERY_GENERATION: bool |
| ENABLE_RETRIEVAL_QUERY_GENERATION: bool |
| QUERY_GENERATION_PROMPT_TEMPLATE: str |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str |
|
|
|
|
| @router.post("/config/update") |
| async def update_task_config( |
| request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user) |
| ): |
| request.app.state.config.TASK_MODEL = form_data.TASK_MODEL |
| request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL |
| request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = ( |
| form_data.TITLE_GENERATION_PROMPT_TEMPLATE |
| ) |
|
|
| request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ( |
| form_data.ENABLE_AUTOCOMPLETE_GENERATION |
| ) |
| request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( |
| form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH |
| ) |
|
|
| request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = ( |
| form_data.TAGS_GENERATION_PROMPT_TEMPLATE |
| ) |
| request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION |
| request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ( |
| form_data.ENABLE_SEARCH_QUERY_GENERATION |
| ) |
| request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ( |
| form_data.ENABLE_RETRIEVAL_QUERY_GENERATION |
| ) |
|
|
| request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = ( |
| form_data.QUERY_GENERATION_PROMPT_TEMPLATE |
| ) |
| request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( |
| form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
| ) |
|
|
| return { |
| "TASK_MODEL": request.app.state.config.TASK_MODEL, |
| "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL, |
| "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, |
| "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, |
| "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, |
| "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, |
| "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, |
| "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, |
| "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, |
| "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, |
| "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
| } |
|
|
|
|
| @router.post("/title/completions") |
| async def generate_title( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug( |
| f"generating chat title using model {task_model_id} for user {user.email} " |
| ) |
|
|
| if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": |
| template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE |
|
|
| content = title_generation_template( |
| template, |
| form_data["messages"], |
| { |
| "name": user.name, |
| "location": user.info.get("location") if user.info else None, |
| }, |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| **( |
| {"max_tokens": 50} |
| if models[task_model_id]["owned_by"] == "ollama" |
| else { |
| "max_completion_tokens": 50, |
| } |
| ), |
| "metadata": { |
| "task": str(TASKS.TITLE_GENERATION), |
| "task_body": form_data, |
| "chat_id": form_data.get("chat_id", None), |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| log.error("Exception occurred", exc_info=True) |
| return JSONResponse( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| content={"detail": "An internal error has occurred."}, |
| ) |
|
|
|
|
| @router.post("/tags/completions") |
| async def generate_chat_tags( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
|
|
| if not request.app.state.config.ENABLE_TAGS_GENERATION: |
| return JSONResponse( |
| status_code=status.HTTP_200_OK, |
| content={"detail": "Tags generation is disabled"}, |
| ) |
|
|
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug( |
| f"generating chat tags using model {task_model_id} for user {user.email} " |
| ) |
|
|
| if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": |
| template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE |
|
|
| content = tags_generation_template( |
| template, form_data["messages"], {"name": user.name} |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| "metadata": { |
| "task": str(TASKS.TAGS_GENERATION), |
| "task_body": form_data, |
| "chat_id": form_data.get("chat_id", None), |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| log.error(f"Error generating chat completion: {e}") |
| return JSONResponse( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| content={"detail": "An internal error has occurred."}, |
| ) |
|
|
|
|
| @router.post("/image_prompt/completions") |
| async def generate_image_prompt( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug( |
| f"generating image prompt using model {task_model_id} for user {user.email} " |
| ) |
|
|
| if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "": |
| template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE |
|
|
| content = image_prompt_generation_template( |
| template, |
| form_data["messages"], |
| user={ |
| "name": user.name, |
| }, |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| "metadata": { |
| "task": str(TASKS.IMAGE_PROMPT_GENERATION), |
| "task_body": form_data, |
| "chat_id": form_data.get("chat_id", None), |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| log.error("Exception occurred", exc_info=True) |
| return JSONResponse( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| content={"detail": "An internal error has occurred."}, |
| ) |
|
|
|
|
| @router.post("/queries/completions") |
| async def generate_queries( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
|
|
| type = form_data.get("type") |
| if type == "web_search": |
| if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Search query generation is disabled", |
| ) |
| elif type == "retrieval": |
| if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Query generation is disabled", |
| ) |
|
|
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug( |
| f"generating {type} queries using model {task_model_id} for user {user.email}" |
| ) |
|
|
| if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "": |
| template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE |
|
|
| content = query_generation_template( |
| template, form_data["messages"], {"name": user.name} |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| "metadata": { |
| "task": str(TASKS.QUERY_GENERATION), |
| "task_body": form_data, |
| "chat_id": form_data.get("chat_id", None), |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| return JSONResponse( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| content={"detail": str(e)}, |
| ) |
|
|
|
|
| @router.post("/auto/completions") |
| async def generate_autocompletion( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
| if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Autocompletion generation is disabled", |
| ) |
|
|
| type = form_data.get("type") |
| prompt = form_data.get("prompt") |
| messages = form_data.get("messages") |
|
|
| if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: |
| if ( |
| len(prompt) |
| > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH |
| ): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", |
| ) |
|
|
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug( |
| f"generating autocompletion using model {task_model_id} for user {user.email}" |
| ) |
|
|
| if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "": |
| template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE |
|
|
| content = autocomplete_generation_template( |
| template, prompt, messages, type, {"name": user.name} |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| "metadata": { |
| "task": str(TASKS.AUTOCOMPLETE_GENERATION), |
| "task_body": form_data, |
| "chat_id": form_data.get("chat_id", None), |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| log.error(f"Error generating chat completion: {e}") |
| return JSONResponse( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| content={"detail": "An internal error has occurred."}, |
| ) |
|
|
|
|
| @router.post("/emoji/completions") |
| async def generate_emoji( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
|
|
| models = request.app.state.MODELS |
|
|
| model_id = form_data["model"] |
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") |
|
|
| template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE |
|
|
| content = emoji_generation_template( |
| template, |
| form_data["prompt"], |
| { |
| "name": user.name, |
| "location": user.info.get("location") if user.info else None, |
| }, |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": False, |
| **( |
| {"max_tokens": 4} |
| if models[task_model_id]["owned_by"] == "ollama" |
| else { |
| "max_completion_tokens": 4, |
| } |
| ), |
| "chat_id": form_data.get("chat_id", None), |
| "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| return JSONResponse( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| content={"detail": str(e)}, |
| ) |
|
|
|
|
| @router.post("/moa/completions") |
| async def generate_moa_response( |
| request: Request, form_data: dict, user=Depends(get_verified_user) |
| ): |
|
|
| models = request.app.state.MODELS |
| model_id = form_data["model"] |
|
|
| if model_id not in models: |
| raise HTTPException( |
| status_code=status.HTTP_404_NOT_FOUND, |
| detail="Model not found", |
| ) |
|
|
| |
| |
| task_model_id = get_task_model_id( |
| model_id, |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| log.debug(f"generating MOA model {task_model_id} for user {user.email} ") |
|
|
| template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE |
|
|
| content = moa_response_generation_template( |
| template, |
| form_data["prompt"], |
| form_data["responses"], |
| ) |
|
|
| payload = { |
| "model": task_model_id, |
| "messages": [{"role": "user", "content": content}], |
| "stream": form_data.get("stream", False), |
| "metadata": { |
| "chat_id": form_data.get("chat_id", None), |
| "task": str(TASKS.MOA_RESPONSE_GENERATION), |
| "task_body": form_data, |
| }, |
| } |
|
|
| try: |
| return await generate_chat_completion(request, form_data=payload, user=user) |
| except Exception as e: |
| return JSONResponse( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| content={"detail": str(e)}, |
| ) |
|
|