diarization-chunks / handler.py
erik-svensson-cm's picture
Update handler.py
e6b4d67 verified
import base64
import gc
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from typing import List
import torch
from faster_whisper import WhisperModel
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from alignment import load_align_model, align
from config import InferenceConfig, model_settings
from diarize import DiarizationPipeline, assign_word_speakers
from schema import SingleSegment
from utils import load_audio
# Get current LD_LIBRARY_PATH
original = os.environ.get("LD_LIBRARY_PATH", "")
cudnn_path = "/opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib/"
os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
device = "cuda"
self.asr_pipeline = WhisperModel(
model_settings.asr_model,
device=device,
compute_type="float16",
download_root="cache"
)
model_a, metadata = load_align_model(
language_code=model_settings.language,
device=device,
model_name=model_settings.align_model,
model_dir="cache",
)
self.align_model = model_a
self.align_metadata = metadata
self.diarize_model = DiarizationPipeline(
token=model_settings.hf_token,
device=device
)
def __call__(self, inputs):
file = inputs.pop("inputs")
file = base64.b64decode(file)
parameters = inputs.pop("parameters", {})
try:
parameters = InferenceConfig(**parameters)
except ValidationError as e:
logger.error(f"Error validating parameters: {e}")
raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
logger.info(f"inference parameters: {parameters}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(file)
tmp_path = tmp.name
audio = load_audio(tmp_path)
os.remove(tmp_path)
def run_asr(_audio):
try:
_segments, _info = self.asr_pipeline.transcribe(
_audio,
language=model_settings.language,
condition_on_previous_text=False,
word_timestamps=False,
vad_filter=parameters.vad_filter
)
align_segments: List[SingleSegment] = []
for seg in _segments:
align_segments.append(
{
"start": seg.start,
"end": seg.end,
"text": seg.text,
}
)
return align_segments, _info
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error diring ASR inference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")
def run_alignment(_segments: List[SingleSegment], _audio):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return align(
_segments,
self.align_model,
self.align_metadata,
_audio,
"cuda",
)
def run_diarization(_audio):
if not self.diarize_model:
return None
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
try:
diarize_segments, _embeddings = self.diarize_model(
_audio,
min_speakers=parameters.min_speakers,
max_speakers=parameters.max_speakers,
num_speakers=parameters.num_speakers,
return_embeddings=True
)
return diarize_segments, _embeddings
except RuntimeError as e:
logger.error(f"Diarization inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during diarization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
# 1. Run ASR Sequentially (Heaviest operation)
segments, info = run_asr(audio)
# 2. Clear VRAM to make room for parallel execution
gc.collect()
torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
align_future = executor.submit(run_alignment, segments, audio)
diarization_future = executor.submit(run_diarization, audio)
_result = align_future.result()
diarization_output, embeddings = diarization_future.result()
result = []
if diarization_output is not None and _result:
result = assign_word_speakers(
diarization_output,
_result,
embeddings
)
# Final cleanup
del diarization_output, segments, audio
gc.collect()
torch.cuda.empty_cache()
diarization = []
if result:
diarization = [f'{seg.get("speaker", "UNKNOWN")}: {seg.get("text", "").strip()}' for seg in
result.get("segments", [])]
return {
"result": result["segments"],
"full_transcription": " ".join([seg.get("text", "").strip() for seg in result["segments"]]),
"diarization": diarization,
"asr_model": model_settings.asr_model,
"speaker_embeddings": embeddings
}