clementBE's picture
Update app.py
3351f3d verified
import gradio as gr
import spaces
import torch
import os
import datetime
import time
from transformers import pipeline
from docx import Document
from pydub import AudioSegment
# Define the available ASR models
MODEL_SIZES = {
"Tiny (Fastest)": "openai/whisper-tiny",
"Base (Faster)": "openai/whisper-base",
"Small (Balanced)": "openai/whisper-small",
"Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
"Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3"
}
# Use a dictionary to cache loaded models
model_cache = {}
# Use a separate cache for the summarization model
summarizer_cache = {}
# Define the fixed chunk length (5 minutes in milliseconds)
CHUNK_LENGTH_MS = 5 * 60 * 1000
# Define summarization parameters
SUMMARY_MIN_LENGTH = 30
SUMMARY_MAX_LENGTH = 150
# Language mapping for Whisper
LANGUAGE_MAP = {
"French": "fr",
"English": "en",
"Spanish": "es",
"German": "de"
}
LANGUAGE_CHOICES = ["Auto-Detect"] + list(LANGUAGE_MAP.keys())
def get_model_pipeline(model_name, pipeline_type, progress):
"""
Initializes and caches an ASR or Summarization pipeline.
"""
cache = model_cache if pipeline_type == "asr" else summarizer_cache
model_id = MODEL_SIZES.get(model_name) if pipeline_type == "asr" else model_name
if model_id not in cache:
progress_start = 0.0 if pipeline_type == "asr" else 0.90
progress_end = 0.50 if pipeline_type == "asr" else 0.95
desc = f"⏳ Loading {model_name} model..." if pipeline_type == "asr" else "🧠 Loading Summarization Model..."
progress(progress_start, desc="πŸš€ Initializing ZeroGPU instance..." if pipeline_type == "asr" else desc)
device = 0 if torch.cuda.is_available() else "cpu"
if pipeline_type == "asr":
pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
device=device,
max_new_tokens=128
)
elif pipeline_type == "summarization":
pipe = pipeline(
"summarization",
model=model_id,
device=device
)
cache[model_id] = pipe
progress(progress_end, desc="βœ… Model loaded successfully!" if pipeline_type == "asr" else "βœ… Summarization Model loaded!")
return cache[model_id]
def format_seconds(seconds):
"""Converts seconds to HH:MM:SS format."""
return str(datetime.timedelta(seconds=int(seconds)))
def create_vtt(segments, file_path):
"""Creates a WebVTT (.vtt) file from transcription segments."""
with open(file_path, "w", encoding="utf-8") as f:
f.write("WEBVTT\n\n")
for i, segment in enumerate(segments):
start_ms = int(segment.get('start', 0) * 1000)
end_ms = int(segment.get('end', 0) * 1000)
def format_time(ms):
hours, remainder = divmod(ms, 3600000)
minutes, remainder = divmod(remainder, 60000)
seconds, milliseconds = divmod(remainder, 1000)
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{int(milliseconds):03}"
start = format_time(start_ms)
end = format_time(end_ms)
f.write(f"{i+1}\n")
f.write(f"{start} --> {end}\n")
f.write(f"{segment.get('text', '').strip()}\n\n")
def create_docx(segments, file_path, with_timestamps):
"""Creates a DOCX (.docx) file from transcription segments."""
document = Document()
document.add_heading("Transcription", 0)
if with_timestamps:
for segment in segments:
text = segment.get('text', '').strip()
start = format_seconds(segment.get('start', 0))
end = format_seconds(segment.get('end', 0))
document.add_paragraph(f"[{start} - {end}] {text}")
else:
full_text = " ".join([segment.get('text', '').strip() for segment in segments])
document.add_paragraph(full_text)
document.save(file_path)
def analyze_audio_and_get_chunks(audio_file):
"""Reads the audio file and generates chunk options for the dropdown."""
if audio_file is None:
return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), "Please upload an audio file first."
try:
audio = AudioSegment.from_file(audio_file)
total_duration_ms = len(audio)
num_chunks = (total_duration_ms + CHUNK_LENGTH_MS - 1) // CHUNK_LENGTH_MS
chunk_options = ["Full Audio"]
for i in range(num_chunks):
start_ms = i * CHUNK_LENGTH_MS
end_ms = min((i + 1) * CHUNK_LENGTH_MS, total_duration_ms)
start_sec = start_ms / 1000
end_sec = end_ms / 1000
start_time_str = format_seconds(start_sec).split('.')[0]
end_time_str = format_seconds(end_sec).split('.')[0]
option_name = f"Chunk {i+1} ({start_time_str} - {end_time_str})"
chunk_options.append(option_name)
status = f"Audio analyzed. Duration: {format_seconds(total_duration_ms/1000.0)}. Found {num_chunks} chunks."
# Add guidance based on the number of chunks
if num_chunks > 6: # More than 30 minutes
status += " ⚠️ **Recommendation:** Select a single chunk to process to avoid GPU memory crash."
return gr.Dropdown(choices=chunk_options, value="Full Audio", interactive=True), status
except Exception as e:
error_msg = f"Error analyzing audio: {e}"
return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), error_msg
# --- MODIFIED: generate_summary to force output language ---
def generate_summary(text, target_language_code, progress):
"""Generates an abstractive summary using a pre-trained T5 model, prompting for the target language."""
try:
summarizer = get_model_pipeline("t5-small", "summarization", progress)
# T5-Small is multilingual but often defaults to English.
# We use a specific prompt based on the target language to force the output.
if target_language_code == "fr":
# Standard French summarization prompt format for T5-like models
prompt = f"rΓ©sumer: {text}"
elif target_language_code == "es":
prompt = f"resumir: {text}"
else:
# Default English prompt (or for auto-detect)
prompt = f"summarize: {text}"
summary = summarizer(
prompt,
max_length=SUMMARY_MAX_LENGTH,
min_length=SUMMARY_MIN_LENGTH,
do_sample=False
)[0]['summary_text']
return summary
except Exception as e:
return f"Error during summarization: {e}"
# -----------------------------------------------------------
@spaces.GPU
def transcribe_and_export(audio_file, model_size, chunk_choice, selected_language, vtt_output, docx_timestamp_output, docx_no_timestamp_output, summarize_output, progress=gr.Progress()):
"""
Main function to transcribe audio and export. Uses selected_language to force
the transcription language, fixing the French issue.
"""
if audio_file is None:
return (None, "", None, gr.Audio(value=None), "Please upload an audio file.")
start_time = time.time()
pipe = get_model_pipeline(model_size, "asr", progress)
# 1. Determine which segment to process
audio_segment_to_process = audio_file
offset = 0.0
if chunk_choice != "Full Audio":
progress(0.70, desc="βœ‚οΈ Preparing audio segment...")
try:
chunk_num = int(chunk_choice.split(' ')[1]) - 1
full_audio = AudioSegment.from_file(audio_file)
total_duration_ms = len(full_audio)
start_ms = chunk_num * CHUNK_LENGTH_MS
end_ms = min((chunk_num + 1) * CHUNK_LENGTH_MS, total_duration_ms)
chunk = full_audio[start_ms:end_ms]
temp_chunk_path = "/tmp/selected_chunk.mp3"
chunk.export(temp_chunk_path, format="mp3")
audio_segment_to_process = temp_chunk_path
offset = start_ms / 1000.0
except Exception as e:
return (None, "", None, gr.Audio(value=None), f"Error preparing audio chunk: {e}")
# 2. Define generation arguments (Language fix implemented here)
generate_kwargs = {}
lang_code = None
if selected_language != "Auto-Detect":
lang_code = LANGUAGE_MAP.get(selected_language, None)
if lang_code:
# Crucial for French fix: Pass the language code to Whisper
generate_kwargs["language"] = lang_code
# 3. Transcribe the segment
progress(0.75, desc=f"🎀 Transcribing {chunk_choice}...")
raw_output = pipe(
audio_segment_to_process,
return_timestamps="word",
# Pass the refined generate_kwargs
generate_kwargs=generate_kwargs
)
# 4. Process and adjust segments
full_segments = raw_output.get("chunks", [])
transcribed_text = raw_output.get('text', '').strip()
if chunk_choice != "Full Audio":
for segment in full_segments:
segment['start'] = segment.get('start', 0.0) + offset
segment['end'] = segment.get('end', 0.0) + offset
if os.path.exists(audio_segment_to_process):
os.remove(audio_segment_to_process)
# 5. Generate Summary (if requested)
summary_text = ""
if summarize_output and transcribed_text:
# Pass the language code to the summary function for explicit prompting
summary_text = generate_summary(transcribed_text, lang_code, progress)
elif summarize_output and not transcribed_text:
summary_text = "Transcription failed or was empty, cannot generate summary."
# 6. Generate output files
outputs = {}
progress(0.95, desc="πŸ“ Generating output files...")
if vtt_output:
vtt_path = "transcription.vtt"
create_vtt(full_segments, vtt_path)
outputs["VTT"] = vtt_path
if docx_timestamp_output:
docx_ts_path = "transcription_with_timestamps.docx"
create_docx(full_segments, docx_ts_path, with_timestamps=True)
outputs["DOCX (with timestamps)"] = docx_ts_path
if docx_no_timestamp_output:
docx_no_ts_path = "transcription_without_timestamps.docx"
create_docx(full_segments, docx_no_ts_path, with_timestamps=False)
outputs["DOCX (without timestamps)"] = docx_no_ts_path
end_time = time.time()
total_time = end_time - start_time
downloadable_files = [path for path in outputs.values()]
status_message = f"βœ… Transcription complete! Total time: {total_time:.2f} seconds."
return (
transcribed_text,
summary_text,
gr.Files(value=downloadable_files, label="Download Transcripts"),
gr.Audio(value=None),
status_message
)
# --- Gradio UI ---
with gr.Blocks(title="Whisper ZeroGPU Transcription & Summarization") as demo:
gr.Markdown("# πŸŽ™οΈ Whisper ZeroGPU Transcription & Summarization")
gr.Markdown("1. **Upload** audio. 2. **Analyze** for chunks. 3. Select **Model**, **Chunk**, and **Language**, then **Transcribe**.")
# NEW GUIDANCE COMMENT: Crucial warning for large files
gr.Markdown(
"""
⚠️ **GPU Memory Warning:** For files longer than **30 minutes** (approx. 6 chunks),
it's highly recommended to select a single **Chunk** to process instead of **'Full Audio'** to prevent a GPU memory crash on the platform.
"""
)
with gr.Row():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File")
with gr.Column(scale=2):
model_selector = gr.Dropdown(
label="Choose Whisper Model Size",
choices=list(MODEL_SIZES.keys()),
value="Distil-Large-v3 (General Purpose)"
)
# LANGUAGE FIX: Selector to explicitly set the expected language
language_selector = gr.Dropdown(
label="Select Expected Language (Crucial for French/Non-English)",
choices=LANGUAGE_CHOICES,
value="French", # Default to French
interactive=True
)
analyze_btn = gr.Button("Analyze Audio πŸ”Ž", variant="secondary")
chunk_selector = gr.Dropdown(
label="Select Audio Segment (5-minute chunks)",
choices=["Full Audio"],
value="Full Audio",
interactive=False
)
gr.Markdown("### Output Options")
with gr.Row():
summarize_checkbox = gr.Checkbox(label="Generate Summary", value=False)
vtt_checkbox = gr.Checkbox(label="VTT", value=True)
with gr.Row():
docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False)
docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True)
transcribe_btn = gr.Button("Transcribe", variant="primary")
status_text = gr.Textbox(label="Status", interactive=False)
transcription_output = gr.Textbox(label="Full Transcription", lines=10)
summary_output = gr.Textbox(label="Summary (Abstractive)", lines=3)
downloadable_files_output = gr.Files(label="Download Transcripts")
analyze_btn.click(
fn=analyze_audio_and_get_chunks,
inputs=[audio_input],
outputs=[chunk_selector, status_text]
)
transcribe_btn.click(
fn=transcribe_and_export,
inputs=[audio_input, model_selector, chunk_selector, language_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, summarize_checkbox],
outputs=[transcription_output, summary_output, downloadable_files_output, audio_input, status_text]
)
if __name__ == "__main__":
demo.launch()