Spaces:
Build error
Build error
| from fastapi import APIRouter, Depends, HTTPException, Response, status, Request | |
| from fastapi.responses import JSONResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import logging | |
| import re | |
| from open_webui.utils.chat import generate_chat_completion | |
| from open_webui.utils.task import ( | |
| title_generation_template, | |
| follow_up_generation_template, | |
| query_generation_template, | |
| image_prompt_generation_template, | |
| autocomplete_generation_template, | |
| tags_generation_template, | |
| emoji_generation_template, | |
| moa_response_generation_template, | |
| ) | |
| from open_webui.utils.auth import get_admin_user, get_verified_user | |
| from open_webui.constants import ERROR_MESSAGES, TASKS | |
| from open_webui.routers.pipelines import process_pipeline_inlet_filter | |
| from open_webui.utils.task import get_task_model_id | |
| from open_webui.config import ( | |
| DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, | |
| DEFAULT_VOICE_MODE_PROMPT_TEMPLATE, | |
| ) | |
| log = logging.getLogger(__name__) | |
| router = APIRouter() | |
| ################################## | |
| # | |
| # Task Endpoints | |
| # | |
| ################################## | |
| class ActiveChatsForm(BaseModel): | |
| chat_ids: list[str] | |
| async def check_active_chats(request: Request, form_data: ActiveChatsForm, user=Depends(get_verified_user)): | |
| """Check which chat IDs have active tasks.""" | |
| from open_webui.tasks import get_active_chat_ids | |
| active = await get_active_chat_ids(request.app.state.redis, form_data.chat_ids) | |
| return {'active_chat_ids': active} | |
| async def get_task_config(request: Request, user=Depends(get_verified_user)): | |
| return { | |
| 'TASK_MODEL': request.app.state.config.TASK_MODEL, | |
| 'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL, | |
| 'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | |
| 'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| 'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | |
| 'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | |
| 'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | |
| 'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | |
| 'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, | |
| 'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION, | |
| 'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION, | |
| 'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | |
| 'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | |
| 'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | |
| 'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | |
| 'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, | |
| } | |
| class TaskConfigForm(BaseModel): | |
| TASK_MODEL: Optional[str] | |
| TASK_MODEL_EXTERNAL: Optional[str] | |
| ENABLE_TITLE_GENERATION: bool | |
| TITLE_GENERATION_PROMPT_TEMPLATE: str | |
| IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str | |
| ENABLE_AUTOCOMPLETE_GENERATION: bool | |
| AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int | |
| TAGS_GENERATION_PROMPT_TEMPLATE: str | |
| FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str | |
| ENABLE_FOLLOW_UP_GENERATION: bool | |
| ENABLE_TAGS_GENERATION: bool | |
| ENABLE_SEARCH_QUERY_GENERATION: bool | |
| ENABLE_RETRIEVAL_QUERY_GENERATION: bool | |
| QUERY_GENERATION_PROMPT_TEMPLATE: str | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str | |
| VOICE_MODE_PROMPT_TEMPLATE: Optional[str] | |
| async def update_task_config(request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)): | |
| request.app.state.config.TASK_MODEL = form_data.TASK_MODEL | |
| request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL | |
| request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION | |
| request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = form_data.TITLE_GENERATION_PROMPT_TEMPLATE | |
| request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = form_data.ENABLE_FOLLOW_UP_GENERATION | |
| request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | |
| request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = form_data.ENABLE_AUTOCOMPLETE_GENERATION | |
| request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = ( | |
| form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH | |
| ) | |
| request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = form_data.TAGS_GENERATION_PROMPT_TEMPLATE | |
| request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION | |
| request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = form_data.ENABLE_SEARCH_QUERY_GENERATION | |
| request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = form_data.ENABLE_RETRIEVAL_QUERY_GENERATION | |
| request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = form_data.QUERY_GENERATION_PROMPT_TEMPLATE | |
| request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE | |
| request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE = form_data.VOICE_MODE_PROMPT_TEMPLATE | |
| return { | |
| 'TASK_MODEL': request.app.state.config.TASK_MODEL, | |
| 'TASK_MODEL_EXTERNAL': request.app.state.config.TASK_MODEL_EXTERNAL, | |
| 'ENABLE_TITLE_GENERATION': request.app.state.config.ENABLE_TITLE_GENERATION, | |
| 'TITLE_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, | |
| 'IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE': request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, | |
| 'ENABLE_AUTOCOMPLETE_GENERATION': request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, | |
| 'AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH': request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, | |
| 'TAGS_GENERATION_PROMPT_TEMPLATE': request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, | |
| 'ENABLE_TAGS_GENERATION': request.app.state.config.ENABLE_TAGS_GENERATION, | |
| 'ENABLE_FOLLOW_UP_GENERATION': request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, | |
| 'FOLLOW_UP_GENERATION_PROMPT_TEMPLATE': request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, | |
| 'ENABLE_SEARCH_QUERY_GENERATION': request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, | |
| 'ENABLE_RETRIEVAL_QUERY_GENERATION': request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, | |
| 'QUERY_GENERATION_PROMPT_TEMPLATE': request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, | |
| 'TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE': request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | |
| 'VOICE_MODE_PROMPT_TEMPLATE': request.app.state.config.VOICE_MODE_PROMPT_TEMPLATE, | |
| } | |
| async def generate_title(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if not request.app.state.config.ENABLE_TITLE_GENERATION: | |
| return JSONResponse( | |
| status_code=status.HTTP_200_OK, | |
| content={'detail': 'Title generation is disabled'}, | |
| ) | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating chat title using model {task_model_id} for user {user.email} ') | |
| if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != '': | |
| template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE | |
| content = title_generation_template(template, form_data['messages'], user) | |
| max_tokens = models[task_model_id].get('info', {}).get('params', {}).get('max_tokens', 1000) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| **( | |
| {'max_tokens': max_tokens} | |
| if models[task_model_id].get('owned_by') == 'ollama' | |
| else { | |
| 'max_completion_tokens': max_tokens, | |
| } | |
| ), | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.TITLE_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error('Exception occurred', exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': 'An internal error has occurred.'}, | |
| ) | |
| async def generate_follow_ups(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: | |
| return JSONResponse( | |
| status_code=status.HTTP_200_OK, | |
| content={'detail': 'Follow-up generation is disabled'}, | |
| ) | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating chat title using model {task_model_id} for user {user.email} ') | |
| if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != '': | |
| template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE | |
| content = follow_up_generation_template(template, form_data['messages'], user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.FOLLOW_UP_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error('Exception occurred', exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': 'An internal error has occurred.'}, | |
| ) | |
| async def generate_chat_tags(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if not request.app.state.config.ENABLE_TAGS_GENERATION: | |
| return JSONResponse( | |
| status_code=status.HTTP_200_OK, | |
| content={'detail': 'Tags generation is disabled'}, | |
| ) | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating chat tags using model {task_model_id} for user {user.email} ') | |
| if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != '': | |
| template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE | |
| content = tags_generation_template(template, form_data['messages'], user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.TAGS_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error(f'Error generating chat completion: {e}') | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={'detail': 'An internal error has occurred.'}, | |
| ) | |
| async def generate_image_prompt(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating image prompt using model {task_model_id} for user {user.email} ') | |
| if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != '': | |
| template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE | |
| content = image_prompt_generation_template(template, form_data['messages'], user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.IMAGE_PROMPT_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error('Exception occurred', exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': 'An internal error has occurred.'}, | |
| ) | |
| async def generate_queries(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| type = form_data.get('type') | |
| if type == 'web_search': | |
| if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.FEATURE_DISABLED('Search query generation'), | |
| ) | |
| elif type == 'retrieval': | |
| if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.FEATURE_DISABLED('Query generation'), | |
| ) | |
| if getattr(request.state, 'cached_queries', None): | |
| log.info(f'Reusing cached queries: {request.state.cached_queries}') | |
| return request.state.cached_queries | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating {type} queries using model {task_model_id} for user {user.email}') | |
| if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != '': | |
| template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE | |
| content = query_generation_template(template, form_data['messages'], user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.QUERY_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': str(e)}, | |
| ) | |
| async def generate_autocompletion(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.FEATURE_DISABLED('Autocompletion generation'), | |
| ) | |
| type = form_data.get('type') | |
| prompt = form_data.get('prompt') | |
| messages = form_data.get('messages') | |
| if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0: | |
| if len(prompt) > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.INPUT_TOO_LONG(request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH), | |
| ) | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating autocompletion using model {task_model_id} for user {user.email}') | |
| if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != '': | |
| template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | |
| else: | |
| template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE | |
| content = autocomplete_generation_template(template, prompt, messages, type, user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.AUTOCOMPLETE_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| log.error(f'Error generating chat completion: {e}') | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={'detail': 'An internal error has occurred.'}, | |
| ) | |
| async def generate_emoji(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| # Check if the user has a custom task model | |
| # If the user has a custom task model, use that model | |
| task_model_id = get_task_model_id( | |
| model_id, | |
| request.app.state.config.TASK_MODEL, | |
| request.app.state.config.TASK_MODEL_EXTERNAL, | |
| models, | |
| ) | |
| log.debug(f'generating emoji using model {task_model_id} for user {user.email} ') | |
| template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE | |
| content = emoji_generation_template(template, form_data['prompt'], user) | |
| payload = { | |
| 'model': task_model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': False, | |
| **( | |
| {'max_tokens': 4} | |
| if models[task_model_id].get('owned_by') == 'ollama' | |
| else { | |
| 'max_completion_tokens': 4, | |
| } | |
| ), | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'task': str(TASKS.EMOJI_GENERATION), | |
| 'task_body': form_data, | |
| 'chat_id': form_data.get('chat_id', None), | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': str(e)}, | |
| ) | |
| async def generate_moa_response(request: Request, form_data: dict, user=Depends(get_verified_user)): | |
| if getattr(request.state, 'direct', False) and hasattr(request.state, 'model'): | |
| models = { | |
| request.state.model['id']: request.state.model, | |
| } | |
| else: | |
| models = request.app.state.MODELS | |
| model_id = form_data['model'] | |
| if model_id not in models: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.MODEL_NOT_FOUND(), | |
| ) | |
| template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE | |
| content = moa_response_generation_template( | |
| template, | |
| form_data['prompt'], | |
| form_data['responses'], | |
| ) | |
| payload = { | |
| 'model': model_id, | |
| 'messages': [{'role': 'user', 'content': content}], | |
| 'stream': form_data.get('stream', False), | |
| 'metadata': { | |
| **(request.state.metadata if hasattr(request.state, 'metadata') else {}), | |
| 'chat_id': form_data.get('chat_id', None), | |
| 'task': str(TASKS.MOA_RESPONSE_GENERATION), | |
| 'task_body': form_data, | |
| }, | |
| } | |
| # Process the payload through the pipeline | |
| try: | |
| payload = await process_pipeline_inlet_filter(request, payload, user, models) | |
| except Exception as e: | |
| raise e | |
| try: | |
| return await generate_chat_completion(request, form_data=payload, user=user) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content={'detail': str(e)}, | |
| ) | |