Master
Add application file
0c527d5
from contextlib import asynccontextmanager
import asyncio
import malaya_speech
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.responses import StreamingResponse
import io
import logging.config
import uuid
import soundfile as sf
from contextvars import ContextVar
from logging_config import LOGGING_CONFIG
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import librosa
from model import TTSRequest, STTResponse
# --- Context Variable and Logging Setup ---
trace_id_var: ContextVar[str] = ContextVar("trace_id", default="NO_ID")
logging.config.dictConfig(LOGGING_CONFIG)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("speech_service")
# Define a dictionary to store application state
app_state = {}
# --- Global Dictionary to Hold Models ---
tts_models = {}
vocoder_models = {}
stt_models = {}
def load_model():
"""
Load STT and TTS models on startup
"""
logger.info("Loading Whisper model and processor...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the processor (tokenizer and feature extractor)
stt_models['processor'] = AutoProcessor.from_pretrained("mesolitica/malaysian-whisper-base")
stt_models['model'] = AutoModelForSpeechSeq2Seq.from_pretrained("mesolitica/malaysian-whisper-base")
stt_models['model'].to(device) # Move model to GPU if available
logger.info("Whisper model and processor loaded successfully.")
# --- Load TTS FastSpeech2 models ---
logger.info("Loading TTS models...")
tts_models['female'] = malaya_speech.tts.fastspeech2(model='yasmin')
tts_models['male'] = malaya_speech.tts.fastspeech2(model='osman')
logger.info("TTS models loaded successfully.")
logger.info("Loading Vocoder models...")
vocoder_models['female'] = malaya_speech.vocoder.melgan(model='yasmin')
vocoder_models['male'] = malaya_speech.vocoder.melgan(model='osman')
logger.info("Vocoder models loaded successfully.")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events. Loads ML models on startup.
"""
logger.info("Application startup: Loading ML models...")
loop = asyncio.get_event_loop()
# Run the synchronous, CPU-bound model loading function in a separate thread
await loop.run_in_executor(None, load_model)
logger.info("ML models loaded successfully. Application is ready.")
# This 'yield' is the point where the application starts accepting requests.
# It will only be reached AFTER the models are loaded.
yield
# --- Code to run on application shutdown (optional) ---
logger.info("Application shutdown: Clearing models...")
stt_models.clear()
tts_models.clear()
vocoder_models.clear()
# --- Initialize FastAPI app ---
app = FastAPI(
title="Malaya Speech Service",
description="A service for Text-to-Speech and Speech-to-Text using Malaya-Speech.",
version="1.0.0",
lifespan=lifespan
)
# --- Health Check Endpoint ---
@app.get("/")
def read_root():
return {"status": "Malaya Speech Service is running"}
# --- Request Interceptor Middleware Setup ---
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
trace_id = request.headers.get("X-Trace-Id")
if not trace_id:
trace_id = str(uuid.uuid4())
token = trace_id_var.set(trace_id)
logger.info(f"Request started: {request.method} {request.url.path}")
try:
response = await call_next(request)
response.headers["X-Trace-Id"] = trace_id
logger.info(f"Request finished with status: {response.status_code}")
return response
finally:
trace_id_var.reset(token)
# --- Text-to-Speech Endpoint ---
@app.post("/tts", response_class=StreamingResponse)
async def text_to_speech(request: TTSRequest):
"""
Converts text to speech using a gender-specific model and returns the audio as a WAV file.
"""
try:
logger.info(f"Received TTS request for gender: '{request.gender}' with text: '{request.text[:30]}'")
tts = tts_models[request.gender]
vocoder = vocoder_models[request.gender]
mel_spectrogram = tts.predict(request.text)
y_ = vocoder(mel_spectrogram['mel-output'])
# Write audio data to an in-memory buffer as a WAV file
buffer = io.BytesIO()
sf.write(buffer, y_, samplerate=22050, format="WAV")
# Move buffer's cursor back to beginning to be read by response
buffer.seek(0)
# Stream content back to client
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
logger.error(f"An error occurred during TTS processing: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Failed to process text-to-speech request.")
# --- Speech-to-Text Endpoint ---
@app.post("/stt", response_model=STTResponse)
async def speech_to_text(file: UploadFile = File(...)):
"""
Converts speech from an audio file to text using the Whisper model.
"""
try:
if not file.content_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")
logger.info(f"Received STT request for file: {file.filename}")
# Read uploaded file content into memory
audio_bytes = await file.read()
# Load audio data and its original sampling rate
# 'y' will be audio waveform (numpy array), 'sr' will be sample rate
y, sr = sf.read(io.BytesIO(audio_bytes))
# Resample audio to 16,000 Hz, which is what Whisper expects.
if sr != 16000:
y = librosa.resample(y=y, orig_sr=sr, target_sr=16000)
# Ensure audio is mono (single channel)
if len(y.shape) > 1:
y = librosa.to_mono(y)
processor = stt_models['processor']
model = stt_models['model']
# Process the audio array to get the input features
inputs = processor(y, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features
attention_mask = torch.ones(input_features.shape, dtype=torch.long)
if torch.cuda.is_available():
input_features = input_features.to("cuda:0")
attention_mask = attention_mask.to("cuda:0")
predicted_ids = model.generate(
input_features,
attention_mask=attention_mask,
language='ms'
)
# Decode the token IDs to text
# Using batch_decode is preferred as it handles batches and cleans up special tokens.
transcribed_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"Whisper transcription result: '{transcribed_text}'")
return STTResponse(text=transcribed_text)
except Exception as e:
print(f"Error during STT with Whisper: {e}")
# Be careful not to expose too much detail in production errors
raise HTTPException(status_code=500, detail=f"An error occurred during transcription: {e}")
# --- Speech-to-Text Endpoint ---
# @app.post("/stt", response_model=STTResponse)
# async def speech_to_text(file: UploadFile = File(...)):
# """
# Converts speech from an audio file to text.
# """
# try:
# if not file.content_type.startswith("audio/"):
# raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")
#
# print(f"Received STT request for file: {file.filename}")
# # Read the audio file content
# audio_bytes = await file.read()
#
# y, sr = sf.read(io.BytesIO(audio_bytes))
# transcribed_text = small_model.beam_decoder([y])[0]
#
# print(f"Transcription result: '{transcribed_text}'")
# return STTResponse(text=transcribed_text)
# except Exception as e:
# print(f"Error during STT: {e}")
# raise HTTPException(status_code=500, detail=str(e))
# # --- Text-to-Speech Endpoint ---
# @app.post("/tts", response_class=StreamingResponse)
# async def text_to_speech(request: TTSRequest):
# """
# Converts text to speech and returns the audio as a WAV file.
# """
# try:
# logger.info(f"Received TTS request: '{request}'")
#
# mel_spectrogram = tts_model.predict(request.text)
# y_ = vocoder_female(mel_spectrogram['mel-output'])
# buffer = io.BytesIO()
# sf.write(buffer, y_, samplerate=22050, format="WAV")
# # Move the cursor to the beginning
# buffer.seek(0)
# return StreamingResponse(buffer, media_type="audio/wav")
#
# except Exception as e:
# print(f"Error during TTS: {e}")
# raise HTTPException(status_code=500, detail=str(e))