| import os |
| from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security, Form |
| from fastapi.security.api_key import APIKeyHeader, APIKey |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel |
| from typing import Optional |
| import numpy as np |
| import io |
| import soundfile as sf |
| import base64 |
| import logging |
| import torch |
| import librosa |
| from pathlib import Path |
| from pydub import AudioSegment |
| from moviepy.editor import VideoFileClip |
| import traceback |
| from logging.handlers import RotatingFileHandler |
| import boto3 |
| from botocore.exceptions import NoCredentialsError |
| import time |
| import tempfile |
| import magic |
|
|
| |
| from asr import transcribe, ASR_LANGUAGES, ASR_SAMPLING_RATE |
| from tts import synthesize, TTS_LANGUAGES |
| from lid import identify |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) |
| file_handler.setLevel(logging.INFO) |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| file_handler.setFormatter(formatter) |
| logger.addHandler(file_handler) |
|
|
| app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") |
|
|
| |
| S3_BUCKET = os.environ.get("S3_BUCKET") |
| S3_REGION = os.environ.get("S3_REGION") |
| S3_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID") |
| S3_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") |
|
|
| |
| API_KEY = os.environ.get("API_KEY") |
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
| |
| s3_client = boto3.client( |
| 's3', |
| aws_access_key_id=S3_ACCESS_KEY_ID, |
| aws_secret_access_key=S3_SECRET_ACCESS_KEY, |
| region_name=S3_REGION |
| ) |
|
|
| |
| class AudioRequest(BaseModel): |
| audio: str |
| language: Optional[str] = None |
|
|
| class TTSRequest(BaseModel): |
| text: str |
| language: Optional[str] = None |
| speed: float = 1.0 |
|
|
| class LanguageRequest(BaseModel): |
| language: Optional[str] = None |
|
|
| async def get_api_key(api_key_header: str = Security(api_key_header)): |
| if api_key_header == API_KEY: |
| return api_key_header |
| raise HTTPException(status_code=403, detail="Could not validate credentials") |
|
|
| def extract_audio_from_file(input_bytes): |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_file: |
| temp_file.write(input_bytes) |
| temp_file_path = temp_file.name |
|
|
| try: |
| |
| file_info = magic.from_file(temp_file_path, mime=True) |
| logger.info(f"Received file of type: {file_info}") |
|
|
| |
| try: |
| audio_array, sample_rate = sf.read(temp_file_path) |
| logger.info(f"Successfully read audio with soundfile. Shape: {audio_array.shape}, Sample rate: {sample_rate}") |
| return audio_array, sample_rate |
| except Exception as e: |
| logger.info(f"Could not read with soundfile: {str(e)}") |
|
|
| |
| try: |
| video = VideoFileClip(temp_file_path) |
| audio = video.audio |
| if audio is not None: |
| audio_array = audio.to_soundarray() |
| sample_rate = audio.fps |
| audio_array = audio_array.mean(axis=1) if len(audio_array.shape) > 1 and audio_array.shape[1] > 1 else audio_array |
| audio_array = audio_array.astype(np.float32) |
| audio_array /= np.max(np.abs(audio_array)) |
| video.close() |
| logger.info(f"Successfully extracted audio from video. Shape: {audio_array.shape}, Sample rate: {sample_rate}") |
| return audio_array, sample_rate |
| else: |
| logger.info("Video file contains no audio") |
| except Exception as e: |
| logger.info(f"Could not read as video: {str(e)}") |
|
|
| |
| try: |
| audio = AudioSegment.from_file(temp_file_path) |
| audio_array = np.array(audio.get_array_of_samples()) |
| audio_array = audio_array.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**7) |
| audio_array = audio_array.reshape((-1, 2)).mean(axis=1) if audio.channels == 2 else audio_array |
| logger.info(f"Successfully read audio with pydub. Shape: {audio_array.shape}, Sample rate: {audio.frame_rate}") |
| return audio_array, audio.frame_rate |
| except Exception as e: |
| logger.info(f"Could not read with pydub: {str(e)}") |
|
|
| raise ValueError(f"Unsupported file format: {file_info}") |
| finally: |
| os.unlink(temp_file_path) |
|
|
| @app.post("/transcribe") |
| async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_api_key)): |
| start_time = time.time() |
| try: |
| input_bytes = base64.b64decode(request.audio) |
| audio_array, sample_rate = extract_audio_from_file(input_bytes) |
|
|
| |
| audio_array = audio_array.astype(np.float32) |
|
|
| |
| if sample_rate != ASR_SAMPLING_RATE: |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) |
|
|
| if request.language is None: |
| |
| identified_language = identify(audio_array) |
| result = transcribe(audio_array, identified_language) |
| else: |
| result = transcribe(audio_array, request.language) |
|
|
| processing_time = time.time() - start_time |
| return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/transcribe_file") |
| async def transcribe_audio_file( |
| file: UploadFile = File(...), |
| language: Optional[str] = Form(None), |
| api_key: APIKey = Depends(get_api_key) |
| ): |
| start_time = time.time() |
| try: |
| contents = await file.read() |
| audio_array, sample_rate = extract_audio_from_file(contents) |
|
|
| |
| audio_array = audio_array.astype(np.float32) |
|
|
| |
| if sample_rate != ASR_SAMPLING_RATE: |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) |
|
|
| if language is None: |
| |
| identified_language = identify(audio_array) |
| result = transcribe(audio_array, identified_language) |
| else: |
| result = transcribe(audio_array, language) |
|
|
| processing_time = time.time() - start_time |
| return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in transcribe_audio_file: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred during transcription", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/synthesize") |
| async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_api_key)): |
| start_time = time.time() |
| logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") |
| try: |
| if request.language is None: |
| |
| lang_code = "eng" |
| else: |
| |
| lang_code = request.language.split()[0].strip() |
| |
| |
| if not request.text: |
| raise ValueError("Text cannot be empty") |
| if lang_code not in TTS_LANGUAGES: |
| raise ValueError(f"Unsupported language: {lang_code}") |
| if not 0.5 <= request.speed <= 2.0: |
| raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") |
| |
| logger.info(f"Calling synthesize function with lang_code: {lang_code}") |
| result, filtered_text = synthesize(request.text, lang_code, request.speed) |
| logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") |
| |
| if result is None: |
| logger.error("Synthesize function returned None") |
| raise ValueError("Synthesis failed to produce audio") |
| |
| sample_rate, audio = result |
| logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") |
| |
| logger.info("Converting audio to numpy array") |
| audio = np.array(audio, dtype=np.float32) |
| logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") |
| |
| logger.info("Normalizing audio") |
| max_value = np.max(np.abs(audio)) |
| if max_value == 0: |
| logger.warning("Audio array is all zeros") |
| raise ValueError("Generated audio is silent (all zeros)") |
| audio = audio / max_value |
| logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") |
| |
| logger.info("Converting to int16") |
| audio = (audio * 32767).astype(np.int16) |
| logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") |
| |
| logger.info("Writing audio to buffer") |
| buffer = io.BytesIO() |
| sf.write(buffer, audio, sample_rate, format='wav') |
| buffer.seek(0) |
| logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") |
| |
| |
| filename = f"synthesized_audio_{int(time.time())}.wav" |
| |
| |
| try: |
| s3_client.upload_fileobj( |
| buffer, |
| S3_BUCKET, |
| filename, |
| ExtraArgs={'ContentType': 'audio/wav'} |
| ) |
| logger.info(f"File uploaded successfully to S3: {filename}") |
| |
| |
| url = f"https://s3.{S3_REGION}.amazonaws.com/{S3_BUCKET}/{filename}" |
| logger.info(f"Public URL generated: {url}") |
| |
| processing_time = time.time() - start_time |
| return JSONResponse(content={"audio_url": url, "processing_time_seconds": processing_time}) |
| |
| except NoCredentialsError: |
| logger.error("AWS credentials not available or invalid") |
| raise HTTPException(status_code=500, detail="Could not upload file to S3: Missing or invalid credentials") |
| except Exception as e: |
| logger.error(f"Failed to upload to S3: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Could not upload file to S3: {str(e)}") |
|
|
| except ValueError as ve: |
| logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=400, |
| content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time} |
| ) |
| except Exception as e: |
| logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "type": type(e).__name__, |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/identify") |
| async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)): |
| start_time = time.time() |
| try: |
| input_bytes = base64.b64decode(request.audio) |
| audio_array, sample_rate = extract_audio_from_file(input_bytes) |
| result = identify(audio_array) |
| processing_time = time.time() - start_time |
| return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in identify_language: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/identify_file") |
| async def identify_language_file( |
| file: UploadFile = File(...), |
| api_key: APIKey = Depends(get_api_key) |
| ): |
| start_time = time.time() |
| try: |
| contents = await file.read() |
| audio_array, sample_rate = extract_audio_from_file(contents) |
| result = identify(audio_array) |
| processing_time = time.time() - start_time |
| return JSONResponse(content={"language_identification": result, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in identify_language_file: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred during language identification", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/asr_languages") |
| async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)): |
| start_time = time.time() |
| try: |
| if request.language is None or request.language == "": |
| |
| matching_languages = ASR_LANGUAGES |
| else: |
| matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())] |
| |
| processing_time = time.time() - start_time |
| return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred while fetching ASR languages", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.post("/tts_languages") |
| async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)): |
| start_time = time.time() |
| try: |
| if request.language is None or request.language == "": |
| |
| matching_languages = TTS_LANGUAGES |
| else: |
| matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())] |
| |
| processing_time = time.time() - start_time |
| return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time}) |
| except Exception as e: |
| logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) |
| error_details = { |
| "error": str(e), |
| "traceback": traceback.format_exc() |
| } |
| processing_time = time.time() - start_time |
| return JSONResponse( |
| status_code=500, |
| content={"message": "An error occurred while fetching TTS languages", "details": error_details, "processing_time_seconds": processing_time} |
| ) |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "ok"} |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "Welcome to the MMS Speech Technology API", |
| "version": "1.0", |
| "endpoints": [ |
| "/transcribe", |
| "/transcribe_file", |
| "/synthesize", |
| "/identify", |
| "/identify_file", |
| "/asr_languages", |
| "/tts_languages", |
| "/health" |
| ] |
| } |
|
|