Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import whisper | |
| from pyannote.audio import Pipeline | |
| import torch | |
| from tqdm import tqdm | |
| from time import time | |
| from transformers import pipeline | |
| from .transcription import Transcription | |
| from .audio_processing import AudioProcessor | |
| import io | |
| from contextlib import redirect_stdout | |
| import sys | |
| load_dotenv() | |
| class Transcriptor: | |
| """ | |
| A class for transcribing and diarizing audio files. | |
| This class uses the Whisper model for transcription and the PyAnnote speaker diarization pipeline for speaker identification. | |
| Attributes | |
| ---------- | |
| model_size : str | |
| The size of the Whisper model to use for transcription. Available options are: | |
| - 'tiny': Fastest, lowest accuracy | |
| - 'base': Fast, good accuracy for many use cases | |
| - 'small': Balanced speed and accuracy | |
| - 'medium': High accuracy, slower than smaller models | |
| - 'large-v3': Latest and most accurate version of the large model | |
| - 'large-v3-turbo': Optimized version of the large-v3 model for faster processing | |
| model : whisper.model.Whisper | |
| The Whisper model for transcription. | |
| pipeline : pyannote.audio.pipelines.SpeakerDiarization | |
| The PyAnnote speaker diarization pipeline. | |
| Usage: | |
| >>> transcript = Transcriptor(model_size="large-v3") | |
| >>> transcription = transcript.transcribe_audio("/path/to/audio.wav") | |
| >>> transcription.get_name_speakers() | |
| >>> transcription.save("/path/to/transcripts") | |
| Note: | |
| Larger models, especially 'large-v3', provide higher accuracy but require more | |
| computational resources and may be slower to process audio. | |
| """ | |
| def __init__(self, model_size: str = "base"): | |
| self.model_size = model_size | |
| self.HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not self.HF_TOKEN: | |
| raise ValueError("HF_TOKEN not found. Please store token in .env") | |
| self._setup() | |
| def _setup(self): | |
| """Initialize the Whisper model and diarization pipeline.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| print("Initializing Whisper model...") | |
| if self.model_size == "large-v3-turbo": | |
| self.model = pipeline( | |
| task="automatic-speech-recognition", | |
| model="ylacombe/whisper-large-v3-turbo", | |
| chunk_length_s=30, | |
| device=self.device, | |
| ) | |
| else: | |
| self.model = whisper.load_model(self.model_size, device=self.device) | |
| print("Building diarization pipeline...") | |
| self.pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=self.HF_TOKEN | |
| ).to(torch.device(self.device)) | |
| print("Setup completed successfully!") | |
| def transcribe_audio(self, audio_file_path: str, enhanced: bool = False, buffer_logs: bool = False): | |
| """ | |
| Transcribe an audio file. | |
| Parameters: | |
| ----------- | |
| audio_file_path : str | |
| Path to the audio file to be transcribed. | |
| enhanced : bool, optional | |
| If True, applies audio enhancement techniques to improve transcription quality. | |
| buffer_logs : bool, optional | |
| If True, captures logs and returns them with the transcription. If False, prints to terminal. | |
| Returns: | |
| -------- | |
| Union[Transcription, Tuple[Transcription, str]] | |
| Returns either just the Transcription object (if buffer_logs=False) | |
| or a tuple of (Transcription, logs string) if buffer_logs=True | |
| """ | |
| if buffer_logs: | |
| logs_buffer = io.StringIO() | |
| with redirect_stdout(logs_buffer): | |
| transcription = self._perform_transcription(audio_file_path, enhanced) | |
| logs = logs_buffer.getvalue() | |
| return transcription, logs | |
| else: | |
| transcription = self._perform_transcription(audio_file_path, enhanced) | |
| return transcription | |
| def _perform_transcription(self, audio_file_path: str, enhanced: bool = False): | |
| """Internal method to handle the actual transcription process.""" | |
| try: | |
| print(f"Received audio_file_path: {audio_file_path}") | |
| print(f"Type of audio_file_path: {type(audio_file_path)}") | |
| if audio_file_path is None: | |
| raise ValueError("No audio file was uploaded. Please upload an audio file.") | |
| if not isinstance(audio_file_path, (str, bytes, os.PathLike)): | |
| raise ValueError(f"Invalid audio file path type: {type(audio_file_path)}") | |
| if not os.path.exists(audio_file_path): | |
| raise FileNotFoundError(f"Audio file not found at path: {audio_file_path}") | |
| print("Processing audio file...") | |
| processed_audio = self.process_audio(audio_file_path, enhanced) | |
| audio_file_path = processed_audio.path | |
| audio, sr, duration = processed_audio.load_as_array(), processed_audio.sample_rate, processed_audio.duration | |
| print("Diarization in progress...") | |
| start_time = time() | |
| diarization = self.perform_diarization(audio_file_path) | |
| print(f"Diarization completed in {time() - start_time:.2f} seconds.") | |
| segments = list(diarization.itertracks(yield_label=True)) | |
| transcriptions = self.transcribe_segments(audio, sr, duration, segments) | |
| return Transcription(audio_file_path, transcriptions, segments) | |
| except Exception as e: | |
| print(f"Error occurred: {str(e)}") | |
| raise RuntimeError(f"Failed to process the audio file: {str(e)}") | |
| def process_audio(self, audio_file_path: str, enhanced: bool = False) -> AudioProcessor: | |
| """ | |
| Process the audio file to ensure it meets the requirements for transcription. | |
| Parameters: | |
| ----------- | |
| audio_file_path : str | |
| Path to the audio file to be processed. | |
| enhanced : bool, optional | |
| If True, applies audio enhancement techniques to improve audio quality. | |
| This includes optimizing noise reduction, voice enhancement, and volume boosting | |
| parameters based on the audio characteristics. | |
| Returns: | |
| -------- | |
| AudioProcessor | |
| An AudioProcessor object containing the processed audio file. | |
| """ | |
| processed_audio = AudioProcessor(audio_file_path) | |
| if processed_audio.format != ".wav": | |
| processed_audio.convert_to_wav() | |
| if processed_audio.sample_rate != 16000: | |
| processed_audio.resample_wav() | |
| if enhanced: | |
| parameters = processed_audio.optimize_enhancement_parameters() | |
| processed_audio.enhance_audio(noise_reduce_strength=parameters[0], | |
| voice_enhance_strength=parameters[1], | |
| volume_boost=parameters[2]) | |
| processed_audio.display_changes() | |
| return processed_audio | |
| def perform_diarization(self, audio_file_path: str): | |
| """Perform speaker diarization on the audio file.""" | |
| with torch.no_grad(): | |
| return self.pipeline(audio_file_path) | |
| def transcribe_segments(self, audio, sr, duration, segments): | |
| """Transcribe audio segments based on diarization.""" | |
| transcriptions = [] | |
| audio_segments = [] | |
| for turn, _, speaker in segments: | |
| start = turn.start | |
| end = min(turn.end, duration) | |
| segment = audio[int(start * sr):int(end * sr)] | |
| audio_segments.append((segment, speaker)) | |
| with tqdm( | |
| total=len(audio_segments), | |
| desc="Transcribing segments", | |
| unit="segment", | |
| ncols=100, | |
| colour="green", | |
| file=sys.stdout, | |
| mininterval=0.1, | |
| dynamic_ncols=True, | |
| leave=True | |
| ) as pbar: | |
| if self.device == "cuda": | |
| try: | |
| total_memory = torch.cuda.get_device_properties(0).total_memory | |
| reserved_memory = torch.cuda.memory_reserved(0) | |
| allocated_memory = torch.cuda.memory_allocated(0) | |
| free_memory = total_memory - reserved_memory - allocated_memory | |
| memory_per_sample = 1024 * 1024 * 1024 # 1GB | |
| batch_size = max(1, min(4, int((free_memory * 0.7) // memory_per_sample))) | |
| print(f"Using batch size of {batch_size} for GPU processing") | |
| for i in range(0, len(audio_segments), batch_size): | |
| try: | |
| batch = audio_segments[i:i + batch_size] | |
| torch.cuda.empty_cache() | |
| results = self.model([segment for segment, _ in batch]) | |
| for (_, speaker), result in zip(batch, results): | |
| transcriptions.append((speaker, result['text'].strip())) | |
| pbar.update(len(batch)) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| torch.cuda.empty_cache() | |
| for segment, speaker in batch: | |
| results = self.model([segment]) | |
| transcriptions.append((speaker, results[0]['text'].strip())) | |
| pbar.update(0.5) | |
| else: | |
| raise e | |
| except Exception as e: | |
| print(f"GPU processing failed: {str(e)}. Falling back to CPU processing...") | |
| self.model = self.model.to('cpu') | |
| self.device = 'cpu' | |
| else: | |
| for segment, speaker in audio_segments: | |
| if self.model_size == "large-v3-turbo": | |
| result = self.model(segment) | |
| transcriptions.append((speaker, result['text'].strip())) | |
| else: | |
| result = self.model.transcribe(segment, fp16=self.device == "cuda") | |
| transcriptions.append((speaker, result['text'].strip())) | |
| pbar.update(1) | |
| return transcriptions |