DineshJ96's picture
app & req file updated
eb74f8a
# app.py
# A 100% OPEN-SOURCE audio processing application.
# - Local Whisper for Transcription
# - Local Pyannote for Diarization
# - Local Helsinki-NLP for Translation
import os
import torch
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio
from transformers import pipeline as hf_pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import tempfile
import logging
import warnings
from pyannote.audio import Pipeline as PyannotePipeline
from langdetect import detect, LangDetectException
# --- 1. Initial Setup & Configuration ---
warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.functional')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Language name mapping
LANGUAGE_NAME_MAPPING = {
"en": "English", "zh-cn": "Chinese", "de": "German", "es": "Spanish", "ru": "Russian",
"ko": "Korean", "fr": "French", "ja": "Japanese", "pt": "Portuguese", "tr": "Turkish",
"pl": "Polish", "ca": "Catalan", "nl": "Dutch", "ar": "Arabic", "sv": "Swedish",
"it": "Italian", "id": "Indonesian", "hi": "Hindi", "fi": "Finnish", "vi": "Vietnamese",
"he": "Hebrew", "uk": "Ukrainian", "el": "Greek", "ms": "Malay", "cs": "Czech",
"ro": "Romanian", "da": "Danish", "hu": "Hungarian", "ta": "Tamil", "no": "Norwegian",
"th": "Thai", "ur": "Urdu", "hr": "Croatian", "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin",
"mi": "Maori", "ml": "Malayalam", "cy": "Welsh", "sk": "Slovak", "te": "Telugu", "pa": "Punjabi",
"lv": "Latvian", "bn": "Bengali", "sr": "Serbian", "az": "Azerbaijani", "sl": "Slovenian",
"kn": "Kannada", "et": "Estonian", "mk": "Macedonian", "br": "Breton", "eu": "Basque",
"is": "Icelandic", "hy": "Armenian", "ne": "Nepali", "mn": "Mongolian", "bs": "Bosnian",
"kk": "Kazakh", "sq": "Albanian", "sw": "Swahili", "gl": "Galician", "mr": "Marathi",
"si": "Sinhala", "am": "Amharic", "yo": "Yoruba", "uz": "Uzbek", "af": "Afrikaans",
"oc": "Occitan", "ka": "Georgian", "be": "Belarusian", "tg": "Tajik", "sd": "Sindhi",
"gu": "Gujarati", "so": "Somali", "lo": "Lao", "yi": "Yiddish", "ky": "Kyrgyz",
"tk": "Turkmen", "ht": "Haitian Creole", "ps": "Pashto", "as": "Assamese", "tt": "Tatar",
"ha": "Hausa", "ba": "Bashkir", "jw": "Javanese", "su": "Sundanese"
}
def get_hf_token_instructions():
"""Generates instructions for setting the HF_TOKEN for pyannote."""
return """
**IMPORTANT: Authentication Required for Speaker Identification**
This feature uses the `pyannote/speaker-diarization-3.1` model, which requires a Hugging Face access token.
**How to Add Your Token:**
1. **Accept the model license:** Visit [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) and agree to the terms.
2. **Get your token:** Find it in your Hugging Face account settings: [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
3. **Add the token to this Space:** Go to the **Settings** tab, find **Repository secrets**, click **New secret**, and add a secret named `HF_TOKEN` with your token as the value. Restart the Space after saving.
"""
# --- 2. Global Model Loading (All Local) ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
logging.info(f"Using device: {DEVICE} with data type: {TORCH_DTYPE}")
# ASR Pipeline (Local Whisper)
ASR_PIPELINE = None
try:
logging.info("Loading ASR pipeline (Whisper)...")
ASR_PIPELINE = hf_pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3",
torch_dtype=TORCH_DTYPE,
device=DEVICE,
)
logging.info("ASR pipeline loaded successfully.")
except Exception as e:
logging.error(f"Fatal error: Could not load ASR pipeline. {e}")
# Speaker Diarization Pipeline (Local Pyannote)
HF_TOKEN = os.environ.get("HF_TOKEN")
DIARIZATION_PIPELINE = None
if HF_TOKEN:
try:
logging.info("Loading Speaker Diarization pipeline (v3.1)...")
DIARIZATION_PIPELINE = PyannotePipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN)
DIARIZATION_PIPELINE.to(torch.device(DEVICE))
logging.info("Speaker Diarization pipeline loaded successfully.")
except Exception as e:
logging.error(f"Failed to load Diarization pipeline. Check HF_TOKEN. Error: {e}")
else:
logging.warning("HF_TOKEN not set. Speaker diarization will be disabled.")
# Translation Model Cache (Local Helsinki-NLP)
TRANSLATION_MODELS = {}
logging.info("Translation model cache initialized.")
# --- 3. Core Processing Functions ---
def load_and_resample_audio(audio_path):
try:
waveform, sample_rate = torchaudio.load(audio_path, channels_first=True)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
return waveform.squeeze(0).numpy()
except Exception as e:
raise IOError(f"Error processing audio file {audio_path}: {e}")
def process_audio(audio_input):
if audio_input is None:
raise gr.Error("Please provide an audio file or record audio.")
temp_audio_path = None
try:
# Step 1: Handle audio input
if isinstance(audio_input, tuple):
sample_rate, audio_data = audio_input
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_audio_path = temp_file.name
sf.write(temp_audio_path, audio_data, sample_rate)
else:
temp_audio_path = audio_input
logging.info("Standardizing audio...")
audio_waveform_16k = load_and_resample_audio(temp_audio_path)
# Step 2: ASR with local Whisper pipeline
logging.info("Starting ASR with local Whisper pipeline...")
if not ASR_PIPELINE:
raise gr.Error("ASR pipeline not available. The application cannot proceed.")
asr_output = ASR_PIPELINE(
audio_waveform_16k,
chunk_length_s=30,
batch_size=8,
return_timestamps="word"
)
word_timestamps = asr_output.get("chunks", [])
full_text = asr_output.get("text", "").strip()
# Step 3: Language Detection
detected_language_code = "en"
if full_text:
try:
detected_language_code = detect(full_text)
except LangDetectException:
logging.warning("Language detection failed, defaulting to English.")
detected_language_name = LANGUAGE_NAME_MAPPING.get(detected_language_code, "Unknown")
logging.info(f"Transcription complete. Language: {detected_language_name}")
# Step 4: Speaker Diarization
diarization = None
if DIARIZATION_PIPELINE:
logging.info("Performing speaker diarization...")
try:
diarization = DIARIZATION_PIPELINE(temp_audio_path)
except Exception as e:
logging.error(f"Diarization failed: {e}")
# Step 5: Align ASR and Diarization results
logging.info("Aligning transcription with speaker segments...")
merged_segments = []
if diarization:
speaker_map = [{'start': turn.start, 'end': turn.end, 'speaker': speaker} for turn, _, speaker in diarization.itertracks(yield_label=True)]
for word_info in word_timestamps:
word_start, word_end = word_info['timestamp']
assigned_speaker = next((seg['speaker'] for seg in speaker_map if word_start >= seg['start'] and word_end <= seg['end']), "Unknown")
merged_segments.append({'start': word_start, 'end': word_end, 'text': word_info['text'], 'speaker': assigned_speaker})
else:
for word_info in word_timestamps:
merged_segments.append({'start': word_info['timestamp'][0], 'end': word_info['timestamp'][1], 'text': word_info['text'], 'speaker': 'SPEAKER_00'})
# Merge consecutive words from the same speaker
final_segments = []
if merged_segments:
current_segment = merged_segments[0]
for i in range(1, len(merged_segments)):
next_seg = merged_segments[i]
if next_seg['speaker'] == current_segment['speaker'] and (next_seg['start'] - current_segment['end'] < 0.5):
current_segment['text'] += " " + next_seg['text']
current_segment['end'] = next_seg['end']
else:
final_segments.append(current_segment)
current_segment = next_seg
final_segments.append(current_segment)
diarized_text = "\n".join(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['speaker']}: {seg['text'].strip()}" for seg in final_segments)
# Step 6: Translation with local Helsinki-NLP models
translation_output = "Source language is English. No translation needed."
if detected_language_code != 'en':
model_name = 'Helsinki-NLP/opus-mt-tam-en' if detected_language_code == 'ta' else f'Helsinki-NLP/opus-mt-{detected_language_code}-en'
try:
if model_name not in TRANSLATION_MODELS:
logging.info(f"Loading translation model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE)
TRANSLATION_MODELS[model_name] = (tokenizer, model)
tokenizer, model = TRANSLATION_MODELS[model_name]
texts_to_translate = [seg['text'] for seg in final_segments]
inputs = tokenizer(texts_to_translate, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
translated_ids = model.generate(**inputs)
translated_texts = tokenizer.batch_decode(translated_ids, skip_special_tokens=True)
# Reconstruct translated output with speaker and timing info
translation_lines = []
for i, segment in enumerate(final_segments):
translation_lines.append(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['speaker']}: {translated_texts[i]}")
translation_output = "\n".join(translation_lines)
except Exception as e:
translation_output = f"Translation failed for '{detected_language_name}'. Model may not be available. Error: {e}"
# Step 7: Generate Report
report_content = f"# Audio Processing Report\n\n## Detected Language\n{detected_language_name} ({detected_language_code})\n\n---\n\n## Diarized Transcription\n{diarized_text}\n\n---\n\n## English Translation\n{translation_output}"
with tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False, encoding='utf-8') as report_file:
report_file.write(report_content)
report_path = report_file.name
return (f"{detected_language_name} ({detected_language_code})", diarized_text, translation_output, gr.update(value=report_path, visible=True))
except Exception as e:
logging.error(f"An unexpected error occurred: {e}", exc_info=True)
raise gr.Error(f"An error occurred: {str(e)}")
finally:
if temp_audio_path and temp_audio_path.startswith(tempfile.gettempdir()):
os.remove(temp_audio_path)
if DEVICE == "cuda":
torch.cuda.empty_cache()
# --- 4. Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="Advanced Audio Processor") as app:
gr.Markdown("# Advanced Open-Source Audio Processor")
gr.Markdown("A 100% cost-free tool for transcribing, identifying speakers, and translating audio.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Provide Audio")
audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio")
process_button = gr.Button("Process Audio", variant="primary")
with gr.Accordion("Authentication Instructions (Required for Speaker ID)", open=False):
gr.Markdown(get_hf_token_instructions())
with gr.Column(scale=2):
gr.Markdown("### 2. Processing Results")
detected_language_output = gr.Textbox(label="Detected Language", interactive=False)
with gr.Tabs():
with gr.TabItem("Diarized Transcription"):
diarized_transcription_output = gr.Textbox(label="Full Transcription (with speaker labels)", lines=15, interactive=False, show_copy_button=True)
with gr.TabItem("Translation (to English)"):
translation_output = gr.Textbox(label="Full Translation (with speaker labels)", lines=15, interactive=False, show_copy_button=True)
gr.Markdown("### 3. Download Full Report")
download_report_button = gr.File(label="Download Report (.txt)", visible=False, interactive=False)
process_button.click(
fn=process_audio,
inputs=[audio_input],
outputs=[detected_language_output, diarized_transcription_output, translation_output, download_report_button],
api_name="process_audio"
)
if __name__ == "__main__":
app.launch()