import base64 import gc import logging import os import tempfile from concurrent.futures import ThreadPoolExecutor from typing import List import torch from faster_whisper import WhisperModel from pydantic import ValidationError from starlette.exceptions import HTTPException from alignment import load_align_model, align from config import InferenceConfig, model_settings from diarize import DiarizationPipeline, assign_word_speakers from schema import SingleSegment from utils import load_audio # Get current LD_LIBRARY_PATH original = os.environ.get("LD_LIBRARY_PATH", "") cudnn_path = "/opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib/" os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path logger = logging.getLogger(__name__) class EndpointHandler(): def __init__(self, path=""): device = "cuda" self.asr_pipeline = WhisperModel( model_settings.asr_model, device=device, compute_type="float16", download_root="cache" ) model_a, metadata = load_align_model( language_code=model_settings.language, device=device, model_name=model_settings.align_model, model_dir="cache", ) self.align_model = model_a self.align_metadata = metadata self.diarize_model = DiarizationPipeline( token=model_settings.hf_token, device=device ) def __call__(self, inputs): file = inputs.pop("inputs") file = base64.b64decode(file) parameters = inputs.pop("parameters", {}) try: parameters = InferenceConfig(**parameters) except ValidationError as e: logger.error(f"Error validating parameters: {e}") raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}") logger.info(f"inference parameters: {parameters}") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(file) tmp_path = tmp.name audio = load_audio(tmp_path) os.remove(tmp_path) def run_asr(_audio): try: _segments, _info = self.asr_pipeline.transcribe( _audio, language=model_settings.language, condition_on_previous_text=False, word_timestamps=False, vad_filter=parameters.vad_filter ) align_segments: List[SingleSegment] = [] for seg in _segments: align_segments.append( { "start": seg.start, "end": seg.end, "text": seg.text, } ) return align_segments, _info except RuntimeError as e: logger.error(f"ASR inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error diring ASR inference: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}") def run_alignment(_segments: List[SingleSegment], _audio): stream = torch.cuda.Stream() with torch.cuda.stream(stream): return align( _segments, self.align_model, self.align_metadata, _audio, "cuda", ) def run_diarization(_audio): if not self.diarize_model: return None stream = torch.cuda.Stream() with torch.cuda.stream(stream): try: diarize_segments, _embeddings = self.diarize_model( _audio, min_speakers=parameters.min_speakers, max_speakers=parameters.max_speakers, num_speakers=parameters.num_speakers, return_embeddings=True ) return diarize_segments, _embeddings except RuntimeError as e: logger.error(f"Diarization inference error: {str(e)}") raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error during diarization: {str(e)}") raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") # 1. Run ASR Sequentially (Heaviest operation) segments, info = run_asr(audio) # 2. Clear VRAM to make room for parallel execution gc.collect() torch.cuda.empty_cache() with ThreadPoolExecutor() as executor: align_future = executor.submit(run_alignment, segments, audio) diarization_future = executor.submit(run_diarization, audio) _result = align_future.result() diarization_output, embeddings = diarization_future.result() result = [] if diarization_output is not None and _result: result = assign_word_speakers( diarization_output, _result, embeddings ) # Final cleanup del diarization_output, segments, audio gc.collect() torch.cuda.empty_cache() diarization = [] if result: diarization = [f'{seg.get("speaker", "UNKNOWN")}: {seg.get("text", "").strip()}' for seg in result.get("segments", [])] return { "result": result["segments"], "full_transcription": " ".join([seg.get("text", "").strip() for seg in result["segments"]]), "diarization": diarization, "asr_model": model_settings.asr_model, "speaker_embeddings": embeddings }