import logging import torch import os import base64 from pyannote.audio import Pipeline from transformers import pipeline, AutoModelForCausalLM from huggingface_hub import HfApi from pydantic import ValidationError logger = logging.getLogger(__name__) class AudioHandler: def __init__(self, model_settings): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") logger.info(f"Using device: {device.type}") torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 self.device = device self.torch_dtype = torch_dtype # Load assistant model self.assistant_model = ( AutoModelForCausalLM.from_pretrained( model_settings.assistant_model, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ).to(device) if model_settings.assistant_model else None ) # Load ASR pipeline self.asr_pipeline = pipeline( "automatic-speech-recognition", model=model_settings.asr_model, torch_dtype=torch_dtype, device=device ) # Load diarization pipeline if available if model_settings.diarization_model: HfApi().whoami(model_settings.hf_token) self.diarization_pipeline = Pipeline.from_pretrained( checkpoint_path=model_settings.diarization_model, use_auth_token=model_settings.hf_token, ).to(device) else: self.diarization_pipeline = None def run_asr(self, file, parameters): """Run Automatic Speech Recognition (ASR)""" generate_kwargs = { "task": parameters.task, "language": parameters.language, "assistant_model": self.assistant_model if parameters.assisted else None } return self.asr_pipeline( file, chunk_length_s=parameters.chunk_length_s, batch_size=parameters.batch_size, generate_kwargs=generate_kwargs, return_timestamps=True, ) def run_diarization(self, file, parameters, asr_outputs): """Run Diarization if available""" if not self.diarization_pipeline: return [] # Replace with actual diarization logic if required return diarize(self.diarization_pipeline, file, parameters, asr_outputs) def run_inference(self, file: bytes, parameters): """Run the complete inference process""" try: logger.info(f"Inference parameters: {parameters}") asr_outputs = self.run_asr(file, parameters) except RuntimeError as e: logger.error(f"ASR inference error: {str(e)}") raise RuntimeError(f"ASR inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error during ASR inference: {str(e)}") raise RuntimeError(f"Unknown error during ASR inference: {str(e)}") try: transcript = self.run_diarization(file, parameters, asr_outputs) except RuntimeError as e: logger.error(f"Diarization inference error: {str(e)}") raise RuntimeError(f"Diarization inference error: {str(e)}") except Exception as e: logger.error(f"Unknown error during diarization: {str(e)}") raise RuntimeError(f"Unknown error during diarization: {str(e)}") return { "speakers": transcript, "chunks": asr_outputs["chunks"], "text": asr_outputs["text"], }