Spaces:
Sleeping
Sleeping
Fix transcription errors by improving audio segment handling\n\n- Add checks for empty audio segments to avoid creating invalid files\n- Pad very short audio segments to ensure Whisper compatibility\n- Use explicit WAV format with PCM_16 subtype for better compatibility\n- Add error handling around transcription to gracefully handle segment errors
5bff499
| import os | |
| import torch | |
| import gradio as gr | |
| from pathlib import Path | |
| import numpy as np | |
| from pyannote.audio import Pipeline | |
| from transformers import pipeline | |
| from datasets import Dataset | |
| import librosa | |
| import soundfile as sf | |
| from datetime import timedelta | |
| from transformers.utils import is_flash_attn_2_available | |
| class DiarizationTranscriptionTranslation: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self.device}") | |
| # Load models for diarization, transcription, and translation | |
| self.load_models() | |
| def load_models(self): | |
| """Load all required models""" | |
| print("Loading models...") | |
| # Initialize the optimized ASR pipeline using transformers with a smaller, faster model | |
| # Using whisper-base for faster download and less space usage while maintaining good speed | |
| model_name = "openai/whisper-base" # Using base which is much smaller than large variants | |
| self.asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_name, | |
| torch_dtype=torch.float32, # Using float32 for better compatibility with base model | |
| device="cuda:0" if self.device == "cuda" else "cpu", | |
| ) | |
| # Load translation model | |
| self.translation_pipeline = pipeline( | |
| "translation", | |
| model="Helsinki-NLP/opus-mt-mul-en", | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| # Initialize the diarization pipeline | |
| # Note: This requires authentication with Hugging Face for pyannote.audio models | |
| # For Hugging Face Spaces, authentication is handled automatically if set up in the space settings | |
| try: | |
| # Try to load with auth token first | |
| self.diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization@2.1", | |
| use_auth_token=True | |
| ) | |
| print("Diarization model loaded successfully!") | |
| except Exception as e: | |
| print(f"Could not load diarization model with auth: {e}") | |
| try: | |
| # Try to load without auth (for public models or when cache is available) | |
| self.diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization@2.1" | |
| ) | |
| print("Diarization model loaded successfully from cache!") | |
| except Exception as e2: | |
| print(f"Could not load diarization model: {e2}") | |
| print("Using fallback diarization method") | |
| self.diarization_pipeline = None | |
| print("Models loaded successfully!") | |
| def load_audio(self, file_path, sr=16000): | |
| """Load audio file and return waveform and sample rate""" | |
| audio, orig_sr = librosa.load(file_path, sr=None) | |
| if sr is not None and orig_sr != sr: | |
| audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) | |
| return audio, orig_sr | |
| def diarize_audio(self, file_path): | |
| """Perform speaker diarization on the audio file""" | |
| if self.diarization_pipeline is None: | |
| # Fallback to simple approach if diarization model not available | |
| duration = librosa.get_duration(path=file_path) | |
| segments = [] | |
| current_time = 0 | |
| segment_duration = min(5.0, duration) # 5 second segments max | |
| speaker_id = 1 | |
| while current_time < duration: | |
| end_time = min(current_time + segment_duration, duration) | |
| segments.append({ | |
| "start": current_time, | |
| "end": end_time, | |
| "speaker": f"Speaker {speaker_id}" | |
| }) | |
| current_time = end_time | |
| speaker_id = 2 if speaker_id == 1 else 1 # Alternate speakers | |
| return segments | |
| else: | |
| # Use pyannote.audio for proper diarization | |
| from pyannote.audio import Audio | |
| from pyannote.core import Annotation | |
| # Load the audio file | |
| audio = Audio().read(file_path) | |
| # Apply diarization | |
| diarization = self.diarization_pipeline({"waveform": audio[0], "sample_rate": audio[1]}) | |
| # Convert diarization to segments | |
| segments = [] | |
| for segment, track, speaker in diarization.itertracks(yield_label=True): | |
| segments.append({ | |
| "start": segment.start, | |
| "end": segment.end, | |
| "speaker": speaker | |
| }) | |
| return segments | |
| def transcribe_audio(self, file_path): | |
| """Transcribe audio using optimized Whisper with insanely-fast-whisper techniques""" | |
| try: | |
| # The ASR pipeline can work directly with file paths | |
| # Using optimized parameters from insanely-fast-whisper for speed | |
| result = self.asr_pipeline( | |
| file_path, | |
| chunk_length_s=30, | |
| batch_size=24, # Increased batch size for faster processing | |
| max_new_tokens=128, | |
| return_timestamps=True, | |
| generate_kwargs={"task": "transcribe"} # Explicitly set task for better performance | |
| ) | |
| return result | |
| except Exception as e: | |
| print(f"Error during transcription: {str(e)}") | |
| return {"text": f"Transcription error: {str(e)}"} | |
| def detect_language(self, text): | |
| """Detect language of the transcribed text""" | |
| # Use langdetect for language detection | |
| try: | |
| from langdetect import detect | |
| return detect(text) | |
| except ImportError: | |
| # If langdetect is not available, use a simple heuristic | |
| # This is a basic heuristic - in practice, you'd want a proper language detection model | |
| non_english_chars = sum(1 for c in text if ord(c) > 127) | |
| if non_english_chars / len(text) > 0.1: # If more than 10% are non-English chars | |
| # For demo purposes, return a common non-English language | |
| # A real implementation would use a proper language detection model | |
| return "es" # Spanish as example | |
| return "en" | |
| except: | |
| # If detection fails, return English as default | |
| return "en" | |
| def translate_text(self, text): | |
| """Translate text to English""" | |
| try: | |
| # First detect the language | |
| detected_lang = self.detect_language(text) | |
| # If the text is not in English, translate it | |
| if detected_lang != "en": | |
| result = self.translation_pipeline(text) | |
| return result[0]['translation_text'] | |
| else: | |
| return text | |
| except Exception as e: | |
| print(f"Translation error: {str(e)}") | |
| # If translation fails, return the original text | |
| return text | |
| def process_audio(self, file_path): | |
| """Main processing pipeline: diarization -> transcription -> translation""" | |
| if not file_path: | |
| return "Please upload an audio file." | |
| try: | |
| # Step 1: Diarize the audio | |
| diarization_segments = self.diarize_audio(file_path) | |
| # Step 2: For each segment, transcribe and then translate | |
| results = [] | |
| for segment in diarization_segments: | |
| # Extract the audio segment | |
| audio, orig_sr = self.load_audio(file_path) | |
| start_sample = int(segment["start"] * orig_sr) | |
| end_sample = int(segment["end"] * orig_sr) | |
| segment_audio = audio[start_sample:end_sample] | |
| # Ensure segment_audio is not empty | |
| if len(segment_audio) == 0: | |
| continue # Skip empty segments | |
| # Add a small amount of silence if segment is too short for Whisper | |
| if len(segment_audio) < orig_sr * 0.1: # Less than 0.1 seconds | |
| min_samples = int(orig_sr * 0.1) | |
| zeros_to_add = min_samples - len(segment_audio) | |
| segment_audio = np.pad(segment_audio, (0, zeros_to_add), mode='constant') | |
| # Save the segment as a temporary file for Whisper | |
| temp_file = f"temp_segment_{segment['start']}_{segment['end']}.wav" | |
| # Use subtype parameter to ensure proper WAV format | |
| sf.write(temp_file, segment_audio, orig_sr, format='WAV', subtype='PCM_16') | |
| # Transcribe the segment | |
| try: | |
| transcription_result = self.transcribe_audio(temp_file) | |
| # Handle both possible return formats | |
| if isinstance(transcription_result, dict) and "text" in transcription_result: | |
| transcribed_text = transcription_result["text"] | |
| elif isinstance(transcription_result, str): | |
| transcribed_text = transcription_result | |
| else: | |
| transcribed_text = str(transcription_result) | |
| except Exception as e: | |
| print(f"Error transcribing segment {temp_file}: {str(e)}") | |
| transcribed_text = f"Transcription error: {str(e)}" | |
| # Continue with the error message as the transcription | |
| # Translate if necessary | |
| translated_text = self.translate_text(transcribed_text) | |
| results.append({ | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "speaker": segment["speaker"], | |
| "transcription": transcribed_text, | |
| "translation": translated_text | |
| }) | |
| # Clean up temp file | |
| try: | |
| os.remove(temp_file) | |
| except: | |
| pass # Ignore errors when removing temp file | |
| # Format results as a readable transcript | |
| transcript = [] | |
| for result in results: | |
| start_time = str(timedelta(seconds=int(result["start"]))) | |
| end_time = str(timedelta(seconds=int(result["end"]))) | |
| # Show both original transcription and translation if different | |
| if result["transcription"] != result["translation"]: | |
| transcript.append( | |
| f"[{start_time} - {end_time}] {result['speaker']}:\n" | |
| f" Original: {result['transcription']}\n" | |
| f" Translation: {result['translation']}\n" | |
| ) | |
| else: | |
| transcript.append( | |
| f"[{start_time} - {end_time}] {result['speaker']}: {result['translation']}" | |
| ) | |
| return "\n".join(transcript) | |
| except Exception as e: | |
| return f"Processing error: {str(e)}" | |
| # Initialize the processor | |
| processor = DiarizationTranscriptionTranslation() | |
| def process_audio_file(audio_file): | |
| """Wrapper function for Gradio interface""" | |
| if audio_file is None: | |
| return "Please upload an audio file." | |
| return processor.process_audio(audio_file) | |
| # Create Gradio interface | |
| interface = gr.Interface( | |
| fn=process_audio_file, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
| outputs=gr.Textbox(label="Transcript with Speaker Labels", lines=20), | |
| title="Speaker Diarization, Transcription & Translation", | |
| description=""" | |
| This Space combines three powerful speech processing capabilities: | |
| 1. Speaker Diarization - Distinguishes between different speakers in your audio | |
| 2. Speech Transcription - Converts spoken words into text using Whisper | |
| 3. Automatic Translation - Translates non-English content to English | |
| Upload an audio file (MP3, WAV, or other common formats) and get a timestamped transcript with speaker labels. | |
| """, | |
| examples=[ | |
| # Add example files if available | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch(server_name="0.0.0.0", server_port=7860) |