File size: 6,131 Bytes
994a3aa 21d573e 994a3aa 21d573e 994a3aa e6b4d67 994a3aa bb2e16d 994a3aa bb2e16d bfbc67e bb2e16d 994a3aa bb2e16d 994a3aa bb2e16d 994a3aa 303bfa4 994a3aa 21d573e 994a3aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import base64
import gc
import logging
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from typing import List
import torch
from faster_whisper import WhisperModel
from pydantic import ValidationError
from starlette.exceptions import HTTPException
from alignment import load_align_model, align
from config import InferenceConfig, model_settings
from diarize import DiarizationPipeline, assign_word_speakers
from schema import SingleSegment
from utils import load_audio
# Get current LD_LIBRARY_PATH
original = os.environ.get("LD_LIBRARY_PATH", "")
cudnn_path = "/opt/conda/lib/python3.11/site-packages/nvidia/cudnn/lib/"
os.environ['LD_LIBRARY_PATH'] = original + ":" + cudnn_path
logger = logging.getLogger(__name__)
class EndpointHandler():
def __init__(self, path=""):
device = "cuda"
self.asr_pipeline = WhisperModel(
model_settings.asr_model,
device=device,
compute_type="float16",
download_root="cache"
)
model_a, metadata = load_align_model(
language_code=model_settings.language,
device=device,
model_name=model_settings.align_model,
model_dir="cache",
)
self.align_model = model_a
self.align_metadata = metadata
self.diarize_model = DiarizationPipeline(
token=model_settings.hf_token,
device=device
)
def __call__(self, inputs):
file = inputs.pop("inputs")
file = base64.b64decode(file)
parameters = inputs.pop("parameters", {})
try:
parameters = InferenceConfig(**parameters)
except ValidationError as e:
logger.error(f"Error validating parameters: {e}")
raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
logger.info(f"inference parameters: {parameters}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(file)
tmp_path = tmp.name
audio = load_audio(tmp_path)
os.remove(tmp_path)
def run_asr(_audio):
try:
_segments, _info = self.asr_pipeline.transcribe(
_audio,
language=model_settings.language,
condition_on_previous_text=False,
word_timestamps=False,
vad_filter=parameters.vad_filter
)
align_segments: List[SingleSegment] = []
for seg in _segments:
align_segments.append(
{
"start": seg.start,
"end": seg.end,
"text": seg.text,
}
)
return align_segments, _info
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error diring ASR inference: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")
def run_alignment(_segments: List[SingleSegment], _audio):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return align(
_segments,
self.align_model,
self.align_metadata,
_audio,
"cuda",
)
def run_diarization(_audio):
if not self.diarize_model:
return None
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
try:
diarize_segments, _embeddings = self.diarize_model(
_audio,
min_speakers=parameters.min_speakers,
max_speakers=parameters.max_speakers,
num_speakers=parameters.num_speakers,
return_embeddings=True
)
return diarize_segments, _embeddings
except RuntimeError as e:
logger.error(f"Diarization inference error: {str(e)}")
raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during diarization: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
# 1. Run ASR Sequentially (Heaviest operation)
segments, info = run_asr(audio)
# 2. Clear VRAM to make room for parallel execution
gc.collect()
torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
align_future = executor.submit(run_alignment, segments, audio)
diarization_future = executor.submit(run_diarization, audio)
_result = align_future.result()
diarization_output, embeddings = diarization_future.result()
result = []
if diarization_output is not None and _result:
result = assign_word_speakers(
diarization_output,
_result,
embeddings
)
# Final cleanup
del diarization_output, segments, audio
gc.collect()
torch.cuda.empty_cache()
diarization = []
if result:
diarization = [f'{seg.get("speaker", "UNKNOWN")}: {seg.get("text", "").strip()}' for seg in
result.get("segments", [])]
return {
"result": result["segments"],
"full_transcription": " ".join([seg.get("text", "").strip() for seg in result["segments"]]),
"diarization": diarization,
"asr_model": model_settings.asr_model,
"speaker_embeddings": embeddings
}
|