Spaces:
Sleeping
Sleeping
| # app.py | |
| # A 100% OPEN-SOURCE audio processing application. | |
| # - Local Whisper for Transcription | |
| # - Local Pyannote for Diarization | |
| # - Local Helsinki-NLP for Translation | |
| import os | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torchaudio | |
| from transformers import pipeline as hf_pipeline | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import tempfile | |
| import logging | |
| import warnings | |
| from pyannote.audio import Pipeline as PyannotePipeline | |
| from langdetect import detect, LangDetectException | |
| # --- 1. Initial Setup & Configuration --- | |
| warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.functional') | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Language name mapping | |
| LANGUAGE_NAME_MAPPING = { | |
| "en": "English", "zh-cn": "Chinese", "de": "German", "es": "Spanish", "ru": "Russian", | |
| "ko": "Korean", "fr": "French", "ja": "Japanese", "pt": "Portuguese", "tr": "Turkish", | |
| "pl": "Polish", "ca": "Catalan", "nl": "Dutch", "ar": "Arabic", "sv": "Swedish", | |
| "it": "Italian", "id": "Indonesian", "hi": "Hindi", "fi": "Finnish", "vi": "Vietnamese", | |
| "he": "Hebrew", "uk": "Ukrainian", "el": "Greek", "ms": "Malay", "cs": "Czech", | |
| "ro": "Romanian", "da": "Danish", "hu": "Hungarian", "ta": "Tamil", "no": "Norwegian", | |
| "th": "Thai", "ur": "Urdu", "hr": "Croatian", "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin", | |
| "mi": "Maori", "ml": "Malayalam", "cy": "Welsh", "sk": "Slovak", "te": "Telugu", "pa": "Punjabi", | |
| "lv": "Latvian", "bn": "Bengali", "sr": "Serbian", "az": "Azerbaijani", "sl": "Slovenian", | |
| "kn": "Kannada", "et": "Estonian", "mk": "Macedonian", "br": "Breton", "eu": "Basque", | |
| "is": "Icelandic", "hy": "Armenian", "ne": "Nepali", "mn": "Mongolian", "bs": "Bosnian", | |
| "kk": "Kazakh", "sq": "Albanian", "sw": "Swahili", "gl": "Galician", "mr": "Marathi", | |
| "si": "Sinhala", "am": "Amharic", "yo": "Yoruba", "uz": "Uzbek", "af": "Afrikaans", | |
| "oc": "Occitan", "ka": "Georgian", "be": "Belarusian", "tg": "Tajik", "sd": "Sindhi", | |
| "gu": "Gujarati", "so": "Somali", "lo": "Lao", "yi": "Yiddish", "ky": "Kyrgyz", | |
| "tk": "Turkmen", "ht": "Haitian Creole", "ps": "Pashto", "as": "Assamese", "tt": "Tatar", | |
| "ha": "Hausa", "ba": "Bashkir", "jw": "Javanese", "su": "Sundanese" | |
| } | |
| def get_hf_token_instructions(): | |
| """Generates instructions for setting the HF_TOKEN for pyannote.""" | |
| return """ | |
| **IMPORTANT: Authentication Required for Speaker Identification** | |
| This feature uses the `pyannote/speaker-diarization-3.1` model, which requires a Hugging Face access token. | |
| **How to Add Your Token:** | |
| 1. **Accept the model license:** Visit [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) and agree to the terms. | |
| 2. **Get your token:** Find it in your Hugging Face account settings: [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). | |
| 3. **Add the token to this Space:** Go to the **Settings** tab, find **Repository secrets**, click **New secret**, and add a secret named `HF_TOKEN` with your token as the value. Restart the Space after saving. | |
| """ | |
| # --- 2. Global Model Loading (All Local) --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| logging.info(f"Using device: {DEVICE} with data type: {TORCH_DTYPE}") | |
| # ASR Pipeline (Local Whisper) | |
| ASR_PIPELINE = None | |
| try: | |
| logging.info("Loading ASR pipeline (Whisper)...") | |
| ASR_PIPELINE = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3", | |
| torch_dtype=TORCH_DTYPE, | |
| device=DEVICE, | |
| ) | |
| logging.info("ASR pipeline loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Fatal error: Could not load ASR pipeline. {e}") | |
| # Speaker Diarization Pipeline (Local Pyannote) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DIARIZATION_PIPELINE = None | |
| if HF_TOKEN: | |
| try: | |
| logging.info("Loading Speaker Diarization pipeline (v3.1)...") | |
| DIARIZATION_PIPELINE = PyannotePipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=HF_TOKEN) | |
| DIARIZATION_PIPELINE.to(torch.device(DEVICE)) | |
| logging.info("Speaker Diarization pipeline loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Failed to load Diarization pipeline. Check HF_TOKEN. Error: {e}") | |
| else: | |
| logging.warning("HF_TOKEN not set. Speaker diarization will be disabled.") | |
| # Translation Model Cache (Local Helsinki-NLP) | |
| TRANSLATION_MODELS = {} | |
| logging.info("Translation model cache initialized.") | |
| # --- 3. Core Processing Functions --- | |
| def load_and_resample_audio(audio_path): | |
| try: | |
| waveform, sample_rate = torchaudio.load(audio_path, channels_first=True) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| return waveform.squeeze(0).numpy() | |
| except Exception as e: | |
| raise IOError(f"Error processing audio file {audio_path}: {e}") | |
| def process_audio(audio_input): | |
| if audio_input is None: | |
| raise gr.Error("Please provide an audio file or record audio.") | |
| temp_audio_path = None | |
| try: | |
| # Step 1: Handle audio input | |
| if isinstance(audio_input, tuple): | |
| sample_rate, audio_data = audio_input | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_audio_path = temp_file.name | |
| sf.write(temp_audio_path, audio_data, sample_rate) | |
| else: | |
| temp_audio_path = audio_input | |
| logging.info("Standardizing audio...") | |
| audio_waveform_16k = load_and_resample_audio(temp_audio_path) | |
| # Step 2: ASR with local Whisper pipeline | |
| logging.info("Starting ASR with local Whisper pipeline...") | |
| if not ASR_PIPELINE: | |
| raise gr.Error("ASR pipeline not available. The application cannot proceed.") | |
| asr_output = ASR_PIPELINE( | |
| audio_waveform_16k, | |
| chunk_length_s=30, | |
| batch_size=8, | |
| return_timestamps="word" | |
| ) | |
| word_timestamps = asr_output.get("chunks", []) | |
| full_text = asr_output.get("text", "").strip() | |
| # Step 3: Language Detection | |
| detected_language_code = "en" | |
| if full_text: | |
| try: | |
| detected_language_code = detect(full_text) | |
| except LangDetectException: | |
| logging.warning("Language detection failed, defaulting to English.") | |
| detected_language_name = LANGUAGE_NAME_MAPPING.get(detected_language_code, "Unknown") | |
| logging.info(f"Transcription complete. Language: {detected_language_name}") | |
| # Step 4: Speaker Diarization | |
| diarization = None | |
| if DIARIZATION_PIPELINE: | |
| logging.info("Performing speaker diarization...") | |
| try: | |
| diarization = DIARIZATION_PIPELINE(temp_audio_path) | |
| except Exception as e: | |
| logging.error(f"Diarization failed: {e}") | |
| # Step 5: Align ASR and Diarization results | |
| logging.info("Aligning transcription with speaker segments...") | |
| merged_segments = [] | |
| if diarization: | |
| speaker_map = [{'start': turn.start, 'end': turn.end, 'speaker': speaker} for turn, _, speaker in diarization.itertracks(yield_label=True)] | |
| for word_info in word_timestamps: | |
| word_start, word_end = word_info['timestamp'] | |
| assigned_speaker = next((seg['speaker'] for seg in speaker_map if word_start >= seg['start'] and word_end <= seg['end']), "Unknown") | |
| merged_segments.append({'start': word_start, 'end': word_end, 'text': word_info['text'], 'speaker': assigned_speaker}) | |
| else: | |
| for word_info in word_timestamps: | |
| merged_segments.append({'start': word_info['timestamp'][0], 'end': word_info['timestamp'][1], 'text': word_info['text'], 'speaker': 'SPEAKER_00'}) | |
| # Merge consecutive words from the same speaker | |
| final_segments = [] | |
| if merged_segments: | |
| current_segment = merged_segments[0] | |
| for i in range(1, len(merged_segments)): | |
| next_seg = merged_segments[i] | |
| if next_seg['speaker'] == current_segment['speaker'] and (next_seg['start'] - current_segment['end'] < 0.5): | |
| current_segment['text'] += " " + next_seg['text'] | |
| current_segment['end'] = next_seg['end'] | |
| else: | |
| final_segments.append(current_segment) | |
| current_segment = next_seg | |
| final_segments.append(current_segment) | |
| diarized_text = "\n".join(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['speaker']}: {seg['text'].strip()}" for seg in final_segments) | |
| # Step 6: Translation with local Helsinki-NLP models | |
| translation_output = "Source language is English. No translation needed." | |
| if detected_language_code != 'en': | |
| model_name = 'Helsinki-NLP/opus-mt-tam-en' if detected_language_code == 'ta' else f'Helsinki-NLP/opus-mt-{detected_language_code}-en' | |
| try: | |
| if model_name not in TRANSLATION_MODELS: | |
| logging.info(f"Loading translation model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(DEVICE) | |
| TRANSLATION_MODELS[model_name] = (tokenizer, model) | |
| tokenizer, model = TRANSLATION_MODELS[model_name] | |
| texts_to_translate = [seg['text'] for seg in final_segments] | |
| inputs = tokenizer(texts_to_translate, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) | |
| translated_ids = model.generate(**inputs) | |
| translated_texts = tokenizer.batch_decode(translated_ids, skip_special_tokens=True) | |
| # Reconstruct translated output with speaker and timing info | |
| translation_lines = [] | |
| for i, segment in enumerate(final_segments): | |
| translation_lines.append(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] {segment['speaker']}: {translated_texts[i]}") | |
| translation_output = "\n".join(translation_lines) | |
| except Exception as e: | |
| translation_output = f"Translation failed for '{detected_language_name}'. Model may not be available. Error: {e}" | |
| # Step 7: Generate Report | |
| report_content = f"# Audio Processing Report\n\n## Detected Language\n{detected_language_name} ({detected_language_code})\n\n---\n\n## Diarized Transcription\n{diarized_text}\n\n---\n\n## English Translation\n{translation_output}" | |
| with tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False, encoding='utf-8') as report_file: | |
| report_file.write(report_content) | |
| report_path = report_file.name | |
| return (f"{detected_language_name} ({detected_language_code})", diarized_text, translation_output, gr.update(value=report_path, visible=True)) | |
| except Exception as e: | |
| logging.error(f"An unexpected error occurred: {e}", exc_info=True) | |
| raise gr.Error(f"An error occurred: {str(e)}") | |
| finally: | |
| if temp_audio_path and temp_audio_path.startswith(tempfile.gettempdir()): | |
| os.remove(temp_audio_path) | |
| if DEVICE == "cuda": | |
| torch.cuda.empty_cache() | |
| # --- 4. Gradio User Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Advanced Audio Processor") as app: | |
| gr.Markdown("# Advanced Open-Source Audio Processor") | |
| gr.Markdown("A 100% cost-free tool for transcribing, identifying speakers, and translating audio.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Provide Audio") | |
| audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload or Record Audio") | |
| process_button = gr.Button("Process Audio", variant="primary") | |
| with gr.Accordion("Authentication Instructions (Required for Speaker ID)", open=False): | |
| gr.Markdown(get_hf_token_instructions()) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 2. Processing Results") | |
| detected_language_output = gr.Textbox(label="Detected Language", interactive=False) | |
| with gr.Tabs(): | |
| with gr.TabItem("Diarized Transcription"): | |
| diarized_transcription_output = gr.Textbox(label="Full Transcription (with speaker labels)", lines=15, interactive=False, show_copy_button=True) | |
| with gr.TabItem("Translation (to English)"): | |
| translation_output = gr.Textbox(label="Full Translation (with speaker labels)", lines=15, interactive=False, show_copy_button=True) | |
| gr.Markdown("### 3. Download Full Report") | |
| download_report_button = gr.File(label="Download Report (.txt)", visible=False, interactive=False) | |
| process_button.click( | |
| fn=process_audio, | |
| inputs=[audio_input], | |
| outputs=[detected_language_output, diarized_transcription_output, translation_output, download_report_button], | |
| api_name="process_audio" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |