Spaces:
Sleeping
Sleeping
File size: 14,924 Bytes
dc7a247 0baba61 dc7a247 d926715 275f6d2 d926715 dc7a247 d926715 dc7a247 d926715 dc7a247 d926715 dc7a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
import gradio as gr
import os
import torch
import gc
import json
import whisperx
from pyannote.audio import Pipeline
from huggingface_hub import HfFolder
from transformers import pipeline
import numpy as np
import soundfile as sf
import io
import tempfile
# --- Configuration ---
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.")
print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
whisper_model_size = "medium" # 'large-v2' is best but most resource intensive.
# 'small' or 'medium' are better for free tiers.
# --- Global Models (loaded once) ---
whisper_model_global = None
diarize_pipeline_global = None
translation_pipeline_global = None
def load_all_models():
global whisper_model_global, diarize_pipeline_global, translation_pipeline_global
print(f"Loading WhisperX model ({whisper_model_size})...")
whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type)
print("Loading Pyannote Diarization Pipeline...")
if not HF_TOKEN:
raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.")
diarize_pipeline_global = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...")
try:
translation_pipeline_global = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-ta-en",
device=0 if device == "cuda" else -1
)
except Exception as e:
print(f"Could not load translation model: {e}")
translation_pipeline_global = None
# Load models when the Gradio app starts
load_all_models()
def convert_audio_for_whisper(audio_input):
"""
Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file
that WhisperX expects. Returns the path to the temporary WAV file.
"""
temp_wav_path = None
if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath")
input_filepath = audio_input
temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
try:
waveform, sample_rate = sf.read(input_filepath)
if waveform.ndim > 1:
waveform = waveform.mean(axis=1) # Convert to mono if stereo
# Resample only if necessary
if sample_rate != 16000:
print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
# For high-quality resampling, you'd use torchaudio.transforms.Resample
# For simple cases, soundfile might handle basic resample on write,
# or WhisperX's load_audio does its own internal resampling.
# Explicitly loading/resampling here for robustness.
from torchaudio.transforms import Resample
waveform_tensor = torch.from_numpy(waveform).float()
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform_tensor).numpy()
sample_rate = 16000 # Update sample rate after resampling
sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16')
return temp_wav_path
except Exception as e:
print(f"Error converting uploaded audio: {e}")
return None
elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone
sample_rate, numpy_array = audio_input
# Ensure it's mono
if numpy_array.ndim > 1:
numpy_array = numpy_array.mean(axis=1)
# Normalize to float32 if not already (soundfile expects this)
if numpy_array.dtype != np.float32:
numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array))
# Resample only if necessary for microphone input as well
if sample_rate != 16000:
print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
from torchaudio.transforms import Resample
waveform_tensor = torch.from_numpy(numpy_array).float()
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
numpy_array = resampler(waveform_tensor).numpy()
sample_rate = 16000 # Update sample rate after resampling
temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
try:
sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz
return temp_wav_path
except Exception as e:
print(f"Error writing microphone audio to temp file: {e}")
return None
return None
def process_audio_for_web(audio_input):
"""
Processes an audio input (from upload or microphone) for speaker diarization,
transcription, and translation.
"""
if audio_input is None:
return "Please upload an audio file or record from microphone.", "", "", None
audio_file_path = convert_audio_for_whisper(audio_input)
if not audio_file_path:
return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None
print(f"Processing audio from temp file: {audio_file_path}")
try:
audio = whisperx.load_audio(audio_file_path)
# 1. Transcribe
print("Transcribing audio...")
transcription_result = whisper_model_global.transcribe(audio, batch_size=1)
detected_language = transcription_result["language"]
print(f"Detected overall language: {detected_language}")
# 2. Align
print("Aligning transcription with audio...")
align_model_local = None # Initialize to None to prevent UnboundLocalError in outer except
try:
# Load the alignment model based on the detected language
# The 'device' parameter is passed here, not to whisperx.align
align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device)
except Exception as e:
# Handle cases where the alignment model for the detected language cannot be loaded
print(f"Error loading alignment model for language '{detected_language}': {e}")
import traceback
print(traceback.format_exc())
# Provide a user-friendly message, possibly suggesting supported languages
return f"Error: Could not load alignment model for language '{detected_language}'. Alignment is typically supported for English, French, German, Spanish, Italian, Japanese, Chinese, Dutch, and Portuguese. Details: {e}", "", "", None
# Perform alignment using the loaded model
# Removed 'device' from here as the model itself is already on the correct device
transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False)
# Removed the duplicate whisperx.align call and 'del align_model_local'
# as it can cause issues if an error occurs later.
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# 3. Diarize
print("Performing speaker diarization...")
diarize_segments = diarize_pipeline_global(audio_file_path)
final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result)
speaker_transcripts_raw = {}
# Prepare for display in diarized_transcription_output
diarized_display_lines = []
for segment in final_result["segments"]:
speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER")
text = segment["text"].strip()
start = segment["start"]
end = segment["end"]
if speaker_id not in speaker_transcripts_raw:
speaker_transcripts_raw[speaker_id] = []
speaker_transcripts_raw[speaker_id].append({
"start": start,
"end": end,
"text": text
})
diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}")
full_diarized_text_str = "\n".join(diarized_display_lines)
# 4. Translate
translated_display_lines = []
if translation_pipeline_global:
translated_speaker_data = {} # To hold translated segments per speaker
for speaker, segments in speaker_transcripts_raw.items():
translated_speaker_data[speaker] = [] # Initialize for current speaker
translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---")
for seg in segments:
original_text = seg['text']
translated_text_output = original_text
is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text)
if original_text and (detected_language == 'ta' or is_tamil_char_present):
try:
translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en")
translated_text_output = translated_result[0]['translation_text']
except Exception as e:
print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.")
translated_speaker_data[speaker].append({
"start": seg['start'],
"end": seg['end'],
"original_text": original_text,
"translated_text": translated_text_output
})
translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}")
translated_display_lines.append(f" Translated: {translated_text_output}")
translated_output_str = "\n".join(translated_display_lines)
else:
translated_output_str = "Translation model not loaded. Skipping translation."
# Create a temporary file for download
output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name
with open(output_filename, "w", encoding="utf-8") as f:
f.write("--- Speaker-wise Original Transcription ---\n\n")
# Write original transcription per speaker
for speaker, segments in speaker_transcripts_raw.items():
f.write(f"\n### Speaker {speaker} ###\n")
for seg in segments:
f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n")
f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n")
# Write translated transcription per speaker
if translation_pipeline_global and 'translated_speaker_data' in locals():
for speaker, segments in translated_speaker_data.items():
f.write(f"\n### Speaker {speaker} ###\n")
for seg in segments:
f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n")
f.write(f" Translated: {seg['translated_text']}\n")
else:
f.write("Translation output not available or translation model not loaded.\n")
f.write(f"\n\nOverall Detected Language: {detected_language}")
# Clean up the temporary audio file
os.unlink(audio_file_path)
return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename
except Exception as e:
import traceback
error_message = f"An error occurred: {e}\n{traceback.format_exc()}"
print(error_message)
# Clean up temp audio file even on error
if audio_file_path and os.path.exists(audio_file_path):
os.unlink(audio_file_path)
return error_message, "", "", None
# --- Gradio Interface ---
with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo:
gr.Markdown(
"""
# Language-Agnostic Speaker Diarization, Transcription, and Translation
Upload an audio file (WAV, MP3, etc.) or record directly from your microphone.
The system will identify speakers, transcribe their speech (in detected language),
and provide an English translation for relevant segments.
"""
)
with gr.Row():
audio_input = gr.Audio(
type="filepath",
sources=["upload", "microphone"],
label="Upload Audio File or Record from Microphone"
)
with gr.Row():
process_button = gr.Button("Process Audio", variant="primary")
with gr.Column():
detected_language_output = gr.Textbox(label="Detected Overall Language")
# Diarized Transcription will still be chronological with speaker labels
diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False)
# Translated transcription will now be clearly separated by speaker
translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False)
download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False)
process_button.click(
fn=process_audio_for_web,
inputs=audio_input,
outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button]
)
gr.Examples(
[
# Add paths to your example audio files here.
# These files must be present in your Hugging Face Space repository.
# For example, if you have 'sample_two_speakers.wav' in your repo:
# "sample_two_speakers.wav"
],
inputs=audio_input,
outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button],
fn=process_audio_for_web,
cache_examples=False
)
demo.launch() |