audio-speaker-diarization / AudioHandler.py
Gokulavelan's picture
files push
ea5aa75
import logging
import torch
import os
import base64
from pyannote.audio import Pipeline
from transformers import pipeline, AutoModelForCausalLM
from huggingface_hub import HfApi
from pydantic import ValidationError
logger = logging.getLogger(__name__)
class AudioHandler:
def __init__(self, model_settings):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logger.info(f"Using device: {device.type}")
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
self.device = device
self.torch_dtype = torch_dtype
# Load assistant model
self.assistant_model = (
AutoModelForCausalLM.from_pretrained(
model_settings.assistant_model,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
).to(device)
if model_settings.assistant_model
else None
)
# Load ASR pipeline
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_settings.asr_model,
torch_dtype=torch_dtype,
device=device
)
# Load diarization pipeline if available
if model_settings.diarization_model:
HfApi().whoami(model_settings.hf_token)
self.diarization_pipeline = Pipeline.from_pretrained(
checkpoint_path=model_settings.diarization_model,
use_auth_token=model_settings.hf_token,
).to(device)
else:
self.diarization_pipeline = None
def run_asr(self, file, parameters):
"""Run Automatic Speech Recognition (ASR)"""
generate_kwargs = {
"task": parameters.task,
"language": parameters.language,
"assistant_model": self.assistant_model if parameters.assisted else None
}
return self.asr_pipeline(
file,
chunk_length_s=parameters.chunk_length_s,
batch_size=parameters.batch_size,
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
def run_diarization(self, file, parameters, asr_outputs):
"""Run Diarization if available"""
if not self.diarization_pipeline:
return []
# Replace with actual diarization logic if required
return diarize(self.diarization_pipeline, file, parameters, asr_outputs)
def run_inference(self, file: bytes, parameters):
"""Run the complete inference process"""
try:
logger.info(f"Inference parameters: {parameters}")
asr_outputs = self.run_asr(file, parameters)
except RuntimeError as e:
logger.error(f"ASR inference error: {str(e)}")
raise RuntimeError(f"ASR inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during ASR inference: {str(e)}")
raise RuntimeError(f"Unknown error during ASR inference: {str(e)}")
try:
transcript = self.run_diarization(file, parameters, asr_outputs)
except RuntimeError as e:
logger.error(f"Diarization inference error: {str(e)}")
raise RuntimeError(f"Diarization inference error: {str(e)}")
except Exception as e:
logger.error(f"Unknown error during diarization: {str(e)}")
raise RuntimeError(f"Unknown error during diarization: {str(e)}")
return {
"speakers": transcript,
"chunks": asr_outputs["chunks"],
"text": asr_outputs["text"],
}