Spaces:
Build error
Build error
| import whisperx | |
| import torch | |
| import numpy as np | |
| from scipy.signal import resample | |
| from pyannote.audio import Pipeline | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import logging | |
| import time | |
| from difflib import SequenceMatcher | |
| import spaces | |
| hf_token = os.getenv("HF_TOKEN") | |
| CHUNK_LENGTH = 30 | |
| OVERLAP = 2 | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def load_whisper_model(model_size="small"): | |
| logger.info(f"Loading Whisper model (size: {model_size})...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "int8" | |
| try: | |
| model = whisperx.load_model(model_size, device, compute_type=compute_type) | |
| logger.info(f"Whisper model loaded successfully on {device}") | |
| return model | |
| except RuntimeError as e: | |
| logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}") | |
| device = "cpu" | |
| compute_type = "int8" | |
| model = whisperx.load_model(model_size, device, compute_type=compute_type) | |
| logger.info("Whisper model loaded successfully on CPU") | |
| return model | |
| def load_diarization_pipeline(): | |
| logger.info("Loading diarization pipeline...") | |
| try: | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token) | |
| if torch.cuda.is_available(): | |
| pipeline = pipeline.to(torch.device("cuda")) | |
| logger.info("Diarization pipeline loaded successfully") | |
| return pipeline | |
| except Exception as e: | |
| logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.") | |
| return None | |
| def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): | |
| chunks = [] | |
| for i in range(0, len(audio), chunk_size - overlap): | |
| chunk = audio[i:i+chunk_size] | |
| if len(chunk) < chunk_size: | |
| chunk = np.pad(chunk, (0, chunk_size - len(chunk))) | |
| chunks.append(chunk) | |
| return chunks | |
| def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7): | |
| merged = [] | |
| for segment in segments: | |
| if not merged or segment['start'] - merged[-1]['end'] > time_threshold: | |
| merged.append(segment) | |
| else: | |
| matcher = SequenceMatcher(None, merged[-1]['text'], segment['text']) | |
| match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text'])) | |
| if match.size / len(segment['text']) > similarity_threshold: | |
| merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:] | |
| merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:] | |
| merged[-1]['end'] = segment['end'] | |
| merged[-1]['text'] = merged_text | |
| if 'translated' in segment: | |
| merged[-1]['translated'] = merged_translated | |
| else: | |
| merged.append(segment) | |
| return merged | |
| def get_most_common_speaker(diarization_result, start_time, end_time): | |
| speakers = [] | |
| for turn, _, speaker in diarization_result.itertracks(yield_label=True): | |
| if turn.start <= end_time and turn.end >= start_time: | |
| speakers.append(speaker) | |
| return max(set(speakers), key=speakers.count) if speakers else "Unknown" | |
| def split_audio(audio, max_duration=30): | |
| sample_rate = 16000 | |
| max_samples = max_duration * sample_rate | |
| if len(audio) <= max_samples: | |
| return [audio] | |
| splits = [] | |
| for i in range(0, len(audio), max_samples): | |
| splits.append(audio[i:i+max_samples]) | |
| return splits | |
| def process_audio(audio_file, translate=False, model_size="small", use_diarization=True): | |
| logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}") | |
| start_time = time.time() | |
| try: | |
| whisper_model = load_whisper_model(model_size) | |
| audio = whisperx.load_audio(audio_file) | |
| audio_splits = split_audio(audio) | |
| diarization_result = None | |
| if use_diarization: | |
| diarization_pipeline = load_diarization_pipeline() | |
| if diarization_pipeline is not None: | |
| try: | |
| diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}) | |
| except Exception as e: | |
| logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.") | |
| language_segments = [] | |
| final_segments = [] | |
| for i, audio_split in enumerate(audio_splits): | |
| logger.info(f"Processing split {i+1}/{len(audio_splits)}") | |
| result = whisper_model.transcribe(audio_split) | |
| lang = result["language"] | |
| for segment in result["segments"]: | |
| segment_start = segment["start"] + (i * 30) | |
| segment_end = segment["end"] + (i * 30) | |
| speaker = "Unknown" | |
| if diarization_result is not None: | |
| speaker = get_most_common_speaker(diarization_result, segment_start, segment_end) | |
| final_segment = { | |
| "start": segment_start, | |
| "end": segment_end, | |
| "language": lang, | |
| "speaker": speaker, | |
| "text": segment["text"], | |
| } | |
| if translate: | |
| translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate") | |
| final_segment["translated"] = translation["text"] | |
| final_segments.append(final_segment) | |
| language_segments.append({ | |
| "language": lang, | |
| "start": i * 30, | |
| "end": min((i + 1) * 30, len(audio) / 16000) | |
| }) | |
| final_segments.sort(key=lambda x: x["start"]) | |
| merged_segments = merge_nearby_segments(final_segments) | |
| end_time = time.time() | |
| logger.info(f"Total processing time: {end_time - start_time:.2f} seconds") | |
| return language_segments, merged_segments | |
| except Exception as e: | |
| logger.error(f"An error occurred during audio processing: {str(e)}") | |
| raise |