from faster_whisper import WhisperModel from transformers import pipeline from pydub import AudioSegment import os import torchaudio import torch import re import time import sys from pathlib import Path import glob import ctypes import numpy as np from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V1, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, RESAMPLING_FREQ, BATCH_SIZE, TASK def load_cudnn(): if not torch.cuda.is_available(): if DEBUG_MODE: print("[INFO] CUDA is not available, skipping cuDNN setup.") return if DEBUG_MODE: print(f"[INFO] sys.platform: {sys.platform}") if sys.platform == "win32": torch_lib_dir = Path(torch.__file__).parent / "lib" if torch_lib_dir.exists(): os.add_dll_directory(str(torch_lib_dir)) if DEBUG_MODE: print(f"[INFO] Added DLL directory: {torch_lib_dir}") else: if DEBUG_MODE: print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") elif sys.platform == "linux": site_packages = Path(torch.__file__).resolve().parents[1] cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" if not cudnn_dir.exists(): if DEBUG_MODE: print(f"[ERROR] cudnn dir not found: {cudnn_dir}") return pattern = str(cudnn_dir / "libcudnn_cnn*.so*") matching_files = sorted(glob.glob(pattern)) if not matching_files: if DEBUG_MODE: print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") return for so_path in matching_files: try: ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) if DEBUG_MODE: print(f"[INFO] Loaded: {so_path}") except OSError as e: if DEBUG_MODE: print(f"[WARNING] Failed to load {so_path}: {e}") else: if DEBUG_MODE: print(f"[WARNING] sys.platform is not win32 or linux") def get_settings(): if DEBUG_MODE: print(f"Entering get_settings function...") is_cuda_available = torch.cuda.is_available() if is_cuda_available: device = "cuda" compute_type = "default" else: device = "cpu" compute_type = "default" if DEBUG_MODE: print(f"[SETTINGS] Device: {device}") if DEBUG_MODE: print(f"Exited get_settings function.") return device, compute_type def load_model(use_v2_fast, device, compute_type): if DEBUG_MODE: print(f"Entering load_model function...") print(f"[MODEL LOADING] use_v2_fast: {use_v2_fast}") if use_v2_fast: model = WhisperModel( MODEL_PATH_V2_FAST, device = device, compute_type = compute_type, ) else: model = pipeline( task="automatic-speech-recognition", model=MODEL_PATH_V1, chunk_length_s=30, device=device, token=os.getenv("HF_TOKEN") ) if DEBUG_MODE: print(f"Exiting load_model function...") return model def split_input_stereo_channels(audio_path): if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...") ext = os.path.splitext(audio_path)[1].lower() if ext == ".wav": audio = AudioSegment.from_wav(audio_path) elif ext == ".mp3": audio = AudioSegment.from_file(audio_path, format="mp3") else: raise ValueError(f"[FORMAT AUDIO] Unsupported file format for: {audio_path}") channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"[FORMAT AUDIO] Audio {audio_path} has {len(channels)} channels (instead of 2).") channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.") def compute_type_to_audio_dtype(compute_type: str, device: str) -> np.dtype: if DEBUG_MODE: print(f"Entering compute_type_to_audio_dtype function.") compute_type = compute_type.lower() if device.startswith("cuda"): if "float16" in compute_type or "int8" in compute_type: audio_np_dtype = np.float16 else: audio_np_dtype = np.float32 else: audio_np_dtype = np.float32 if DEBUG_MODE: print(f"Exited compute_type_to_audio_dtype function.") return audio_np_dtype def format_audio(audio_path: str, compute_type: str, device: str) -> np.ndarray: if DEBUG_MODE: print(f"Entering format_audio function...") input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ) input_audio = resampler(input_audio) input_audio = input_audio.squeeze() np_dtype = compute_type_to_audio_dtype(compute_type, device) input_audio = input_audio.numpy().astype(np_dtype) if DEBUG_MODE: print(f"[FORMAT AUDIO] Audio dtype for actual_compute_type: {input_audio.dtype}") print(f"Exited format_audio function.") return input_audio def process_waveforms(device: str, compute_type: str): if DEBUG_MODE: print(f"Entering process_waveforms function...") left_waveform = format_audio(LEFT_CHANNEL_TEMP_PATH, compute_type, device) right_waveform = format_audio(RIGHT_CHANNEL_TEMP_PATH, compute_type, device) if DEBUG_MODE: print(f"Exited process_waveforms function.") return left_waveform, right_waveform def transcribe_pipeline(audio, model): if DEBUG_MODE: print(f"Entering transcribe_pipeline function.") text = model(audio, batch_size=BATCH_SIZE, generate_kwargs={"task": TASK}, return_timestamps=True)["text"] if DEBUG_MODE: print(f"Exited transcribe_pipeline function.") return text def transcribe_channels(left_waveform, right_waveform, model): if DEBUG_MODE: print(f"Entering transcribe_channels function...") left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe") right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe") left_result = list(left_result) right_result = list(right_result) if DEBUG_MODE: print(f"Exited transcribe_channels function.") return left_result, right_result # TODO refactor and rename this function def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription # TODO not used right now, decide to use it or not def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) merged_transcription = '' current_speaker = None current_segment = [] for i in range(1, len(segments) - 1, 2): speaker_tag = segments[i] text = segments[i + 1].strip() speaker = re.search(r'\d{2}', speaker_tag).group() if speaker == current_speaker: current_segment.append(text) else: if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' current_speaker = speaker current_segment = [text] if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' return merged_transcription.strip() def get_segments(result, speaker_label): if DEBUG_MODE: print(f"Entering get_segments function...") segments = result final_segments = [ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) for seg in segments if seg.text ] if DEBUG_MODE: print(f"EXited get_segments function.") return final_segments def post_process_transcripts(left_result, right_result): if DEBUG_MODE: print(f"Entering post_process_transcripts function...") left_segs = get_segments(left_result, "Speaker 1") right_segs = get_segments(right_result, "Speaker 2") merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) clean_output = "" for start, end, speaker, text in merged_transcript: clean_output += f"[{speaker}]: {text}\n" clean_output = clean_output.strip() if DEBUG_MODE: print(f"Exited post_process_transcripts function.") return clean_output def cleanup_temp_files(*file_paths): if DEBUG_MODE: print(f"Entered cleanup_temp_files function...") for path in file_paths: if path and os.path.exists(path): if DEBUG_MODE: print(f"Removing path: {path}") os.remove(path) if DEBUG_MODE: print(f"Exited cleanup_temp_files function.") def generate(audio_path, use_v2_fast): if DEBUG_MODE: print(f"Entering generate function...") start = time.time() load_cudnn() device, requested_compute_type = get_settings() model = load_model(use_v2_fast, device, requested_compute_type) if use_v2_fast: actual_compute_type = model.model.compute_type else: actual_compute_type = "float32" #HF pipeline safe default if DEBUG_MODE: print(f"[SETTINGS] Requested compute_type: {requested_compute_type}") print(f"[SETTINGS] Actual compute_type: {actual_compute_type}") if use_v2_fast: split_input_stereo_channels(audio_path) left_waveform, right_waveform = process_waveforms(device, actual_compute_type) left_result, right_result = transcribe_channels(left_waveform, right_waveform, model) output = post_process_transcripts(left_result, right_result) cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH) else: audio = format_audio(audio_path, actual_compute_type, device) merged_results = transcribe_pipeline(audio, model) output = post_process_transcription(merged_results) end = time.time() audio_duration = torchaudio.info(audio_path).num_frames / torchaudio.info(audio_path).sample_rate rtf = (end - start) / audio_duration if DEBUG_MODE: print(f"[LATENCY]: {end - start}") if DEBUG_MODE: print(f"[RTF]: {rtf:.2f}") if DEBUG_MODE: print(f"Exited generate function.") return output