malek-messaoudii
Refactor audio processing to utilize free models and enhance logging; update TTS and STT services for improved functionality
95cb26e
| from fastapi import APIRouter, UploadFile, File, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| import io | |
| import logging | |
| from config import ALLOWED_AUDIO_TYPES, MAX_AUDIO_SIZE | |
| from services.stt_service import speech_to_text, load_stt_model | |
| from services.tts_service import generate_tts | |
| from services.chatbot_service import get_chatbot_response, load_chatbot_model | |
| from models.audio import STTResponse, TTSRequest, TTSResponse, ChatbotRequest, ChatbotResponse | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/audio", tags=["Audio"]) | |
| # Pre-load models on router startup | |
| async def startup_event(): | |
| """Load models when the router starts""" | |
| logger.info("Loading free STT and Chatbot models...") | |
| load_stt_model() | |
| load_chatbot_model() | |
| async def tts(request: TTSRequest): | |
| """ | |
| Convert text to speech and return audio file using free gTTS. | |
| Example: | |
| - POST /audio/tts | |
| - Body: {"text": "Hello, welcome to our system"} | |
| - Returns: MP3 audio file | |
| """ | |
| try: | |
| logger.info(f"TTS request received for text: '{request.text}'") | |
| audio_bytes = await generate_tts(request.text) | |
| return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/mp3") | |
| except Exception as e: | |
| logger.error(f"TTS error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def stt(file: UploadFile = File(...)): | |
| """ | |
| Convert audio file to text using free Whisper model. | |
| Example: | |
| - POST /audio/stt | |
| - File: audio.mp3 (or .wav, .m4a) | |
| - Returns: {"text": "transcribed text", "model_name": "whisper-small", ...} | |
| """ | |
| # Validate file type | |
| if file.content_type not in ALLOWED_AUDIO_TYPES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A" | |
| ) | |
| try: | |
| logger.info(f"STT request received for file: {file.filename}") | |
| audio_bytes = await file.read() | |
| # Check file size | |
| if len(audio_bytes) > MAX_AUDIO_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio file too large. Max size: {MAX_AUDIO_SIZE / 1024 / 1024}MB" | |
| ) | |
| text = await speech_to_text(audio_bytes, file.filename) | |
| return STTResponse( | |
| text=text, | |
| model_name="whisper-small", | |
| language="en", | |
| duration_seconds=None | |
| ) | |
| except Exception as e: | |
| logger.error(f"STT error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chatbot_voice(file: UploadFile = File(...)): | |
| """ | |
| Full voice chatbot flow using free models (Audio → Text → Response → Audio). | |
| Example: | |
| - POST /audio/chatbot | |
| - File: user_voice.mp3 | |
| - Returns: Response audio file (MP3) | |
| """ | |
| # Validate file type | |
| if file.content_type not in ALLOWED_AUDIO_TYPES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported format: {file.content_type}. Supported: WAV, MP3, M4A" | |
| ) | |
| try: | |
| logger.info(f"Voice chatbot request received for file: {file.filename}") | |
| # Step 1: Convert audio to text | |
| audio_bytes = await file.read() | |
| # Check file size | |
| if len(audio_bytes) > MAX_AUDIO_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Audio file too large. Max size: {MAX_AUDIO_SIZE / 1024 / 1024}MB" | |
| ) | |
| user_text = await speech_to_text(audio_bytes, file.filename) | |
| logger.info(f"Step 1 - STT: {user_text}") | |
| # Step 2: Generate chatbot response | |
| response_text = await get_chatbot_response(user_text) | |
| logger.info(f"Step 2 - Response: {response_text}") | |
| # Step 3: Convert response to audio | |
| audio_response = await generate_tts(response_text) | |
| logger.info("Step 3 - TTS: Complete") | |
| return StreamingResponse(io.BytesIO(audio_response), media_type="audio/mp3") | |
| except Exception as e: | |
| logger.error(f"Voice chatbot error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chatbot_text(request: ChatbotRequest): | |
| """ | |
| Chatbot interaction with text input/output using free DialoGPT model. | |
| Example: | |
| - POST /audio/chatbot-text | |
| - Body: {"text": "What is the capital of France?"} | |
| - Returns: {"user_input": "What is...", "bot_response": "The capital...", ...} | |
| """ | |
| try: | |
| logger.info(f"Text chatbot request: {request.text}") | |
| response_text = await get_chatbot_response(request.text) | |
| return ChatbotResponse( | |
| user_input=request.text, | |
| bot_response=response_text, | |
| model_name="DialoGPT-medium" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Text chatbot error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) |