File size: 14,271 Bytes
2b96b70
ded0018
 
 
 
 
 
475571f
6e566c7
715fdfb
13d1791
ded0018
 
 
 
 
3351f3d
46499c0
 
ded0018
 
13d1791
 
ded0018
7b256c2
 
13d1791
 
 
7b256c2
13d1791
 
 
 
 
 
 
 
 
 
4091834
13d1791
4091834
13d1791
 
 
 
 
 
 
 
 
 
ded0018
 
13d1791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46499c0
7b256c2
13d1791
7b256c2
 
ded0018
13d1791
ded0018
46499c0
ded0018
4091834
 
 
 
 
 
 
 
 
 
 
 
ded0018
 
 
 
 
13d1791
ded0018
 
 
 
 
 
7b256c2
 
ded0018
 
 
 
4091834
ded0018
 
7b256c2
13d1791
ded0018
7b256c2
6e566c7
7b256c2
6e566c7
 
7b256c2
6e566c7
7b256c2
6e566c7
7b256c2
 
6e566c7
7b256c2
 
6e566c7
7b256c2
 
6e566c7
7b256c2
 
 
 
3351f3d
 
13d1791
 
7b256c2
 
 
 
 
 
3351f3d
a85e37e
 
13d1791
 
 
3351f3d
 
a85e37e
3351f3d
a85e37e
 
 
 
3351f3d
a85e37e
 
13d1791
a85e37e
13d1791
 
 
 
 
 
 
 
3351f3d
7b256c2
 
13d1791
7b256c2
3351f3d
 
7b256c2
 
13d1791
7b256c2
 
13d1791
7b256c2
 
 
13d1791
7b256c2
 
 
 
 
 
 
6e566c7
7b256c2
 
6e566c7
7b256c2
 
 
6e566c7
7b256c2
13d1791
6e566c7
7b256c2
13d1791
7b256c2
13d1791
7b256c2
13d1791
a85e37e
13d1791
 
 
3351f3d
a85e37e
 
7b256c2
 
 
 
 
3351f3d
7b256c2
 
 
 
 
 
4091834
7b256c2
 
 
 
 
 
 
4091834
13d1791
 
 
a602b66
a85e37e
13d1791
 
 
 
ded0018
13d1791
ded0018
 
 
6e566c7
ded0018
 
 
 
6e566c7
ded0018
 
 
 
6e566c7
ded0018
 
 
 
 
 
 
 
 
13d1791
ded0018
7b256c2
ded0018
 
46499c0
ded0018
13d1791
 
 
 
3351f3d
13d1791
 
 
 
 
 
ded0018
46499c0
ded0018
 
 
 
 
 
13d1791
 
 
3351f3d
13d1791
3351f3d
13d1791
3351f3d
13d1791
ded0018
6e566c7
7b256c2
 
 
 
 
 
13d1791
7b256c2
 
13d1791
ded0018
13d1791
ded0018
13d1791
 
ded0018
 
 
 
 
 
 
13d1791
ded0018
 
7b256c2
 
 
a85e37e
7b256c2
 
ded0018
 
13d1791
 
715fdfb
32e69c2
39d879d
ded0018
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import gradio as gr
import spaces
import torch
import os
import datetime
import time
from transformers import pipeline
from docx import Document
from pydub import AudioSegment

# Define the available ASR models
MODEL_SIZES = {
    "Tiny (Fastest)": "openai/whisper-tiny",
    "Base (Faster)": "openai/whisper-base",
    "Small (Balanced)": "openai/whisper-small",
    "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
    "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3" 
}

# Use a dictionary to cache loaded models
model_cache = {}
# Use a separate cache for the summarization model
summarizer_cache = {}

# Define the fixed chunk length (5 minutes in milliseconds)
CHUNK_LENGTH_MS = 5 * 60 * 1000
# Define summarization parameters
SUMMARY_MIN_LENGTH = 30
SUMMARY_MAX_LENGTH = 150

# Language mapping for Whisper
LANGUAGE_MAP = {
    "French": "fr",
    "English": "en",
    "Spanish": "es",
    "German": "de"
}
LANGUAGE_CHOICES = ["Auto-Detect"] + list(LANGUAGE_MAP.keys())

def get_model_pipeline(model_name, pipeline_type, progress):
    """
    Initializes and caches an ASR or Summarization pipeline.
    """
    cache = model_cache if pipeline_type == "asr" else summarizer_cache
    model_id = MODEL_SIZES.get(model_name) if pipeline_type == "asr" else model_name
    
    if model_id not in cache:
        progress_start = 0.0 if pipeline_type == "asr" else 0.90
        progress_end = 0.50 if pipeline_type == "asr" else 0.95
        desc = f"⏳ Loading {model_name} model..." if pipeline_type == "asr" else "🧠 Loading Summarization Model..."

        progress(progress_start, desc="🚀 Initializing ZeroGPU instance..." if pipeline_type == "asr" else desc)
        
        device = 0 if torch.cuda.is_available() else "cpu"
        
        if pipeline_type == "asr":
            pipe = pipeline(
                "automatic-speech-recognition",
                model=model_id,
                device=device,
                max_new_tokens=128
            )
        elif pipeline_type == "summarization":
            pipe = pipeline(
                "summarization",
                model=model_id,
                device=device
            )
        
        cache[model_id] = pipe
        progress(progress_end, desc="✅ Model loaded successfully!" if pipeline_type == "asr" else "✅ Summarization Model loaded!")
        
    return cache[model_id]

def format_seconds(seconds):
    """Converts seconds to HH:MM:SS format."""
    return str(datetime.timedelta(seconds=int(seconds)))

def create_vtt(segments, file_path):
    """Creates a WebVTT (.vtt) file from transcription segments."""
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("WEBVTT\n\n")
        for i, segment in enumerate(segments):
            start_ms = int(segment.get('start', 0) * 1000)
            end_ms = int(segment.get('end', 0) * 1000)
            
            def format_time(ms):
                hours, remainder = divmod(ms, 3600000)
                minutes, remainder = divmod(remainder, 60000)
                seconds, milliseconds = divmod(remainder, 1000)
                return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}.{int(milliseconds):03}"

            start = format_time(start_ms)
            end = format_time(end_ms)
            
            f.write(f"{i+1}\n")
            f.write(f"{start} --> {end}\n")
            f.write(f"{segment.get('text', '').strip()}\n\n")

def create_docx(segments, file_path, with_timestamps):
    """Creates a DOCX (.docx) file from transcription segments."""
    document = Document()
    document.add_heading("Transcription", 0)

    if with_timestamps:
        for segment in segments:
            text = segment.get('text', '').strip()
            start = format_seconds(segment.get('start', 0))
            end = format_seconds(segment.get('end', 0))
            document.add_paragraph(f"[{start} - {end}] {text}")
    else:
        full_text = " ".join([segment.get('text', '').strip() for segment in segments])
        document.add_paragraph(full_text)
        
    document.save(file_path)

def analyze_audio_and_get_chunks(audio_file):
    """Reads the audio file and generates chunk options for the dropdown."""
    if audio_file is None:
        return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), "Please upload an audio file first."
    
    try:
        audio = AudioSegment.from_file(audio_file)
        total_duration_ms = len(audio)
        num_chunks = (total_duration_ms + CHUNK_LENGTH_MS - 1) // CHUNK_LENGTH_MS
        
        chunk_options = ["Full Audio"]
        for i in range(num_chunks):
            start_ms = i * CHUNK_LENGTH_MS
            end_ms = min((i + 1) * CHUNK_LENGTH_MS, total_duration_ms)
            
            start_sec = start_ms / 1000
            end_sec = end_ms / 1000
            
            start_time_str = format_seconds(start_sec).split('.')[0]
            end_time_str = format_seconds(end_sec).split('.')[0]
            
            option_name = f"Chunk {i+1} ({start_time_str} - {end_time_str})"
            chunk_options.append(option_name)

        status = f"Audio analyzed. Duration: {format_seconds(total_duration_ms/1000.0)}. Found {num_chunks} chunks."
        # Add guidance based on the number of chunks
        if num_chunks > 6: # More than 30 minutes
             status += " ⚠️ **Recommendation:** Select a single chunk to process to avoid GPU memory crash."
        
        return gr.Dropdown(choices=chunk_options, value="Full Audio", interactive=True), status
    
    except Exception as e:
        error_msg = f"Error analyzing audio: {e}"
        return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), error_msg

# --- MODIFIED: generate_summary to force output language ---
def generate_summary(text, target_language_code, progress):
    """Generates an abstractive summary using a pre-trained T5 model, prompting for the target language."""
    try:
        summarizer = get_model_pipeline("t5-small", "summarization", progress)
        
        # T5-Small is multilingual but often defaults to English.
        # We use a specific prompt based on the target language to force the output.
        if target_language_code == "fr":
            # Standard French summarization prompt format for T5-like models
            prompt = f"résumer: {text}"
        elif target_language_code == "es":
            prompt = f"resumir: {text}"
        else:
            # Default English prompt (or for auto-detect)
            prompt = f"summarize: {text}"
            
        summary = summarizer(
            prompt, 
            max_length=SUMMARY_MAX_LENGTH, 
            min_length=SUMMARY_MIN_LENGTH, 
            do_sample=False
        )[0]['summary_text']
        
        return summary
    except Exception as e:
        return f"Error during summarization: {e}"
# -----------------------------------------------------------

@spaces.GPU
def transcribe_and_export(audio_file, model_size, chunk_choice, selected_language, vtt_output, docx_timestamp_output, docx_no_timestamp_output, summarize_output, progress=gr.Progress()):
    """
    Main function to transcribe audio and export. Uses selected_language to force
    the transcription language, fixing the French issue.
    """
    if audio_file is None:
        return (None, "", None, gr.Audio(value=None), "Please upload an audio file.")
    
    start_time = time.time()
    pipe = get_model_pipeline(model_size, "asr", progress)
    
    # 1. Determine which segment to process
    audio_segment_to_process = audio_file
    offset = 0.0
    
    if chunk_choice != "Full Audio":
        progress(0.70, desc="✂️ Preparing audio segment...")
        try:
            chunk_num = int(chunk_choice.split(' ')[1]) - 1
            full_audio = AudioSegment.from_file(audio_file)
            total_duration_ms = len(full_audio)
            
            start_ms = chunk_num * CHUNK_LENGTH_MS
            end_ms = min((chunk_num + 1) * CHUNK_LENGTH_MS, total_duration_ms)
            
            chunk = full_audio[start_ms:end_ms]
            temp_chunk_path = "/tmp/selected_chunk.mp3"
            chunk.export(temp_chunk_path, format="mp3")
            
            audio_segment_to_process = temp_chunk_path
            offset = start_ms / 1000.0

        except Exception as e:
            return (None, "", None, gr.Audio(value=None), f"Error preparing audio chunk: {e}")

    # 2. Define generation arguments (Language fix implemented here)
    generate_kwargs = {}
    
    lang_code = None
    if selected_language != "Auto-Detect":
        lang_code = LANGUAGE_MAP.get(selected_language, None)
        if lang_code:
            # Crucial for French fix: Pass the language code to Whisper
            generate_kwargs["language"] = lang_code
            
    # 3. Transcribe the segment
    progress(0.75, desc=f"🎤 Transcribing {chunk_choice}...")
    raw_output = pipe(
        audio_segment_to_process, 
        return_timestamps="word",
        # Pass the refined generate_kwargs
        generate_kwargs=generate_kwargs
    )

    # 4. Process and adjust segments
    full_segments = raw_output.get("chunks", [])
    transcribed_text = raw_output.get('text', '').strip()
    
    if chunk_choice != "Full Audio":
        for segment in full_segments:
            segment['start'] = segment.get('start', 0.0) + offset
            segment['end'] = segment.get('end', 0.0) + offset
        
        if os.path.exists(audio_segment_to_process):
            os.remove(audio_segment_to_process)

    # 5. Generate Summary (if requested)
    summary_text = ""
    if summarize_output and transcribed_text:
        # Pass the language code to the summary function for explicit prompting
        summary_text = generate_summary(transcribed_text, lang_code, progress)
    elif summarize_output and not transcribed_text:
        summary_text = "Transcription failed or was empty, cannot generate summary."

    # 6. Generate output files
    outputs = {}
    progress(0.95, desc="📝 Generating output files...")
    
    if vtt_output:
        vtt_path = "transcription.vtt"
        create_vtt(full_segments, vtt_path)
        outputs["VTT"] = vtt_path

    if docx_timestamp_output:
        docx_ts_path = "transcription_with_timestamps.docx"
        create_docx(full_segments, docx_ts_path, with_timestamps=True)
        outputs["DOCX (with timestamps)"] = docx_ts_path
    
    if docx_no_timestamp_output:
        docx_no_ts_path = "transcription_without_timestamps.docx"
        create_docx(full_segments, docx_no_ts_path, with_timestamps=False)
        outputs["DOCX (without timestamps)"] = docx_no_ts_path

    end_time = time.time()
    total_time = end_time - start_time
    downloadable_files = [path for path in outputs.values()]
    status_message = f"✅ Transcription complete! Total time: {total_time:.2f} seconds."

    return (
        transcribed_text, 
        summary_text,
        gr.Files(value=downloadable_files, label="Download Transcripts"), 
        gr.Audio(value=None), 
        status_message
    )

# --- Gradio UI ---
with gr.Blocks(title="Whisper ZeroGPU Transcription & Summarization") as demo:
    gr.Markdown("# 🎙️ Whisper ZeroGPU Transcription & Summarization")
    gr.Markdown("1. **Upload** audio. 2. **Analyze** for chunks. 3. Select **Model**, **Chunk**, and **Language**, then **Transcribe**.")
    
    # NEW GUIDANCE COMMENT: Crucial warning for large files
    gr.Markdown(
        """
        ⚠️ **GPU Memory Warning:** For files longer than **30 minutes** (approx. 6 chunks), 
        it's highly recommended to select a single **Chunk** to process instead of **'Full Audio'** to prevent a GPU memory crash on the platform.
        """
    )
    
    with gr.Row():
        audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File")
        
        with gr.Column(scale=2):
            model_selector = gr.Dropdown(
                label="Choose Whisper Model Size",
                choices=list(MODEL_SIZES.keys()),
                value="Distil-Large-v3 (General Purpose)"
            )
            
            # LANGUAGE FIX: Selector to explicitly set the expected language
            language_selector = gr.Dropdown(
                label="Select Expected Language (Crucial for French/Non-English)",
                choices=LANGUAGE_CHOICES,
                value="French", # Default to French
                interactive=True
            )
            
            analyze_btn = gr.Button("Analyze Audio 🔎", variant="secondary")
            
            chunk_selector = gr.Dropdown(
                label="Select Audio Segment (5-minute chunks)",
                choices=["Full Audio"],
                value="Full Audio",
                interactive=False 
            )

            gr.Markdown("### Output Options")
            with gr.Row():
                summarize_checkbox = gr.Checkbox(label="Generate Summary", value=False)
                vtt_checkbox = gr.Checkbox(label="VTT", value=True)
            
            with gr.Row():
                docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False)
                docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True)
            
            transcribe_btn = gr.Button("Transcribe", variant="primary")
            status_text = gr.Textbox(label="Status", interactive=False)

    transcription_output = gr.Textbox(label="Full Transcription", lines=10)
    summary_output = gr.Textbox(label="Summary (Abstractive)", lines=3)
    downloadable_files_output = gr.Files(label="Download Transcripts")
    
    analyze_btn.click(
        fn=analyze_audio_and_get_chunks,
        inputs=[audio_input],
        outputs=[chunk_selector, status_text]
    )
    
    transcribe_btn.click(
        fn=transcribe_and_export,
        inputs=[audio_input, model_selector, chunk_selector, language_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, summarize_checkbox],
        outputs=[transcription_output, summary_output, downloadable_files_output, audio_input, status_text]
    )

if __name__ == "__main__":
    demo.launch()