diarization / app.py
andrijdavid's picture
Fix transcription errors by improving audio segment handling\n\n- Add checks for empty audio segments to avoid creating invalid files\n- Pad very short audio segments to ensure Whisper compatibility\n- Use explicit WAV format with PCM_16 subtype for better compatibility\n- Add error handling around transcription to gracefully handle segment errors
5bff499
import os
import torch
import gradio as gr
from pathlib import Path
import numpy as np
from pyannote.audio import Pipeline
from transformers import pipeline
from datasets import Dataset
import librosa
import soundfile as sf
from datetime import timedelta
from transformers.utils import is_flash_attn_2_available
class DiarizationTranscriptionTranslation:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load models for diarization, transcription, and translation
self.load_models()
def load_models(self):
"""Load all required models"""
print("Loading models...")
# Initialize the optimized ASR pipeline using transformers with a smaller, faster model
# Using whisper-base for faster download and less space usage while maintaining good speed
model_name = "openai/whisper-base" # Using base which is much smaller than large variants
self.asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model_name,
torch_dtype=torch.float32, # Using float32 for better compatibility with base model
device="cuda:0" if self.device == "cuda" else "cpu",
)
# Load translation model
self.translation_pipeline = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-mul-en",
device=0 if self.device == "cuda" else -1
)
# Initialize the diarization pipeline
# Note: This requires authentication with Hugging Face for pyannote.audio models
# For Hugging Face Spaces, authentication is handled automatically if set up in the space settings
try:
# Try to load with auth token first
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization@2.1",
use_auth_token=True
)
print("Diarization model loaded successfully!")
except Exception as e:
print(f"Could not load diarization model with auth: {e}")
try:
# Try to load without auth (for public models or when cache is available)
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization@2.1"
)
print("Diarization model loaded successfully from cache!")
except Exception as e2:
print(f"Could not load diarization model: {e2}")
print("Using fallback diarization method")
self.diarization_pipeline = None
print("Models loaded successfully!")
def load_audio(self, file_path, sr=16000):
"""Load audio file and return waveform and sample rate"""
audio, orig_sr = librosa.load(file_path, sr=None)
if sr is not None and orig_sr != sr:
audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr)
return audio, orig_sr
def diarize_audio(self, file_path):
"""Perform speaker diarization on the audio file"""
if self.diarization_pipeline is None:
# Fallback to simple approach if diarization model not available
duration = librosa.get_duration(path=file_path)
segments = []
current_time = 0
segment_duration = min(5.0, duration) # 5 second segments max
speaker_id = 1
while current_time < duration:
end_time = min(current_time + segment_duration, duration)
segments.append({
"start": current_time,
"end": end_time,
"speaker": f"Speaker {speaker_id}"
})
current_time = end_time
speaker_id = 2 if speaker_id == 1 else 1 # Alternate speakers
return segments
else:
# Use pyannote.audio for proper diarization
from pyannote.audio import Audio
from pyannote.core import Annotation
# Load the audio file
audio = Audio().read(file_path)
# Apply diarization
diarization = self.diarization_pipeline({"waveform": audio[0], "sample_rate": audio[1]})
# Convert diarization to segments
segments = []
for segment, track, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": segment.start,
"end": segment.end,
"speaker": speaker
})
return segments
def transcribe_audio(self, file_path):
"""Transcribe audio using optimized Whisper with insanely-fast-whisper techniques"""
try:
# The ASR pipeline can work directly with file paths
# Using optimized parameters from insanely-fast-whisper for speed
result = self.asr_pipeline(
file_path,
chunk_length_s=30,
batch_size=24, # Increased batch size for faster processing
max_new_tokens=128,
return_timestamps=True,
generate_kwargs={"task": "transcribe"} # Explicitly set task for better performance
)
return result
except Exception as e:
print(f"Error during transcription: {str(e)}")
return {"text": f"Transcription error: {str(e)}"}
def detect_language(self, text):
"""Detect language of the transcribed text"""
# Use langdetect for language detection
try:
from langdetect import detect
return detect(text)
except ImportError:
# If langdetect is not available, use a simple heuristic
# This is a basic heuristic - in practice, you'd want a proper language detection model
non_english_chars = sum(1 for c in text if ord(c) > 127)
if non_english_chars / len(text) > 0.1: # If more than 10% are non-English chars
# For demo purposes, return a common non-English language
# A real implementation would use a proper language detection model
return "es" # Spanish as example
return "en"
except:
# If detection fails, return English as default
return "en"
def translate_text(self, text):
"""Translate text to English"""
try:
# First detect the language
detected_lang = self.detect_language(text)
# If the text is not in English, translate it
if detected_lang != "en":
result = self.translation_pipeline(text)
return result[0]['translation_text']
else:
return text
except Exception as e:
print(f"Translation error: {str(e)}")
# If translation fails, return the original text
return text
def process_audio(self, file_path):
"""Main processing pipeline: diarization -> transcription -> translation"""
if not file_path:
return "Please upload an audio file."
try:
# Step 1: Diarize the audio
diarization_segments = self.diarize_audio(file_path)
# Step 2: For each segment, transcribe and then translate
results = []
for segment in diarization_segments:
# Extract the audio segment
audio, orig_sr = self.load_audio(file_path)
start_sample = int(segment["start"] * orig_sr)
end_sample = int(segment["end"] * orig_sr)
segment_audio = audio[start_sample:end_sample]
# Ensure segment_audio is not empty
if len(segment_audio) == 0:
continue # Skip empty segments
# Add a small amount of silence if segment is too short for Whisper
if len(segment_audio) < orig_sr * 0.1: # Less than 0.1 seconds
min_samples = int(orig_sr * 0.1)
zeros_to_add = min_samples - len(segment_audio)
segment_audio = np.pad(segment_audio, (0, zeros_to_add), mode='constant')
# Save the segment as a temporary file for Whisper
temp_file = f"temp_segment_{segment['start']}_{segment['end']}.wav"
# Use subtype parameter to ensure proper WAV format
sf.write(temp_file, segment_audio, orig_sr, format='WAV', subtype='PCM_16')
# Transcribe the segment
try:
transcription_result = self.transcribe_audio(temp_file)
# Handle both possible return formats
if isinstance(transcription_result, dict) and "text" in transcription_result:
transcribed_text = transcription_result["text"]
elif isinstance(transcription_result, str):
transcribed_text = transcription_result
else:
transcribed_text = str(transcription_result)
except Exception as e:
print(f"Error transcribing segment {temp_file}: {str(e)}")
transcribed_text = f"Transcription error: {str(e)}"
# Continue with the error message as the transcription
# Translate if necessary
translated_text = self.translate_text(transcribed_text)
results.append({
"start": segment["start"],
"end": segment["end"],
"speaker": segment["speaker"],
"transcription": transcribed_text,
"translation": translated_text
})
# Clean up temp file
try:
os.remove(temp_file)
except:
pass # Ignore errors when removing temp file
# Format results as a readable transcript
transcript = []
for result in results:
start_time = str(timedelta(seconds=int(result["start"])))
end_time = str(timedelta(seconds=int(result["end"])))
# Show both original transcription and translation if different
if result["transcription"] != result["translation"]:
transcript.append(
f"[{start_time} - {end_time}] {result['speaker']}:\n"
f" Original: {result['transcription']}\n"
f" Translation: {result['translation']}\n"
)
else:
transcript.append(
f"[{start_time} - {end_time}] {result['speaker']}: {result['translation']}"
)
return "\n".join(transcript)
except Exception as e:
return f"Processing error: {str(e)}"
# Initialize the processor
processor = DiarizationTranscriptionTranslation()
def process_audio_file(audio_file):
"""Wrapper function for Gradio interface"""
if audio_file is None:
return "Please upload an audio file."
return processor.process_audio(audio_file)
# Create Gradio interface
interface = gr.Interface(
fn=process_audio_file,
inputs=gr.Audio(type="filepath", label="Upload Audio File"),
outputs=gr.Textbox(label="Transcript with Speaker Labels", lines=20),
title="Speaker Diarization, Transcription & Translation",
description="""
This Space combines three powerful speech processing capabilities:
1. Speaker Diarization - Distinguishes between different speakers in your audio
2. Speech Transcription - Converts spoken words into text using Whisper
3. Automatic Translation - Translates non-English content to English
Upload an audio file (MP3, WAV, or other common formats) and get a timestamped transcript with speaker labels.
""",
examples=[
# Add example files if available
],
cache_examples=False
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860)