DineshJ96's picture
transcript_file_updated
275f6d2
import gradio as gr
import os
import torch
import gc
import json
import whisperx
from pyannote.audio import Pipeline
from huggingface_hub import HfFolder
from transformers import pipeline
import numpy as np
import soundfile as sf
import io
import tempfile
# --- Configuration ---
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.")
print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
whisper_model_size = "medium" # 'large-v2' is best but most resource intensive.
# 'small' or 'medium' are better for free tiers.
# --- Global Models (loaded once) ---
whisper_model_global = None
diarize_pipeline_global = None
translation_pipeline_global = None
def load_all_models():
global whisper_model_global, diarize_pipeline_global, translation_pipeline_global
print(f"Loading WhisperX model ({whisper_model_size})...")
whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type)
print("Loading Pyannote Diarization Pipeline...")
if not HF_TOKEN:
raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.")
diarize_pipeline_global = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...")
try:
translation_pipeline_global = pipeline(
"translation",
model="Helsinki-NLP/opus-mt-ta-en",
device=0 if device == "cuda" else -1
)
except Exception as e:
print(f"Could not load translation model: {e}")
translation_pipeline_global = None
# Load models when the Gradio app starts
load_all_models()
def convert_audio_for_whisper(audio_input):
"""
Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file
that WhisperX expects. Returns the path to the temporary WAV file.
"""
temp_wav_path = None
if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath")
input_filepath = audio_input
temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
try:
waveform, sample_rate = sf.read(input_filepath)
if waveform.ndim > 1:
waveform = waveform.mean(axis=1) # Convert to mono if stereo
# Resample only if necessary
if sample_rate != 16000:
print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
# For high-quality resampling, you'd use torchaudio.transforms.Resample
# For simple cases, soundfile might handle basic resample on write,
# or WhisperX's load_audio does its own internal resampling.
# Explicitly loading/resampling here for robustness.
from torchaudio.transforms import Resample
waveform_tensor = torch.from_numpy(waveform).float()
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform_tensor).numpy()
sample_rate = 16000 # Update sample rate after resampling
sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16')
return temp_wav_path
except Exception as e:
print(f"Error converting uploaded audio: {e}")
return None
elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone
sample_rate, numpy_array = audio_input
# Ensure it's mono
if numpy_array.ndim > 1:
numpy_array = numpy_array.mean(axis=1)
# Normalize to float32 if not already (soundfile expects this)
if numpy_array.dtype != np.float32:
numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array))
# Resample only if necessary for microphone input as well
if sample_rate != 16000:
print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
from torchaudio.transforms import Resample
waveform_tensor = torch.from_numpy(numpy_array).float()
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
numpy_array = resampler(waveform_tensor).numpy()
sample_rate = 16000 # Update sample rate after resampling
temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
try:
sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz
return temp_wav_path
except Exception as e:
print(f"Error writing microphone audio to temp file: {e}")
return None
return None
def process_audio_for_web(audio_input):
"""
Processes an audio input (from upload or microphone) for speaker diarization,
transcription, and translation.
"""
if audio_input is None:
return "Please upload an audio file or record from microphone.", "", "", None
audio_file_path = convert_audio_for_whisper(audio_input)
if not audio_file_path:
return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None
print(f"Processing audio from temp file: {audio_file_path}")
try:
audio = whisperx.load_audio(audio_file_path)
# 1. Transcribe
print("Transcribing audio...")
transcription_result = whisper_model_global.transcribe(audio, batch_size=1)
detected_language = transcription_result["language"]
print(f"Detected overall language: {detected_language}")
# 2. Align
print("Aligning transcription with audio...")
align_model_local = None # Initialize to None to prevent UnboundLocalError in outer except
try:
# Load the alignment model based on the detected language
# The 'device' parameter is passed here, not to whisperx.align
align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device)
except Exception as e:
# Handle cases where the alignment model for the detected language cannot be loaded
print(f"Error loading alignment model for language '{detected_language}': {e}")
import traceback
print(traceback.format_exc())
# Provide a user-friendly message, possibly suggesting supported languages
return f"Error: Could not load alignment model for language '{detected_language}'. Alignment is typically supported for English, French, German, Spanish, Italian, Japanese, Chinese, Dutch, and Portuguese. Details: {e}", "", "", None
# Perform alignment using the loaded model
# Removed 'device' from here as the model itself is already on the correct device
transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False)
# Removed the duplicate whisperx.align call and 'del align_model_local'
# as it can cause issues if an error occurs later.
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# 3. Diarize
print("Performing speaker diarization...")
diarize_segments = diarize_pipeline_global(audio_file_path)
final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result)
speaker_transcripts_raw = {}
# Prepare for display in diarized_transcription_output
diarized_display_lines = []
for segment in final_result["segments"]:
speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER")
text = segment["text"].strip()
start = segment["start"]
end = segment["end"]
if speaker_id not in speaker_transcripts_raw:
speaker_transcripts_raw[speaker_id] = []
speaker_transcripts_raw[speaker_id].append({
"start": start,
"end": end,
"text": text
})
diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}")
full_diarized_text_str = "\n".join(diarized_display_lines)
# 4. Translate
translated_display_lines = []
if translation_pipeline_global:
translated_speaker_data = {} # To hold translated segments per speaker
for speaker, segments in speaker_transcripts_raw.items():
translated_speaker_data[speaker] = [] # Initialize for current speaker
translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---")
for seg in segments:
original_text = seg['text']
translated_text_output = original_text
is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text)
if original_text and (detected_language == 'ta' or is_tamil_char_present):
try:
translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en")
translated_text_output = translated_result[0]['translation_text']
except Exception as e:
print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.")
translated_speaker_data[speaker].append({
"start": seg['start'],
"end": seg['end'],
"original_text": original_text,
"translated_text": translated_text_output
})
translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}")
translated_display_lines.append(f" Translated: {translated_text_output}")
translated_output_str = "\n".join(translated_display_lines)
else:
translated_output_str = "Translation model not loaded. Skipping translation."
# Create a temporary file for download
output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name
with open(output_filename, "w", encoding="utf-8") as f:
f.write("--- Speaker-wise Original Transcription ---\n\n")
# Write original transcription per speaker
for speaker, segments in speaker_transcripts_raw.items():
f.write(f"\n### Speaker {speaker} ###\n")
for seg in segments:
f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n")
f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n")
# Write translated transcription per speaker
if translation_pipeline_global and 'translated_speaker_data' in locals():
for speaker, segments in translated_speaker_data.items():
f.write(f"\n### Speaker {speaker} ###\n")
for seg in segments:
f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n")
f.write(f" Translated: {seg['translated_text']}\n")
else:
f.write("Translation output not available or translation model not loaded.\n")
f.write(f"\n\nOverall Detected Language: {detected_language}")
# Clean up the temporary audio file
os.unlink(audio_file_path)
return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename
except Exception as e:
import traceback
error_message = f"An error occurred: {e}\n{traceback.format_exc()}"
print(error_message)
# Clean up temp audio file even on error
if audio_file_path and os.path.exists(audio_file_path):
os.unlink(audio_file_path)
return error_message, "", "", None
# --- Gradio Interface ---
with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo:
gr.Markdown(
"""
# Language-Agnostic Speaker Diarization, Transcription, and Translation
Upload an audio file (WAV, MP3, etc.) or record directly from your microphone.
The system will identify speakers, transcribe their speech (in detected language),
and provide an English translation for relevant segments.
"""
)
with gr.Row():
audio_input = gr.Audio(
type="filepath",
sources=["upload", "microphone"],
label="Upload Audio File or Record from Microphone"
)
with gr.Row():
process_button = gr.Button("Process Audio", variant="primary")
with gr.Column():
detected_language_output = gr.Textbox(label="Detected Overall Language")
# Diarized Transcription will still be chronological with speaker labels
diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False)
# Translated transcription will now be clearly separated by speaker
translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False)
download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False)
process_button.click(
fn=process_audio_for_web,
inputs=audio_input,
outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button]
)
gr.Examples(
[
# Add paths to your example audio files here.
# These files must be present in your Hugging Face Space repository.
# For example, if you have 'sample_two_speakers.wav' in your repo:
# "sample_two_speakers.wav"
],
inputs=audio_input,
outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button],
fn=process_audio_for_web,
cache_examples=False
)
demo.launch()