Spaces:
Running
Running
| from contextlib import asynccontextmanager | |
| import asyncio | |
| import malaya_speech | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| import io | |
| import logging.config | |
| import uuid | |
| import soundfile as sf | |
| from contextvars import ContextVar | |
| from logging_config import LOGGING_CONFIG | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
| import librosa | |
| from model import TTSRequest, STTResponse | |
| # --- Context Variable and Logging Setup --- | |
| trace_id_var: ContextVar[str] = ContextVar("trace_id", default="NO_ID") | |
| logging.config.dictConfig(LOGGING_CONFIG) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("speech_service") | |
| # Define a dictionary to store application state | |
| app_state = {} | |
| # --- Global Dictionary to Hold Models --- | |
| tts_models = {} | |
| vocoder_models = {} | |
| stt_models = {} | |
| def load_model(): | |
| """ | |
| Load STT and TTS models on startup | |
| """ | |
| logger.info("Loading Whisper model and processor...") | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load the processor (tokenizer and feature extractor) | |
| stt_models['processor'] = AutoProcessor.from_pretrained("mesolitica/malaysian-whisper-base") | |
| stt_models['model'] = AutoModelForSpeechSeq2Seq.from_pretrained("mesolitica/malaysian-whisper-base") | |
| stt_models['model'].to(device) # Move model to GPU if available | |
| logger.info("Whisper model and processor loaded successfully.") | |
| # --- Load TTS FastSpeech2 models --- | |
| logger.info("Loading TTS models...") | |
| tts_models['female'] = malaya_speech.tts.fastspeech2(model='yasmin') | |
| tts_models['male'] = malaya_speech.tts.fastspeech2(model='osman') | |
| logger.info("TTS models loaded successfully.") | |
| logger.info("Loading Vocoder models...") | |
| vocoder_models['female'] = malaya_speech.vocoder.melgan(model='yasmin') | |
| vocoder_models['male'] = malaya_speech.vocoder.melgan(model='osman') | |
| logger.info("Vocoder models loaded successfully.") | |
| async def lifespan(app: FastAPI): | |
| """ | |
| Handles startup and shutdown events. Loads ML models on startup. | |
| """ | |
| logger.info("Application startup: Loading ML models...") | |
| loop = asyncio.get_event_loop() | |
| # Run the synchronous, CPU-bound model loading function in a separate thread | |
| await loop.run_in_executor(None, load_model) | |
| logger.info("ML models loaded successfully. Application is ready.") | |
| # This 'yield' is the point where the application starts accepting requests. | |
| # It will only be reached AFTER the models are loaded. | |
| yield | |
| # --- Code to run on application shutdown (optional) --- | |
| logger.info("Application shutdown: Clearing models...") | |
| stt_models.clear() | |
| tts_models.clear() | |
| vocoder_models.clear() | |
| # --- Initialize FastAPI app --- | |
| app = FastAPI( | |
| title="Malaya Speech Service", | |
| description="A service for Text-to-Speech and Speech-to-Text using Malaya-Speech.", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # --- Health Check Endpoint --- | |
| def read_root(): | |
| return {"status": "Malaya Speech Service is running"} | |
| # --- Request Interceptor Middleware Setup --- | |
| async def logging_middleware(request: Request, call_next): | |
| trace_id = request.headers.get("X-Trace-Id") | |
| if not trace_id: | |
| trace_id = str(uuid.uuid4()) | |
| token = trace_id_var.set(trace_id) | |
| logger.info(f"Request started: {request.method} {request.url.path}") | |
| try: | |
| response = await call_next(request) | |
| response.headers["X-Trace-Id"] = trace_id | |
| logger.info(f"Request finished with status: {response.status_code}") | |
| return response | |
| finally: | |
| trace_id_var.reset(token) | |
| # --- Text-to-Speech Endpoint --- | |
| async def text_to_speech(request: TTSRequest): | |
| """ | |
| Converts text to speech using a gender-specific model and returns the audio as a WAV file. | |
| """ | |
| try: | |
| logger.info(f"Received TTS request for gender: '{request.gender}' with text: '{request.text[:30]}'") | |
| tts = tts_models[request.gender] | |
| vocoder = vocoder_models[request.gender] | |
| mel_spectrogram = tts.predict(request.text) | |
| y_ = vocoder(mel_spectrogram['mel-output']) | |
| # Write audio data to an in-memory buffer as a WAV file | |
| buffer = io.BytesIO() | |
| sf.write(buffer, y_, samplerate=22050, format="WAV") | |
| # Move buffer's cursor back to beginning to be read by response | |
| buffer.seek(0) | |
| # Stream content back to client | |
| return StreamingResponse(buffer, media_type="audio/wav") | |
| except Exception as e: | |
| logger.error(f"An error occurred during TTS processing: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to process text-to-speech request.") | |
| # --- Speech-to-Text Endpoint --- | |
| async def speech_to_text(file: UploadFile = File(...)): | |
| """ | |
| Converts speech from an audio file to text using the Whisper model. | |
| """ | |
| try: | |
| if not file.content_type.startswith("audio/"): | |
| raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.") | |
| logger.info(f"Received STT request for file: {file.filename}") | |
| # Read uploaded file content into memory | |
| audio_bytes = await file.read() | |
| # Load audio data and its original sampling rate | |
| # 'y' will be audio waveform (numpy array), 'sr' will be sample rate | |
| y, sr = sf.read(io.BytesIO(audio_bytes)) | |
| # Resample audio to 16,000 Hz, which is what Whisper expects. | |
| if sr != 16000: | |
| y = librosa.resample(y=y, orig_sr=sr, target_sr=16000) | |
| # Ensure audio is mono (single channel) | |
| if len(y.shape) > 1: | |
| y = librosa.to_mono(y) | |
| processor = stt_models['processor'] | |
| model = stt_models['model'] | |
| # Process the audio array to get the input features | |
| inputs = processor(y, sampling_rate=16000, return_tensors="pt") | |
| input_features = inputs.input_features | |
| attention_mask = torch.ones(input_features.shape, dtype=torch.long) | |
| if torch.cuda.is_available(): | |
| input_features = input_features.to("cuda:0") | |
| attention_mask = attention_mask.to("cuda:0") | |
| predicted_ids = model.generate( | |
| input_features, | |
| attention_mask=attention_mask, | |
| language='ms' | |
| ) | |
| # Decode the token IDs to text | |
| # Using batch_decode is preferred as it handles batches and cleans up special tokens. | |
| transcribed_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| print(f"Whisper transcription result: '{transcribed_text}'") | |
| return STTResponse(text=transcribed_text) | |
| except Exception as e: | |
| print(f"Error during STT with Whisper: {e}") | |
| # Be careful not to expose too much detail in production errors | |
| raise HTTPException(status_code=500, detail=f"An error occurred during transcription: {e}") | |
| # --- Speech-to-Text Endpoint --- | |
| # @app.post("/stt", response_model=STTResponse) | |
| # async def speech_to_text(file: UploadFile = File(...)): | |
| # """ | |
| # Converts speech from an audio file to text. | |
| # """ | |
| # try: | |
| # if not file.content_type.startswith("audio/"): | |
| # raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.") | |
| # | |
| # print(f"Received STT request for file: {file.filename}") | |
| # # Read the audio file content | |
| # audio_bytes = await file.read() | |
| # | |
| # y, sr = sf.read(io.BytesIO(audio_bytes)) | |
| # transcribed_text = small_model.beam_decoder([y])[0] | |
| # | |
| # print(f"Transcription result: '{transcribed_text}'") | |
| # return STTResponse(text=transcribed_text) | |
| # except Exception as e: | |
| # print(f"Error during STT: {e}") | |
| # raise HTTPException(status_code=500, detail=str(e)) | |
| # # --- Text-to-Speech Endpoint --- | |
| # @app.post("/tts", response_class=StreamingResponse) | |
| # async def text_to_speech(request: TTSRequest): | |
| # """ | |
| # Converts text to speech and returns the audio as a WAV file. | |
| # """ | |
| # try: | |
| # logger.info(f"Received TTS request: '{request}'") | |
| # | |
| # mel_spectrogram = tts_model.predict(request.text) | |
| # y_ = vocoder_female(mel_spectrogram['mel-output']) | |
| # buffer = io.BytesIO() | |
| # sf.write(buffer, y_, samplerate=22050, format="WAV") | |
| # # Move the cursor to the beginning | |
| # buffer.seek(0) | |
| # return StreamingResponse(buffer, media_type="audio/wav") | |
| # | |
| # except Exception as e: | |
| # print(f"Error during TTS: {e}") | |
| # raise HTTPException(status_code=500, detail=str(e)) | |