Spaces:
Build error
Build error
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import uuid | |
| import html | |
| import base64 | |
| from pydub import AudioSegment | |
| from pydub.silence import split_on_silence | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Optional | |
| from fnmatch import fnmatch | |
| import aiohttp | |
| import aiofiles | |
| import requests | |
| import mimetypes | |
| from fastapi import ( | |
| Depends, | |
| FastAPI, | |
| File, | |
| Form, | |
| HTTPException, | |
| Request, | |
| UploadFile, | |
| status, | |
| APIRouter, | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from open_webui.utils.misc import strict_match_mime_type | |
| from open_webui.utils.auth import get_admin_user, get_verified_user | |
| from open_webui.utils.access_control import has_permission | |
| from open_webui.utils.headers import include_user_info_headers | |
| from open_webui.config import ( | |
| WHISPER_MODEL_AUTO_UPDATE, | |
| WHISPER_COMPUTE_TYPE, | |
| WHISPER_MODEL_DIR, | |
| WHISPER_VAD_FILTER, | |
| CACHE_DIR, | |
| WHISPER_LANGUAGE, | |
| WHISPER_MULTILINGUAL, | |
| ELEVENLABS_API_BASE_URL, | |
| ) | |
| from open_webui.constants import ERROR_MESSAGES | |
| from open_webui.env import ( | |
| ENV, | |
| AIOHTTP_CLIENT_SESSION_SSL, | |
| AIOHTTP_CLIENT_TIMEOUT, | |
| AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, | |
| BYPASS_PYDUB_PREPROCESSING, | |
| DEVICE_TYPE, | |
| ENABLE_FORWARD_USER_INFO_HEADERS, | |
| ) | |
| router = APIRouter() | |
| # Constants | |
| MAX_FILE_SIZE_MB = 20 | |
| MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes | |
| AZURE_MAX_FILE_SIZE_MB = 200 | |
| AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes | |
| log = logging.getLogger(__name__) | |
| SPEECH_CACHE_DIR = CACHE_DIR / 'audio' / 'speech' | |
| SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| ########################################## | |
| # | |
| # Utility functions | |
| # Let what is spoken here be heard clearly, and let | |
| # no voice be reduced to noise along the way. | |
| # | |
| ########################################## | |
| from pydub import AudioSegment | |
| from pydub.utils import mediainfo | |
| def is_audio_conversion_required(file_path): | |
| """ | |
| Check if the given audio file needs conversion to mp3. | |
| """ | |
| SUPPORTED_FORMATS = {'flac', 'm4a', 'mp3', 'mp4', 'mpeg', 'wav', 'webm'} | |
| if not os.path.isfile(file_path): | |
| log.error(f'File not found: {file_path}') | |
| return False | |
| try: | |
| info = mediainfo(file_path) | |
| codec_name = info.get('codec_name', '').lower() | |
| codec_type = info.get('codec_type', '').lower() | |
| codec_tag_string = info.get('codec_tag_string', '').lower() | |
| if codec_name == 'aac' and codec_type == 'audio' and codec_tag_string == 'mp4a': | |
| # File is AAC/mp4a audio, recommend mp3 conversion | |
| return True | |
| # If the codec name is in the supported formats | |
| if codec_name in SUPPORTED_FORMATS: | |
| return False | |
| return True | |
| except Exception as e: | |
| log.error(f'Error getting audio format: {e}') | |
| return False | |
| def convert_audio_to_mp3(file_path): | |
| """Convert audio file to mp3 format.""" | |
| try: | |
| output_path = os.path.splitext(file_path)[0] + '.mp3' | |
| audio = AudioSegment.from_file(file_path) | |
| audio.export(output_path, format='mp3') | |
| log.info(f'Converted {file_path} to {output_path}') | |
| return output_path | |
| except Exception as e: | |
| log.error(f'Error converting audio file: {e}') | |
| return None | |
| def set_faster_whisper_model(model: str, auto_update: bool = False): | |
| whisper_model = None | |
| if model: | |
| from faster_whisper import WhisperModel | |
| faster_whisper_kwargs = { | |
| 'model_size_or_path': model, | |
| 'device': DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == 'cuda' else 'cpu', | |
| 'compute_type': WHISPER_COMPUTE_TYPE, | |
| 'download_root': WHISPER_MODEL_DIR, | |
| 'local_files_only': not auto_update, | |
| } | |
| try: | |
| whisper_model = WhisperModel(**faster_whisper_kwargs) | |
| except Exception: | |
| log.warning('WhisperModel initialization failed, attempting download with local_files_only=False') | |
| faster_whisper_kwargs['local_files_only'] = False | |
| whisper_model = WhisperModel(**faster_whisper_kwargs) | |
| return whisper_model | |
| ########################################## | |
| # | |
| # Audio API | |
| # | |
| ########################################## | |
| class TTSConfigForm(BaseModel): | |
| OPENAI_API_BASE_URL: str | |
| OPENAI_API_KEY: str | |
| OPENAI_PARAMS: Optional[dict] = None | |
| API_KEY: str | |
| ENGINE: str | |
| MODEL: str | |
| VOICE: str | |
| SPLIT_ON: str | |
| AZURE_SPEECH_REGION: str | |
| AZURE_SPEECH_BASE_URL: str | |
| AZURE_SPEECH_OUTPUT_FORMAT: str | |
| MISTRAL_API_KEY: str | |
| MISTRAL_API_BASE_URL: str | |
| class STTConfigForm(BaseModel): | |
| OPENAI_API_BASE_URL: str | |
| OPENAI_API_KEY: str | |
| ENGINE: str | |
| MODEL: str | |
| SUPPORTED_CONTENT_TYPES: list[str] = [] | |
| WHISPER_MODEL: str | |
| DEEPGRAM_API_KEY: str | |
| AZURE_API_KEY: str | |
| AZURE_REGION: str | |
| AZURE_LOCALES: str | |
| AZURE_BASE_URL: str | |
| AZURE_MAX_SPEAKERS: str | |
| MISTRAL_API_KEY: str | |
| MISTRAL_API_BASE_URL: str | |
| MISTRAL_USE_CHAT_COMPLETIONS: bool | |
| class AudioConfigUpdateForm(BaseModel): | |
| tts: TTSConfigForm | |
| stt: STTConfigForm | |
| async def get_audio_config(request: Request, user=Depends(get_admin_user)): | |
| return { | |
| 'tts': { | |
| 'OPENAI_API_BASE_URL': request.app.state.config.TTS_OPENAI_API_BASE_URL, | |
| 'OPENAI_API_KEY': request.app.state.config.TTS_OPENAI_API_KEY, | |
| 'OPENAI_PARAMS': request.app.state.config.TTS_OPENAI_PARAMS, | |
| 'API_KEY': request.app.state.config.TTS_API_KEY, | |
| 'ENGINE': request.app.state.config.TTS_ENGINE, | |
| 'MODEL': request.app.state.config.TTS_MODEL, | |
| 'VOICE': request.app.state.config.TTS_VOICE, | |
| 'SPLIT_ON': request.app.state.config.TTS_SPLIT_ON, | |
| 'AZURE_SPEECH_REGION': request.app.state.config.TTS_AZURE_SPEECH_REGION, | |
| 'AZURE_SPEECH_BASE_URL': request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, | |
| 'AZURE_SPEECH_OUTPUT_FORMAT': request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, | |
| 'MISTRAL_API_KEY': request.app.state.config.TTS_MISTRAL_API_KEY, | |
| 'MISTRAL_API_BASE_URL': request.app.state.config.TTS_MISTRAL_API_BASE_URL, | |
| }, | |
| 'stt': { | |
| 'OPENAI_API_BASE_URL': request.app.state.config.STT_OPENAI_API_BASE_URL, | |
| 'OPENAI_API_KEY': request.app.state.config.STT_OPENAI_API_KEY, | |
| 'ENGINE': request.app.state.config.STT_ENGINE, | |
| 'MODEL': request.app.state.config.STT_MODEL, | |
| 'SUPPORTED_CONTENT_TYPES': request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, | |
| 'WHISPER_MODEL': request.app.state.config.WHISPER_MODEL, | |
| 'DEEPGRAM_API_KEY': request.app.state.config.DEEPGRAM_API_KEY, | |
| 'AZURE_API_KEY': request.app.state.config.AUDIO_STT_AZURE_API_KEY, | |
| 'AZURE_REGION': request.app.state.config.AUDIO_STT_AZURE_REGION, | |
| 'AZURE_LOCALES': request.app.state.config.AUDIO_STT_AZURE_LOCALES, | |
| 'AZURE_BASE_URL': request.app.state.config.AUDIO_STT_AZURE_BASE_URL, | |
| 'AZURE_MAX_SPEAKERS': request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, | |
| 'MISTRAL_API_KEY': request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, | |
| 'MISTRAL_API_BASE_URL': request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, | |
| 'MISTRAL_USE_CHAT_COMPLETIONS': request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, | |
| }, | |
| } | |
| async def update_audio_config(request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)): | |
| request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL | |
| request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY | |
| request.app.state.config.TTS_OPENAI_PARAMS = form_data.tts.OPENAI_PARAMS | |
| request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY | |
| request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE | |
| request.app.state.config.TTS_MODEL = form_data.tts.MODEL | |
| request.app.state.config.TTS_VOICE = form_data.tts.VOICE | |
| request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON | |
| request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION | |
| request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = form_data.tts.AZURE_SPEECH_BASE_URL | |
| request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT | |
| request.app.state.config.TTS_MISTRAL_API_KEY = form_data.tts.MISTRAL_API_KEY | |
| request.app.state.config.TTS_MISTRAL_API_BASE_URL = form_data.tts.MISTRAL_API_BASE_URL | |
| request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL | |
| request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY | |
| request.app.state.config.STT_ENGINE = form_data.stt.ENGINE | |
| request.app.state.config.STT_MODEL = form_data.stt.MODEL | |
| request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = form_data.stt.SUPPORTED_CONTENT_TYPES | |
| request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL | |
| request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY | |
| request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY | |
| request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION | |
| request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES | |
| request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL | |
| request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = form_data.stt.AZURE_MAX_SPEAKERS | |
| request.app.state.config.AUDIO_STT_MISTRAL_API_KEY = form_data.stt.MISTRAL_API_KEY | |
| request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL = form_data.stt.MISTRAL_API_BASE_URL | |
| request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS = form_data.stt.MISTRAL_USE_CHAT_COMPLETIONS | |
| if request.app.state.config.STT_ENGINE == '': | |
| request.app.state.faster_whisper_model = set_faster_whisper_model( | |
| form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE | |
| ) | |
| else: | |
| request.app.state.faster_whisper_model = None | |
| return { | |
| 'tts': { | |
| 'ENGINE': request.app.state.config.TTS_ENGINE, | |
| 'MODEL': request.app.state.config.TTS_MODEL, | |
| 'VOICE': request.app.state.config.TTS_VOICE, | |
| 'OPENAI_API_BASE_URL': request.app.state.config.TTS_OPENAI_API_BASE_URL, | |
| 'OPENAI_API_KEY': request.app.state.config.TTS_OPENAI_API_KEY, | |
| 'OPENAI_PARAMS': request.app.state.config.TTS_OPENAI_PARAMS, | |
| 'API_KEY': request.app.state.config.TTS_API_KEY, | |
| 'SPLIT_ON': request.app.state.config.TTS_SPLIT_ON, | |
| 'AZURE_SPEECH_REGION': request.app.state.config.TTS_AZURE_SPEECH_REGION, | |
| 'AZURE_SPEECH_BASE_URL': request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, | |
| 'AZURE_SPEECH_OUTPUT_FORMAT': request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT, | |
| 'MISTRAL_API_KEY': request.app.state.config.TTS_MISTRAL_API_KEY, | |
| 'MISTRAL_API_BASE_URL': request.app.state.config.TTS_MISTRAL_API_BASE_URL, | |
| }, | |
| 'stt': { | |
| 'OPENAI_API_BASE_URL': request.app.state.config.STT_OPENAI_API_BASE_URL, | |
| 'OPENAI_API_KEY': request.app.state.config.STT_OPENAI_API_KEY, | |
| 'ENGINE': request.app.state.config.STT_ENGINE, | |
| 'MODEL': request.app.state.config.STT_MODEL, | |
| 'SUPPORTED_CONTENT_TYPES': request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, | |
| 'WHISPER_MODEL': request.app.state.config.WHISPER_MODEL, | |
| 'DEEPGRAM_API_KEY': request.app.state.config.DEEPGRAM_API_KEY, | |
| 'AZURE_API_KEY': request.app.state.config.AUDIO_STT_AZURE_API_KEY, | |
| 'AZURE_REGION': request.app.state.config.AUDIO_STT_AZURE_REGION, | |
| 'AZURE_LOCALES': request.app.state.config.AUDIO_STT_AZURE_LOCALES, | |
| 'AZURE_BASE_URL': request.app.state.config.AUDIO_STT_AZURE_BASE_URL, | |
| 'AZURE_MAX_SPEAKERS': request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS, | |
| 'MISTRAL_API_KEY': request.app.state.config.AUDIO_STT_MISTRAL_API_KEY, | |
| 'MISTRAL_API_BASE_URL': request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL, | |
| 'MISTRAL_USE_CHAT_COMPLETIONS': request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, | |
| }, | |
| } | |
| def load_speech_pipeline(request): | |
| from transformers import pipeline | |
| from datasets import load_dataset | |
| if request.app.state.speech_synthesiser is None: | |
| request.app.state.speech_synthesiser = pipeline('text-to-speech', 'microsoft/speecht5_tts') | |
| if request.app.state.speech_speaker_embeddings_dataset is None: | |
| request.app.state.speech_speaker_embeddings_dataset = load_dataset( | |
| 'Matthijs/cmu-arctic-xvectors', split='validation' | |
| ) | |
| async def speech(request: Request, user=Depends(get_verified_user)): | |
| if request.app.state.config.TTS_ENGINE == '': | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=ERROR_MESSAGES.NOT_FOUND, | |
| ) | |
| if user.role != 'admin' and not await has_permission( | |
| user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| body = await request.body() | |
| name = hashlib.sha256( | |
| body | |
| + str(request.app.state.config.TTS_ENGINE).encode('utf-8') | |
| + str(request.app.state.config.TTS_MODEL).encode('utf-8') | |
| ).hexdigest() | |
| file_path = SPEECH_CACHE_DIR.joinpath(f'{name}.mp3') | |
| file_body_path = SPEECH_CACHE_DIR.joinpath(f'{name}.json') | |
| # Check if the file already exists in the cache | |
| if file_path.is_file(): | |
| return FileResponse(file_path) | |
| payload = None | |
| try: | |
| payload = json.loads(body.decode('utf-8')) | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException(status_code=400, detail='Invalid JSON payload') | |
| r = None | |
| if request.app.state.config.TTS_ENGINE == 'openai': | |
| payload['model'] = request.app.state.config.TTS_MODEL | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| payload = { | |
| **payload, | |
| **(request.app.state.config.TTS_OPENAI_PARAMS or {}), | |
| } | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| 'Authorization': f'Bearer {request.app.state.config.TTS_OPENAI_API_KEY}', | |
| } | |
| if ENABLE_FORWARD_USER_INFO_HEADERS: | |
| headers = include_user_info_headers(headers, user) | |
| r = await session.post( | |
| url=f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech', | |
| json=payload, | |
| headers=headers, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) | |
| r.raise_for_status() | |
| async with aiofiles.open(file_path, 'wb') as f: | |
| await f.write(await r.read()) | |
| async with aiofiles.open(file_body_path, 'w') as f: | |
| await f.write(json.dumps(payload)) | |
| return FileResponse(file_path) | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| status_code = 500 | |
| detail = f'Open WebUI: Server Connection Error' | |
| if r is not None: | |
| status_code = r.status | |
| try: | |
| res = await r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"]}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=status_code, | |
| detail=detail, | |
| ) | |
| elif request.app.state.config.TTS_ENGINE == 'elevenlabs': | |
| voice_id = payload.get('voice', '') | |
| if voice_id not in await get_available_voices(request): | |
| raise HTTPException( | |
| status_code=400, | |
| detail='Invalid voice id', | |
| ) | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.post( | |
| f'{ELEVENLABS_API_BASE_URL}/v1/text-to-speech/{voice_id}', | |
| json={ | |
| 'text': payload['input'], | |
| 'model_id': request.app.state.config.TTS_MODEL, | |
| 'voice_settings': {'stability': 0.5, 'similarity_boost': 0.5}, | |
| }, | |
| headers={ | |
| 'Accept': 'audio/mpeg', | |
| 'Content-Type': 'application/json', | |
| 'xi-api-key': request.app.state.config.TTS_API_KEY, | |
| }, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as r: | |
| r.raise_for_status() | |
| async with aiofiles.open(file_path, 'wb') as f: | |
| await f.write(await r.read()) | |
| async with aiofiles.open(file_body_path, 'w') as f: | |
| await f.write(json.dumps(payload)) | |
| return FileResponse(file_path) | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| try: | |
| if r.status != 200: | |
| res = await r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=getattr(r, 'status', 500) if r else 500, | |
| detail=detail if detail else 'Open WebUI: Server Connection Error', | |
| ) | |
| elif request.app.state.config.TTS_ENGINE == 'azure': | |
| try: | |
| payload = json.loads(body.decode('utf-8')) | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException(status_code=400, detail='Invalid JSON payload') | |
| region = request.app.state.config.TTS_AZURE_SPEECH_REGION or 'eastus' | |
| base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL | |
| language = request.app.state.config.TTS_VOICE | |
| locale = '-'.join(request.app.state.config.TTS_VOICE.split('-')[:2]) | |
| output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT | |
| try: | |
| data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}"> | |
| <voice name="{language}">{html.escape(payload['input'])}</voice> | |
| </speak>""" | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.post( | |
| (base_url or f'https://{region}.tts.speech.microsoft.com') + '/cognitiveservices/v1', | |
| headers={ | |
| 'Ocp-Apim-Subscription-Key': request.app.state.config.TTS_API_KEY, | |
| 'Content-Type': 'application/ssml+xml', | |
| 'X-Microsoft-OutputFormat': output_format, | |
| }, | |
| data=data, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as r: | |
| r.raise_for_status() | |
| async with aiofiles.open(file_path, 'wb') as f: | |
| await f.write(await r.read()) | |
| async with aiofiles.open(file_body_path, 'w') as f: | |
| await f.write(json.dumps(payload)) | |
| return FileResponse(file_path) | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| try: | |
| if r.status != 200: | |
| res = await r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=getattr(r, 'status', 500) if r else 500, | |
| detail=detail if detail else 'Open WebUI: Server Connection Error', | |
| ) | |
| elif request.app.state.config.TTS_ENGINE == 'transformers': | |
| payload = None | |
| try: | |
| payload = json.loads(body.decode('utf-8')) | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException(status_code=400, detail='Invalid JSON payload') | |
| import torch | |
| import soundfile as sf | |
| load_speech_pipeline(request) | |
| embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset | |
| speaker_index = 6799 | |
| try: | |
| speaker_index = embeddings_dataset['filename'].index(request.app.state.config.TTS_MODEL) | |
| except Exception: | |
| pass | |
| speaker_embedding = torch.tensor(embeddings_dataset[speaker_index]['xvector']).unsqueeze(0) | |
| speech = request.app.state.speech_synthesiser( | |
| payload['input'], | |
| forward_params={'speaker_embeddings': speaker_embedding}, | |
| ) | |
| sf.write(file_path, speech['audio'], samplerate=speech['sampling_rate']) | |
| async with aiofiles.open(file_body_path, 'w') as f: | |
| await f.write(json.dumps(payload)) | |
| return FileResponse(file_path) | |
| elif request.app.state.config.TTS_ENGINE == 'mistral': | |
| api_key = request.app.state.config.TTS_MISTRAL_API_KEY | |
| api_base_url = request.app.state.config.TTS_MISTRAL_API_BASE_URL or 'https://api.mistral.ai/v1' | |
| if not api_key: | |
| raise HTTPException( | |
| status_code=400, | |
| detail='Mistral API key is required for Mistral TTS', | |
| ) | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| mistral_payload = { | |
| 'input': payload.get('input', ''), | |
| 'model': request.app.state.config.TTS_MODEL or 'voxtral-mini-tts-2603', | |
| 'voice_id': payload.get('voice', ''), | |
| 'response_format': 'mp3', | |
| } | |
| r = await session.post( | |
| url=f'{api_base_url}/audio/speech', | |
| json=mistral_payload, | |
| headers={ | |
| 'Content-Type': 'application/json', | |
| 'Authorization': f'Bearer {api_key}', | |
| }, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) | |
| r.raise_for_status() | |
| res = await r.json() | |
| audio_data = res.get('audio_data', '') | |
| if not audio_data: | |
| raise ValueError('No audio_data in Mistral TTS response') | |
| audio_bytes = base64.b64decode(audio_data) | |
| async with aiofiles.open(file_path, 'wb') as f: | |
| await f.write(audio_bytes) | |
| async with aiofiles.open(file_body_path, 'w') as f: | |
| await f.write(json.dumps(payload)) | |
| return FileResponse(file_path) | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| status_code = 500 | |
| detail = 'Open WebUI: Server Connection Error' | |
| if r is not None: | |
| status_code = r.status | |
| try: | |
| res = await r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"]}' | |
| elif 'message' in res: | |
| detail = f'External: {res["message"]}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=status_code, | |
| detail=detail, | |
| ) | |
| def transcription_handler(request, file_path, metadata, user=None): | |
| filename = os.path.basename(file_path) | |
| file_dir = os.path.dirname(file_path) | |
| id = filename.split('.')[0] | |
| metadata = metadata or {} | |
| languages = [ | |
| metadata.get('language', None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE, | |
| None, # Always fallback to None in case transcription fails | |
| ] | |
| if request.app.state.config.STT_ENGINE == '': | |
| if request.app.state.faster_whisper_model is None: | |
| request.app.state.faster_whisper_model = set_faster_whisper_model(request.app.state.config.WHISPER_MODEL) | |
| model = request.app.state.faster_whisper_model | |
| segments, info = model.transcribe( | |
| file_path, | |
| beam_size=5, | |
| vad_filter=WHISPER_VAD_FILTER, | |
| language=languages[0], | |
| multilingual=WHISPER_MULTILINGUAL, | |
| ) | |
| log.info("Detected language '%s' with probability %f" % (info.language, info.language_probability)) | |
| transcript = ''.join([segment.text for segment in list(segments)]) | |
| data = {'text': transcript.strip()} | |
| # save the transcript to a json file | |
| transcript_file = os.path.join(file_dir, f'{id}.json') | |
| with open(transcript_file, 'w') as f: | |
| json.dump(data, f) | |
| log.debug(data) | |
| return data | |
| elif request.app.state.config.STT_ENGINE == 'openai': | |
| r = None | |
| try: | |
| for language in languages: | |
| payload = { | |
| 'model': request.app.state.config.STT_MODEL, | |
| } | |
| if language: | |
| payload['language'] = language | |
| headers = {'Authorization': f'Bearer {request.app.state.config.STT_OPENAI_API_KEY}'} | |
| if user and ENABLE_FORWARD_USER_INFO_HEADERS: | |
| headers = include_user_info_headers(headers, user) | |
| with open(file_path, 'rb') as audio_file: | |
| r = requests.post( | |
| url=f'{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions', | |
| headers=headers, | |
| files={'file': (filename, audio_file)}, | |
| data=payload, | |
| timeout=AIOHTTP_CLIENT_TIMEOUT, | |
| ) | |
| if r.status_code == 200: | |
| # Successful transcription | |
| break | |
| r.raise_for_status() | |
| data = r.json() | |
| # save the transcript to a json file | |
| transcript_file = os.path.join(file_dir, f'{id}.json') | |
| with open(transcript_file, 'w') as f: | |
| json.dump(data, f) | |
| return data | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| if r is not None: | |
| try: | |
| res = r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise Exception(detail if detail else 'Open WebUI: Server Connection Error') | |
| elif request.app.state.config.STT_ENGINE == 'deepgram': | |
| try: | |
| # Determine the MIME type of the file | |
| mime, _ = mimetypes.guess_type(file_path) | |
| if not mime: | |
| mime = 'audio/wav' # fallback to wav if undetectable | |
| # Read the audio file | |
| with open(file_path, 'rb') as f: | |
| file_data = f.read() | |
| # Build headers and parameters | |
| headers = { | |
| 'Authorization': f'Token {request.app.state.config.DEEPGRAM_API_KEY}', | |
| 'Content-Type': mime, | |
| } | |
| for language in languages: | |
| params = {} | |
| if request.app.state.config.STT_MODEL: | |
| params['model'] = request.app.state.config.STT_MODEL | |
| if language: | |
| params['language'] = language | |
| # Make request to Deepgram API | |
| r = requests.post( | |
| 'https://api.deepgram.com/v1/listen?smart_format=true', | |
| headers=headers, | |
| params=params, | |
| data=file_data, | |
| timeout=AIOHTTP_CLIENT_TIMEOUT, | |
| ) | |
| if r.status_code == 200: | |
| # Successful transcription | |
| break | |
| r.raise_for_status() | |
| response_data = r.json() | |
| # Extract transcript from Deepgram response | |
| try: | |
| transcript = response_data['results']['channels'][0]['alternatives'][0].get('transcript', '') | |
| except (KeyError, IndexError) as e: | |
| log.error(f'Malformed response from Deepgram: {str(e)}') | |
| raise Exception('Failed to parse Deepgram response - unexpected response format') | |
| data = {'text': transcript.strip()} | |
| # Save transcript | |
| transcript_file = os.path.join(file_dir, f'{id}.json') | |
| with open(transcript_file, 'w') as f: | |
| json.dump(data, f) | |
| return data | |
| except Exception as e: | |
| log.exception(e) | |
| detail = None | |
| if r is not None: | |
| try: | |
| res = r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise Exception(detail if detail else 'Open WebUI: Server Connection Error') | |
| elif request.app.state.config.STT_ENGINE == 'azure': | |
| # Check file exists and size | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=400, detail='Audio file not found') | |
| # Check file size (Azure has a larger limit of 200MB) | |
| file_size = os.path.getsize(file_path) | |
| if file_size > AZURE_MAX_FILE_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB", | |
| ) | |
| api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY | |
| region = request.app.state.config.AUDIO_STT_AZURE_REGION or 'eastus' | |
| locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES | |
| base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL | |
| max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3 | |
| # IF NO LOCALES, USE DEFAULTS | |
| if len(locales) < 2: | |
| locales = [ | |
| 'en-US', | |
| 'es-ES', | |
| 'es-MX', | |
| 'fr-FR', | |
| 'hi-IN', | |
| 'it-IT', | |
| 'de-DE', | |
| 'en-GB', | |
| 'en-IN', | |
| 'ja-JP', | |
| 'ko-KR', | |
| 'pt-BR', | |
| 'zh-CN', | |
| ] | |
| locales = ','.join(locales) | |
| if not api_key or not region: | |
| raise HTTPException( | |
| status_code=400, | |
| detail='Azure API key is required for Azure STT', | |
| ) | |
| r = None | |
| try: | |
| # Prepare the request | |
| data = { | |
| 'definition': json.dumps( | |
| { | |
| 'locales': locales.split(','), | |
| 'diarization': {'maxSpeakers': max_speakers, 'enabled': True}, | |
| } | |
| if locales | |
| else {} | |
| ) | |
| } | |
| url = ( | |
| base_url or f'https://{region}.api.cognitive.microsoft.com' | |
| ) + '/speechtotext/transcriptions:transcribe?api-version=2024-11-15' | |
| # Use context manager to ensure file is properly closed | |
| with open(file_path, 'rb') as audio_file: | |
| r = requests.post( | |
| url=url, | |
| files={'audio': audio_file}, | |
| data=data, | |
| headers={ | |
| 'Ocp-Apim-Subscription-Key': api_key, | |
| }, | |
| timeout=AIOHTTP_CLIENT_TIMEOUT, | |
| ) | |
| r.raise_for_status() | |
| response = r.json() | |
| # Extract transcript from response | |
| if not response.get('combinedPhrases'): | |
| raise ValueError('No transcription found in response') | |
| # Get the full transcript from combinedPhrases | |
| transcript = response['combinedPhrases'][0].get('text', '').strip() | |
| if not transcript: | |
| raise ValueError('Empty transcript in response') | |
| data = {'text': transcript} | |
| # Save transcript to json file (consistent with other providers) | |
| transcript_file = os.path.join(file_dir, f'{id}.json') | |
| with open(transcript_file, 'w') as f: | |
| json.dump(data, f) | |
| log.debug(data) | |
| return data | |
| except (KeyError, IndexError, ValueError) as e: | |
| log.exception('Error parsing Azure response') | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f'Failed to parse Azure response: {str(e)}', | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| log.exception(e) | |
| detail = None | |
| status_code = getattr(r, 'status_code', 500) if r else 500 | |
| try: | |
| if r is not None and r.status_code != 200: | |
| res = r.json() | |
| # Azure returns {"code": "...", "message": "...", "innerError": {...}} | |
| if 'code' in res and 'message' in res: | |
| azure_code = res.get('innerError', {}).get('code', res['code']) | |
| user_facing_codes = { | |
| 'EmptyAudioFile', | |
| 'AudioLengthLimitExceeded', | |
| 'NoLanguageIdentified', | |
| 'MultipleLanguagesIdentified', | |
| } | |
| if azure_code in user_facing_codes: | |
| detail = res['message'] | |
| else: | |
| log.error(f'Azure STT error [{azure_code}]: {res["message"]}') | |
| detail = 'An error occurred during transcription.' | |
| elif 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=status_code, | |
| detail=detail if detail else 'Open WebUI: Server Connection Error', | |
| ) | |
| elif request.app.state.config.STT_ENGINE == 'mistral': | |
| # Check file exists | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=400, detail='Audio file not found') | |
| # Check file size | |
| file_size = os.path.getsize(file_path) | |
| if file_size > MAX_FILE_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f'File size exceeds limit of {MAX_FILE_SIZE_MB}MB', | |
| ) | |
| api_key = request.app.state.config.AUDIO_STT_MISTRAL_API_KEY | |
| api_base_url = request.app.state.config.AUDIO_STT_MISTRAL_API_BASE_URL or 'https://api.mistral.ai/v1' | |
| use_chat_completions = request.app.state.config.AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS | |
| if not api_key: | |
| raise HTTPException( | |
| status_code=400, | |
| detail='Mistral API key is required for Mistral STT', | |
| ) | |
| r = None | |
| try: | |
| # Use voxtral-mini-latest as the default model for transcription | |
| model = request.app.state.config.STT_MODEL or 'voxtral-mini-latest' | |
| log.info( | |
| f'Mistral STT - model: {model}, ' | |
| f'method: {"chat_completions" if use_chat_completions else "transcriptions"}' | |
| ) | |
| if use_chat_completions: | |
| # Use chat completions API with audio input | |
| # This method requires mp3 or wav format | |
| audio_file_to_use = file_path | |
| if is_audio_conversion_required(file_path): | |
| log.debug('Converting audio to mp3 for chat completions API') | |
| converted_path = convert_audio_to_mp3(file_path) | |
| if converted_path: | |
| audio_file_to_use = converted_path | |
| else: | |
| log.error('Audio conversion failed') | |
| raise HTTPException( | |
| status_code=500, | |
| detail='Audio conversion failed. Chat completions API requires mp3 or wav format.', | |
| ) | |
| # Read and encode audio file as base64 | |
| with open(audio_file_to_use, 'rb') as audio_file: | |
| audio_base64 = { | |
| 'data': base64.b64encode(audio_file.read()).decode('utf-8'), | |
| 'format': mimetypes.guess_extension(mimetypes.guess_type(audio_file_to_use)[0]).lstrip('.'), | |
| } | |
| # Prepare chat completions request | |
| url = f'{api_base_url}/chat/completions' | |
| # Add language instruction if specified | |
| language = metadata.get('language', None) if metadata else None | |
| if language: | |
| text_instruction = f'Transcribe this audio exactly as spoken in {language}. Do not translate it.' | |
| else: | |
| text_instruction = 'Transcribe this audio exactly as spoken in its original language. Do not translate it to another language.' | |
| payload = { | |
| 'model': model, | |
| 'messages': [ | |
| { | |
| 'role': 'user', | |
| 'content': [ | |
| { | |
| 'type': 'input_audio', | |
| 'input_audio': audio_base64, | |
| }, | |
| {'type': 'text', 'text': text_instruction}, | |
| ], | |
| } | |
| ], | |
| } | |
| r = requests.post( | |
| url=url, | |
| json=payload, | |
| headers={ | |
| 'Authorization': f'Bearer {api_key}', | |
| 'Content-Type': 'application/json', | |
| }, | |
| timeout=AIOHTTP_CLIENT_TIMEOUT, | |
| ) | |
| r.raise_for_status() | |
| response = r.json() | |
| # Extract transcript from chat completion response | |
| transcript = response.get('choices', [{}])[0].get('message', {}).get('content', '').strip() | |
| if not transcript: | |
| raise ValueError('Empty transcript in response') | |
| data = {'text': transcript} | |
| else: | |
| # Use dedicated transcriptions API | |
| url = f'{api_base_url}/audio/transcriptions' | |
| # Determine the MIME type | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| if not mime_type: | |
| mime_type = 'audio/webm' | |
| # Use context manager to ensure file is properly closed | |
| with open(file_path, 'rb') as audio_file: | |
| files = {'file': (filename, audio_file, mime_type)} | |
| data_form = {'model': model} | |
| # Add language if specified in metadata | |
| language = metadata.get('language', None) if metadata else None | |
| if language: | |
| data_form['language'] = language | |
| r = requests.post( | |
| url=url, | |
| files=files, | |
| data=data_form, | |
| headers={ | |
| 'Authorization': f'Bearer {api_key}', | |
| }, | |
| timeout=AIOHTTP_CLIENT_TIMEOUT, | |
| ) | |
| r.raise_for_status() | |
| response = r.json() | |
| # Extract transcript from response | |
| transcript = response.get('text', '').strip() | |
| if not transcript: | |
| raise ValueError('Empty transcript in response') | |
| data = {'text': transcript} | |
| # Save transcript to json file (consistent with other providers) | |
| transcript_file = os.path.join(file_dir, f'{id}.json') | |
| with open(transcript_file, 'w') as f: | |
| json.dump(data, f) | |
| log.debug(data) | |
| return data | |
| except ValueError as e: | |
| log.exception('Error parsing Mistral response') | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f'Failed to parse Mistral response: {str(e)}', | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| log.exception(e) | |
| detail = None | |
| try: | |
| if r is not None and r.status_code != 200: | |
| res = r.json() | |
| if 'error' in res: | |
| detail = f'External: {res["error"].get("message", "")}' | |
| else: | |
| detail = f'External: {r.text}' | |
| except Exception: | |
| detail = f'External: {e}' | |
| raise HTTPException( | |
| status_code=getattr(r, 'status_code', 500) if r else 500, | |
| detail=detail if detail else 'Open WebUI: Server Connection Error', | |
| ) | |
| def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None, user=None): | |
| log.info(f'transcribe: {file_path} {metadata}') | |
| if BYPASS_PYDUB_PREPROCESSING: | |
| log.info('Bypassing pydub preprocessing (BYPASS_PYDUB_PREPROCESSING=true)') | |
| chunk_paths = [file_path] | |
| else: | |
| if is_audio_conversion_required(file_path): | |
| file_path = convert_audio_to_mp3(file_path) | |
| try: | |
| file_path = compress_audio(file_path) | |
| except Exception as e: | |
| log.exception(e) | |
| # Always produce a list of chunk paths (could be one entry if small) | |
| try: | |
| chunk_paths = split_audio(file_path, MAX_FILE_SIZE) | |
| print(f'Chunk paths: {chunk_paths}') | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.DEFAULT(e), | |
| ) | |
| results = [] | |
| try: | |
| with ThreadPoolExecutor() as executor: | |
| # Submit tasks for each chunk_path | |
| futures = [ | |
| executor.submit(transcription_handler, request, chunk_path, metadata, user) | |
| for chunk_path in chunk_paths | |
| ] | |
| # Gather results as they complete | |
| for future in futures: | |
| try: | |
| results.append(future.result()) | |
| except HTTPException: | |
| raise | |
| except Exception as transcribe_exc: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f'Error transcribing chunk: {transcribe_exc}', | |
| ) | |
| finally: | |
| # Clean up only the temporary chunks, never the original file | |
| for chunk_path in chunk_paths: | |
| if chunk_path != file_path and os.path.isfile(chunk_path): | |
| try: | |
| os.remove(chunk_path) | |
| except Exception: | |
| pass | |
| return { | |
| 'text': ' '.join([result['text'] for result in results]), | |
| } | |
| def compress_audio(file_path): | |
| if os.path.getsize(file_path) > MAX_FILE_SIZE: | |
| id = os.path.splitext(os.path.basename(file_path))[0] # Handles names with multiple dots | |
| file_dir = os.path.dirname(file_path) | |
| audio = AudioSegment.from_file(file_path) | |
| audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio | |
| compressed_path = os.path.join(file_dir, f'{id}_compressed.mp3') | |
| audio.export(compressed_path, format='mp3', bitrate='32k') | |
| # log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined | |
| return compressed_path | |
| else: | |
| return file_path | |
| def split_audio(file_path, max_bytes, format='mp3', bitrate='32k'): | |
| """ | |
| Splits audio into chunks not exceeding max_bytes. | |
| Returns a list of chunk file paths. If audio fits, returns list with original path. | |
| """ | |
| file_size = os.path.getsize(file_path) | |
| if file_size <= max_bytes: | |
| return [file_path] # Nothing to split | |
| audio = AudioSegment.from_file(file_path) | |
| duration_ms = len(audio) | |
| orig_size = file_size | |
| approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000) | |
| chunks = [] | |
| start = 0 | |
| i = 0 | |
| base, _ = os.path.splitext(file_path) | |
| while start < duration_ms: | |
| end = min(start + approx_chunk_ms, duration_ms) | |
| chunk = audio[start:end] | |
| chunk_path = f'{base}_chunk_{i}.{format}' | |
| chunk.export(chunk_path, format=format, bitrate=bitrate) | |
| # Reduce chunk duration if still too large | |
| while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000: | |
| end = start + ((end - start) // 2) | |
| chunk = audio[start:end] | |
| chunk.export(chunk_path, format=format, bitrate=bitrate) | |
| if os.path.getsize(chunk_path) > max_bytes: | |
| os.remove(chunk_path) | |
| raise Exception('Audio chunk cannot be reduced below max file size.') | |
| chunks.append(chunk_path) | |
| start = end | |
| i += 1 | |
| return chunks | |
| async def transcription( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| language: Optional[str] = Form(None), | |
| user=Depends(get_verified_user), | |
| ): | |
| if user.role != 'admin' and not await has_permission( | |
| user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=ERROR_MESSAGES.ACCESS_PROHIBITED, | |
| ) | |
| log.info(f'file.content_type: {file.content_type}') | |
| stt_supported_content_types = getattr(request.app.state.config, 'STT_SUPPORTED_CONTENT_TYPES', []) | |
| if not strict_match_mime_type(stt_supported_content_types, file.content_type): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, | |
| ) | |
| try: | |
| safe_name = os.path.basename(file.filename) if file.filename else '' | |
| ext = safe_name.rsplit('.', 1)[-1] if '.' in safe_name else '' | |
| id = uuid.uuid4() | |
| filename = f'{id}.{ext}' | |
| contents = file.file.read() | |
| file_dir = os.path.join(CACHE_DIR, 'audio', 'transcriptions') | |
| os.makedirs(file_dir, exist_ok=True) | |
| file_path = os.path.join(file_dir, filename) | |
| # Defense-in-depth: ensure resolved path stays within intended directory | |
| if not os.path.realpath(file_path).startswith(os.path.realpath(file_dir)): | |
| raise ValueError('Invalid file path detected') | |
| with open(file_path, 'wb') as f: | |
| f.write(contents) | |
| try: | |
| metadata = None | |
| if language: | |
| metadata = {'language': language} | |
| result = transcribe(request, file_path, metadata, user) | |
| return { | |
| **result, | |
| 'filename': os.path.basename(file_path), | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail='Transcription failed.', | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| log.exception(e) | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail='Transcription failed.', | |
| ) | |
| async def get_available_models(request: Request) -> list[dict]: | |
| available_models = [] | |
| if request.app.state.config.TTS_ENGINE == 'openai': | |
| # Use custom endpoint if not using the official OpenAI API URL | |
| if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith('https://api.openai.com'): | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| try: | |
| async with session.get( | |
| f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models', | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| data = await response.json() | |
| available_models = data.get('models', []) | |
| except Exception as e: | |
| log.debug(f'/audio/models not available, trying /models fallback: {str(e)}') | |
| # Fallback to standard OpenAI-compatible /models endpoint | |
| # (used by KokoroTTS and similar custom TTS servers) | |
| try: | |
| async with session.get( | |
| f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/models', | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| data = await response.json() | |
| # OpenAI /models returns {"data": [...]}, /audio/models returns {"models": [...]} | |
| available_models = data.get('data', data.get('models', [])) | |
| except Exception as e2: | |
| log.error(f'Error fetching models from custom endpoint: {str(e2)}') | |
| available_models = [{'id': 'tts-1'}, {'id': 'tts-1-hd'}] | |
| else: | |
| available_models = [{'id': 'tts-1'}, {'id': 'tts-1-hd'}] | |
| elif request.app.state.config.TTS_ENGINE == 'elevenlabs': | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=5) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get( | |
| f'{ELEVENLABS_API_BASE_URL}/v1/models', | |
| headers={ | |
| 'xi-api-key': request.app.state.config.TTS_API_KEY, | |
| 'Content-Type': 'application/json', | |
| }, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| models = await response.json() | |
| available_models = [{'name': model['name'], 'id': model['model_id']} for model in models] | |
| except Exception as e: | |
| log.error(f'Error fetching models: {str(e)}') | |
| elif request.app.state.config.TTS_ENGINE == 'mistral': | |
| available_models = [{'id': 'voxtral-mini-tts-2603'}] | |
| return available_models | |
| async def get_models(request: Request, user=Depends(get_verified_user)): | |
| return {'models': await get_available_models(request)} | |
| async def get_available_voices(request) -> dict: | |
| """Returns {voice_id: voice_name} dict""" | |
| available_voices = {} | |
| if request.app.state.config.TTS_ENGINE == 'openai': | |
| # Use custom endpoint if not using the official OpenAI API URL | |
| if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith('https://api.openai.com'): | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get( | |
| f'{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices', | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| data = await response.json() | |
| voices_list = data.get('voices', []) | |
| available_voices = {voice['id']: voice['name'] for voice in voices_list} | |
| except Exception as e: | |
| log.error(f'Error fetching voices from custom endpoint: {str(e)}') | |
| available_voices = { | |
| 'alloy': 'alloy', | |
| 'echo': 'echo', | |
| 'fable': 'fable', | |
| 'onyx': 'onyx', | |
| 'nova': 'nova', | |
| 'shimmer': 'shimmer', | |
| } | |
| else: | |
| available_voices = { | |
| 'alloy': 'alloy', | |
| 'echo': 'echo', | |
| 'fable': 'fable', | |
| 'onyx': 'onyx', | |
| 'nova': 'nova', | |
| 'shimmer': 'shimmer', | |
| } | |
| elif request.app.state.config.TTS_ENGINE == 'elevenlabs': | |
| try: | |
| available_voices = await get_elevenlabs_voices(api_key=request.app.state.config.TTS_API_KEY) | |
| except Exception: | |
| # Avoided @lru_cache with exception | |
| pass | |
| elif request.app.state.config.TTS_ENGINE == 'azure': | |
| try: | |
| region = request.app.state.config.TTS_AZURE_SPEECH_REGION | |
| base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL | |
| url = (base_url or f'https://{region}.tts.speech.microsoft.com') + '/cognitiveservices/voices/list' | |
| headers = {'Ocp-Apim-Subscription-Key': request.app.state.config.TTS_API_KEY} | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get(url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response: | |
| response.raise_for_status() | |
| voices = await response.json() | |
| for voice in voices: | |
| available_voices[voice['ShortName']] = f'{voice["DisplayName"]} ({voice["ShortName"]})' | |
| except Exception as e: | |
| log.error(f'Error fetching voices: {str(e)}') | |
| elif request.app.state.config.TTS_ENGINE == 'mistral': | |
| api_key = request.app.state.config.TTS_MISTRAL_API_KEY | |
| api_base_url = request.app.state.config.TTS_MISTRAL_API_BASE_URL or 'https://api.mistral.ai/v1' | |
| if api_key: | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get( | |
| f'{api_base_url}/audio/voices', | |
| headers={ | |
| 'Authorization': f'Bearer {api_key}', | |
| }, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| voices_data = await response.json() | |
| # Mistral returns a paginated response: {"items": [...], "page": ..., "total": ...} | |
| voices_list = voices_data.get('items', []) if isinstance(voices_data, dict) else voices_data | |
| for voice in voices_list: | |
| if isinstance(voice, dict): | |
| voice_id = voice.get('voice_id', voice.get('id', '')) | |
| voice_name = voice.get('name', voice_id) | |
| if voice_id: | |
| available_voices[voice_id] = voice_name | |
| except Exception as e: | |
| log.error(f'Error fetching Mistral voices: {str(e)}') | |
| return available_voices | |
| async def get_elevenlabs_voices(api_key: str) -> dict: | |
| """ | |
| Note, set the following in your .env file to use Elevenlabs: | |
| AUDIO_TTS_ENGINE=elevenlabs | |
| AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key | |
| AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices | |
| AUDIO_TTS_MODEL=eleven_multilingual_v2 | |
| """ | |
| try: | |
| # TODO: Add retries | |
| timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST) | |
| async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
| async with session.get( | |
| f'{ELEVENLABS_API_BASE_URL}/v1/voices', | |
| headers={ | |
| 'xi-api-key': api_key, | |
| 'Content-Type': 'application/json', | |
| }, | |
| ssl=AIOHTTP_CLIENT_SESSION_SSL, | |
| ) as response: | |
| response.raise_for_status() | |
| voices_data = await response.json() | |
| voices = {} | |
| for voice in voices_data.get('voices', []): | |
| voices[voice['voice_id']] = voice['name'] | |
| except Exception as e: | |
| log.error(f'Error fetching voices: {str(e)}') | |
| raise RuntimeError(f'Error fetching voices: {str(e)}') | |
| return voices | |
| async def get_voices(request: Request, user=Depends(get_verified_user)): | |
| return {'voices': [{'id': k, 'name': v} for k, v in (await get_available_voices(request)).items()]} | |