Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| import os | |
| import datetime | |
| import time | |
| from transformers import pipeline | |
| from docx import Document | |
| from pydub import AudioSegment | |
| # Define the available ASR models | |
| MODEL_SIZES = { | |
| "Tiny (Fastest)": "openai/whisper-tiny", | |
| "Base (Faster)": "openai/whisper-base", | |
| "Small (Balanced)": "openai/whisper-small", | |
| "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3", | |
| "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3" | |
| } | |
| # Use a dictionary to cache loaded models | |
| model_cache = {} | |
| # Use a separate cache for the summarization model | |
| summarizer_cache = {} | |
| # Define the fixed chunk length (5 minutes in milliseconds) | |
| CHUNK_LENGTH_MS = 5 * 60 * 1000 | |
| # Define summarization parameters | |
| SUMMARY_MIN_LENGTH = 30 | |
| SUMMARY_MAX_LENGTH = 150 | |
| # Language mapping for Whisper | |
| LANGUAGE_MAP = { | |
| "French": "fr", | |
| "English": "en", | |
| "Spanish": "es", | |
| "German": "de" | |
| } | |
| LANGUAGE_CHOICES = ["Auto-Detect"] + list(LANGUAGE_MAP.keys()) | |
| def get_model_pipeline(model_name, pipeline_type, progress): | |
| """ | |
| Initializes and caches an ASR or Summarization pipeline. | |
| """ | |
| cache = model_cache if pipeline_type == "asr" else summarizer_cache | |
| model_id = MODEL_SIZES.get(model_name) if pipeline_type == "asr" else model_name | |
| if model_id not in cache: | |
| progress_start = 0.0 if pipeline_type == "asr" else 0.90 | |
| progress_end = 0.50 if pipeline_type == "asr" else 0.95 | |
| desc = f"⏳ Loading {model_name} model..." if pipeline_type == "asr" else "🧠 Loading Summarization Model..." | |
| progress(progress_start, desc="🚀 Initializing ZeroGPU instance..." if pipeline_type == "asr" else desc) | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| if pipeline_type == "asr": | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| device=device, | |
| max_new_tokens=128 | |
| ) | |
| elif pipeline_type == "summarization": | |
| pipe = pipeline( | |
| "summarization", | |
| model=model_id, | |
| device=device | |
| ) | |
| cache[model_id] = pipe | |
| progress(progress_end, desc="✅ Model loaded successfully!" if pipeline_type == "asr" else "✅ Summarization Model loaded!") | |
| return cache[model_id] | |
| def format_seconds(seconds): | |
| """Converts seconds to HH:MM:SS format.""" | |
| return str(datetime.timedelta(seconds=int(seconds))) | |
| def create_vtt(segments, file_path): | |
| """Creates a WebVTT (.vtt) file from transcription segments.""" | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write("WEBVTT\n\n") | |
| for i, segment in enumerate(segments): | |
| start_ms = int(segment.get('start', 0) * 1000) | |
| end_ms = int(segment.get('end', 0) * 1000) | |
| def format_time(ms): | |
| hours, remainder = divmod(ms, 3600000) | |
| minutes, remainder = divmod(remainder, 60000) | |
| seconds, milliseconds = divmod(remainder, 1000) | |
| return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{int(milliseconds):03}" | |
| start = format_time(start_ms) | |
| end = format_time(end_ms) | |
| f.write(f"{i+1}\n") | |
| f.write(f"{start} --> {end}\n") | |
| f.write(f"{segment.get('text', '').strip()}\n\n") | |
| def create_docx(segments, file_path, with_timestamps): | |
| """Creates a DOCX (.docx) file from transcription segments.""" | |
| document = Document() | |
| document.add_heading("Transcription", 0) | |
| if with_timestamps: | |
| for segment in segments: | |
| text = segment.get('text', '').strip() | |
| start = format_seconds(segment.get('start', 0)) | |
| end = format_seconds(segment.get('end', 0)) | |
| document.add_paragraph(f"[{start} - {end}] {text}") | |
| else: | |
| full_text = " ".join([segment.get('text', '').strip() for segment in segments]) | |
| document.add_paragraph(full_text) | |
| document.save(file_path) | |
| def analyze_audio_and_get_chunks(audio_file): | |
| """Reads the audio file and generates chunk options for the dropdown.""" | |
| if audio_file is None: | |
| return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), "Please upload an audio file first." | |
| try: | |
| audio = AudioSegment.from_file(audio_file) | |
| total_duration_ms = len(audio) | |
| num_chunks = (total_duration_ms + CHUNK_LENGTH_MS - 1) // CHUNK_LENGTH_MS | |
| chunk_options = ["Full Audio"] | |
| for i in range(num_chunks): | |
| start_ms = i * CHUNK_LENGTH_MS | |
| end_ms = min((i + 1) * CHUNK_LENGTH_MS, total_duration_ms) | |
| start_sec = start_ms / 1000 | |
| end_sec = end_ms / 1000 | |
| start_time_str = format_seconds(start_sec).split('.')[0] | |
| end_time_str = format_seconds(end_sec).split('.')[0] | |
| option_name = f"Chunk {i+1} ({start_time_str} - {end_time_str})" | |
| chunk_options.append(option_name) | |
| status = f"Audio analyzed. Duration: {format_seconds(total_duration_ms/1000.0)}. Found {num_chunks} chunks." | |
| # Add guidance based on the number of chunks | |
| if num_chunks > 6: # More than 30 minutes | |
| status += " ⚠️ **Recommendation:** Select a single chunk to process to avoid GPU memory crash." | |
| return gr.Dropdown(choices=chunk_options, value="Full Audio", interactive=True), status | |
| except Exception as e: | |
| error_msg = f"Error analyzing audio: {e}" | |
| return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), error_msg | |
| # --- MODIFIED: generate_summary to force output language --- | |
| def generate_summary(text, target_language_code, progress): | |
| """Generates an abstractive summary using a pre-trained T5 model, prompting for the target language.""" | |
| try: | |
| summarizer = get_model_pipeline("t5-small", "summarization", progress) | |
| # T5-Small is multilingual but often defaults to English. | |
| # We use a specific prompt based on the target language to force the output. | |
| if target_language_code == "fr": | |
| # Standard French summarization prompt format for T5-like models | |
| prompt = f"résumer: {text}" | |
| elif target_language_code == "es": | |
| prompt = f"resumir: {text}" | |
| else: | |
| # Default English prompt (or for auto-detect) | |
| prompt = f"summarize: {text}" | |
| summary = summarizer( | |
| prompt, | |
| max_length=SUMMARY_MAX_LENGTH, | |
| min_length=SUMMARY_MIN_LENGTH, | |
| do_sample=False | |
| )[0]['summary_text'] | |
| return summary | |
| except Exception as e: | |
| return f"Error during summarization: {e}" | |
| # ----------------------------------------------------------- | |
| def transcribe_and_export(audio_file, model_size, chunk_choice, selected_language, vtt_output, docx_timestamp_output, docx_no_timestamp_output, summarize_output, progress=gr.Progress()): | |
| """ | |
| Main function to transcribe audio and export. Uses selected_language to force | |
| the transcription language, fixing the French issue. | |
| """ | |
| if audio_file is None: | |
| return (None, "", None, gr.Audio(value=None), "Please upload an audio file.") | |
| start_time = time.time() | |
| pipe = get_model_pipeline(model_size, "asr", progress) | |
| # 1. Determine which segment to process | |
| audio_segment_to_process = audio_file | |
| offset = 0.0 | |
| if chunk_choice != "Full Audio": | |
| progress(0.70, desc="✂️ Preparing audio segment...") | |
| try: | |
| chunk_num = int(chunk_choice.split(' ')[1]) - 1 | |
| full_audio = AudioSegment.from_file(audio_file) | |
| total_duration_ms = len(full_audio) | |
| start_ms = chunk_num * CHUNK_LENGTH_MS | |
| end_ms = min((chunk_num + 1) * CHUNK_LENGTH_MS, total_duration_ms) | |
| chunk = full_audio[start_ms:end_ms] | |
| temp_chunk_path = "/tmp/selected_chunk.mp3" | |
| chunk.export(temp_chunk_path, format="mp3") | |
| audio_segment_to_process = temp_chunk_path | |
| offset = start_ms / 1000.0 | |
| except Exception as e: | |
| return (None, "", None, gr.Audio(value=None), f"Error preparing audio chunk: {e}") | |
| # 2. Define generation arguments (Language fix implemented here) | |
| generate_kwargs = {} | |
| lang_code = None | |
| if selected_language != "Auto-Detect": | |
| lang_code = LANGUAGE_MAP.get(selected_language, None) | |
| if lang_code: | |
| # Crucial for French fix: Pass the language code to Whisper | |
| generate_kwargs["language"] = lang_code | |
| # 3. Transcribe the segment | |
| progress(0.75, desc=f"🎤 Transcribing {chunk_choice}...") | |
| raw_output = pipe( | |
| audio_segment_to_process, | |
| return_timestamps="word", | |
| # Pass the refined generate_kwargs | |
| generate_kwargs=generate_kwargs | |
| ) | |
| # 4. Process and adjust segments | |
| full_segments = raw_output.get("chunks", []) | |
| transcribed_text = raw_output.get('text', '').strip() | |
| if chunk_choice != "Full Audio": | |
| for segment in full_segments: | |
| segment['start'] = segment.get('start', 0.0) + offset | |
| segment['end'] = segment.get('end', 0.0) + offset | |
| if os.path.exists(audio_segment_to_process): | |
| os.remove(audio_segment_to_process) | |
| # 5. Generate Summary (if requested) | |
| summary_text = "" | |
| if summarize_output and transcribed_text: | |
| # Pass the language code to the summary function for explicit prompting | |
| summary_text = generate_summary(transcribed_text, lang_code, progress) | |
| elif summarize_output and not transcribed_text: | |
| summary_text = "Transcription failed or was empty, cannot generate summary." | |
| # 6. Generate output files | |
| outputs = {} | |
| progress(0.95, desc="📝 Generating output files...") | |
| if vtt_output: | |
| vtt_path = "transcription.vtt" | |
| create_vtt(full_segments, vtt_path) | |
| outputs["VTT"] = vtt_path | |
| if docx_timestamp_output: | |
| docx_ts_path = "transcription_with_timestamps.docx" | |
| create_docx(full_segments, docx_ts_path, with_timestamps=True) | |
| outputs["DOCX (with timestamps)"] = docx_ts_path | |
| if docx_no_timestamp_output: | |
| docx_no_ts_path = "transcription_without_timestamps.docx" | |
| create_docx(full_segments, docx_no_ts_path, with_timestamps=False) | |
| outputs["DOCX (without timestamps)"] = docx_no_ts_path | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| downloadable_files = [path for path in outputs.values()] | |
| status_message = f"✅ Transcription complete! Total time: {total_time:.2f} seconds." | |
| return ( | |
| transcribed_text, | |
| summary_text, | |
| gr.Files(value=downloadable_files, label="Download Transcripts"), | |
| gr.Audio(value=None), | |
| status_message | |
| ) | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="Whisper ZeroGPU Transcription & Summarization") as demo: | |
| gr.Markdown("# 🎙️ Whisper ZeroGPU Transcription & Summarization") | |
| gr.Markdown("1. **Upload** audio. 2. **Analyze** for chunks. 3. Select **Model**, **Chunk**, and **Language**, then **Transcribe**.") | |
| # NEW GUIDANCE COMMENT: Crucial warning for large files | |
| gr.Markdown( | |
| """ | |
| ⚠️ **GPU Memory Warning:** For files longer than **30 minutes** (approx. 6 chunks), | |
| it's highly recommended to select a single **Chunk** to process instead of **'Full Audio'** to prevent a GPU memory crash on the platform. | |
| """ | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File") | |
| with gr.Column(scale=2): | |
| model_selector = gr.Dropdown( | |
| label="Choose Whisper Model Size", | |
| choices=list(MODEL_SIZES.keys()), | |
| value="Distil-Large-v3 (General Purpose)" | |
| ) | |
| # LANGUAGE FIX: Selector to explicitly set the expected language | |
| language_selector = gr.Dropdown( | |
| label="Select Expected Language (Crucial for French/Non-English)", | |
| choices=LANGUAGE_CHOICES, | |
| value="French", # Default to French | |
| interactive=True | |
| ) | |
| analyze_btn = gr.Button("Analyze Audio 🔎", variant="secondary") | |
| chunk_selector = gr.Dropdown( | |
| label="Select Audio Segment (5-minute chunks)", | |
| choices=["Full Audio"], | |
| value="Full Audio", | |
| interactive=False | |
| ) | |
| gr.Markdown("### Output Options") | |
| with gr.Row(): | |
| summarize_checkbox = gr.Checkbox(label="Generate Summary", value=False) | |
| vtt_checkbox = gr.Checkbox(label="VTT", value=True) | |
| with gr.Row(): | |
| docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False) | |
| docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True) | |
| transcribe_btn = gr.Button("Transcribe", variant="primary") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| transcription_output = gr.Textbox(label="Full Transcription", lines=10) | |
| summary_output = gr.Textbox(label="Summary (Abstractive)", lines=3) | |
| downloadable_files_output = gr.Files(label="Download Transcripts") | |
| analyze_btn.click( | |
| fn=analyze_audio_and_get_chunks, | |
| inputs=[audio_input], | |
| outputs=[chunk_selector, status_text] | |
| ) | |
| transcribe_btn.click( | |
| fn=transcribe_and_export, | |
| inputs=[audio_input, model_selector, chunk_selector, language_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, summarize_checkbox], | |
| outputs=[transcription_output, summary_output, downloadable_files_output, audio_input, status_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |