| import base64 |
| import gc |
| import logging |
| import os |
| import tempfile |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import List |
|
|
| import torch |
| from faster_whisper import WhisperModel |
| from pydantic import ValidationError |
| from starlette.exceptions import HTTPException |
|
|
| from alignment import load_align_model, align |
| from config import InferenceConfig, model_settings |
| from diarize import DiarizationPipeline, assign_word_speakers |
| from schema import SingleSegment |
| from utils import load_audio |
|
|
| |
| original = os.environ.get("LD_LIBRARY_PATH", "") |
|
|
| cudnn_path = "/opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib/" |
| os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| device = "cuda" |
| self.asr_pipeline = WhisperModel( |
| model_settings.asr_model, |
| device=device, |
| compute_type="float16", |
| download_root="cache" |
| ) |
|
|
| model_a, metadata = load_align_model( |
| language_code=model_settings.language, |
| device=device, |
| model_name=model_settings.align_model, |
| model_dir="cache", |
| ) |
| self.align_model = model_a |
| self.align_metadata = metadata |
|
|
| self.diarize_model = DiarizationPipeline( |
| token=model_settings.hf_token, |
| device=device |
| ) |
|
|
| def __call__(self, inputs): |
| file = inputs.pop("inputs") |
| file = base64.b64decode(file) |
| parameters = inputs.pop("parameters", {}) |
| try: |
| parameters = InferenceConfig(**parameters) |
| except ValidationError as e: |
| logger.error(f"Error validating parameters: {e}") |
| raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}") |
|
|
| logger.info(f"inference parameters: {parameters}") |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(file) |
| tmp_path = tmp.name |
| audio = load_audio(tmp_path) |
| os.remove(tmp_path) |
|
|
| def run_asr(_audio): |
| try: |
| _segments, _info = self.asr_pipeline.transcribe( |
| _audio, |
| language=model_settings.language, |
| condition_on_previous_text=False, |
| word_timestamps=False, |
| vad_filter=parameters.vad_filter |
| ) |
| align_segments: List[SingleSegment] = [] |
| for seg in _segments: |
| align_segments.append( |
| { |
| "start": seg.start, |
| "end": seg.end, |
| "text": seg.text, |
| } |
| ) |
| return align_segments, _info |
| except RuntimeError as e: |
| logger.error(f"ASR inference error: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}") |
| except Exception as e: |
| logger.error(f"Unknown error diring ASR inference: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}") |
|
|
| def run_alignment(_segments: List[SingleSegment], _audio): |
| stream = torch.cuda.Stream() |
| with torch.cuda.stream(stream): |
| return align( |
| _segments, |
| self.align_model, |
| self.align_metadata, |
| _audio, |
| "cuda", |
| ) |
|
|
| def run_diarization(_audio): |
| if not self.diarize_model: |
| return None |
| stream = torch.cuda.Stream() |
| with torch.cuda.stream(stream): |
| try: |
| diarize_segments, _embeddings = self.diarize_model( |
| _audio, |
| min_speakers=parameters.min_speakers, |
| max_speakers=parameters.max_speakers, |
| num_speakers=parameters.num_speakers, |
| return_embeddings=True |
| ) |
| return diarize_segments, _embeddings |
| except RuntimeError as e: |
| logger.error(f"Diarization inference error: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}") |
| except Exception as e: |
| logger.error(f"Unknown error during diarization: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}") |
|
|
| |
| segments, info = run_asr(audio) |
|
|
| |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| with ThreadPoolExecutor() as executor: |
| align_future = executor.submit(run_alignment, segments, audio) |
| diarization_future = executor.submit(run_diarization, audio) |
|
|
| _result = align_future.result() |
| diarization_output, embeddings = diarization_future.result() |
| result = [] |
| if diarization_output is not None and _result: |
| result = assign_word_speakers( |
| diarization_output, |
| _result, |
| embeddings |
| ) |
| |
| del diarization_output, segments, audio |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| diarization = [] |
|
|
| if result: |
| diarization = [f'{seg.get("speaker", "UNKNOWN")}: {seg.get("text", "").strip()}' for seg in |
| result.get("segments", [])] |
|
|
| return { |
| "result": result["segments"], |
| "full_transcription": " ".join([seg.get("text", "").strip() for seg in result["segments"]]), |
| "diarization": diarization, |
| "asr_model": model_settings.asr_model, |
| "speaker_embeddings": embeddings |
| } |
|
|