clementBE commited on
Commit
7b256c2
Β·
verified Β·
1 Parent(s): 6e566c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -82
app.py CHANGED
@@ -14,14 +14,15 @@ MODEL_SIZES = {
14
  "Base (Faster)": "openai/whisper-base",
15
  "Small (Balanced)": "openai/whisper-small",
16
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
17
- # FIX: The model 'distil-whisper/distil-large-v3-fr' does not exist.
18
- # We use the general distil-large-v3 and rely on the code below to force French.
19
  "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3"
20
  }
21
 
22
  # Use a dictionary to cache loaded models
23
  model_cache = {}
24
 
 
 
 
25
  def get_model_pipeline(model_name, progress):
26
  """
27
  Initializes and caches the ASR pipeline for a given model name.
@@ -29,7 +30,6 @@ def get_model_pipeline(model_name, progress):
29
  if model_name not in model_cache:
30
  progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
31
  model_id = MODEL_SIZES[model_name]
32
- # Use GPU if available, otherwise fallback to CPU
33
  device = 0 if torch.cuda.is_available() else "cpu"
34
 
35
  progress(0.1, desc=f"⏳ Loading {model_name} model...")
@@ -37,12 +37,15 @@ def get_model_pipeline(model_name, progress):
37
  "automatic-speech-recognition",
38
  model=model_id,
39
  device=device,
40
- # Set max_new_tokens for generation, common for ASR
41
  max_new_tokens=128
42
  )
43
  progress(0.5, desc="βœ… Model loaded successfully!")
44
  return model_cache[model_name]
45
 
 
 
 
 
46
  def create_vtt(segments, file_path):
47
  """
48
  Creates a WebVTT (.vtt) file from transcription segments.
@@ -50,7 +53,6 @@ def create_vtt(segments, file_path):
50
  with open(file_path, "w", encoding="utf-8") as f:
51
  f.write("WEBVTT\n\n")
52
  for i, segment in enumerate(segments):
53
- # Calculate time strings in "HH:MM:SS.mmm" format
54
  start_ms = int(segment.get('start', 0) * 1000)
55
  end_ms = int(segment.get('end', 0) * 1000)
56
 
@@ -77,9 +79,8 @@ def create_docx(segments, file_path, with_timestamps):
77
  if with_timestamps:
78
  for segment in segments:
79
  text = segment.get('text', '').strip()
80
- # Format time as HH:MM:SS for DOCX
81
- start = str(datetime.timedelta(seconds=int(segment.get('start', 0))))
82
- end = str(datetime.timedelta(seconds=int(segment.get('end', 0))))
83
  document.add_paragraph(f"[{start} - {end}] {text}")
84
  else:
85
  full_text = " ".join([segment.get('text', '').strip() for segment in segments])
@@ -87,101 +88,123 @@ def create_docx(segments, file_path, with_timestamps):
87
 
88
  document.save(file_path)
89
 
90
- @spaces.GPU
91
- def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, sequence_5_min, progress=gr.Progress()):
92
  """
93
- Main function to transcribe audio and export to selected formats.
94
- Added logic for 5-minute sequencing.
95
  """
96
  if audio_file is None:
97
- return (None, None, None, "Please upload an audio file.")
98
-
99
- start_time = time.time()
100
-
101
- pipe = get_model_pipeline(model_size, progress)
102
-
103
- # Define generation arguments
104
- generate_kwargs = {}
105
- if model_size == "Distil-Large-v3-FR (French-Specific)":
106
- # Force French for this specific option
107
- generate_kwargs["language"] = "fr"
108
 
109
- full_segments = []
110
- full_text_list = []
111
-
112
- # --- New 5-Minute Sequencing Logic ---
113
- if sequence_5_min:
114
- progress(0.70, desc="βœ‚οΈ Splitting audio into 5-minute chunks...")
115
  audio = AudioSegment.from_file(audio_file)
116
- chunk_length_ms = 5 * 60 * 1000 # 5 minutes in milliseconds
117
  total_duration_ms = len(audio)
118
- num_chunks = (total_duration_ms + chunk_length_ms - 1) // chunk_length_ms # Ceiling division
119
 
 
120
  for i in range(num_chunks):
121
- start_ms = i * chunk_length_ms
122
- end_ms = min((i + 1) * chunk_length_ms, total_duration_ms)
123
 
124
- progress_val = 0.70 + (i / num_chunks) * 0.15
125
- progress(progress_val, desc=f"🎀 Transcribing chunk {i+1}/{num_chunks}...")
126
 
127
- chunk = audio[start_ms:end_ms]
128
- temp_chunk_path = f"/tmp/chunk_{i}.mp3" # Save as a temp file for the pipeline
129
- chunk.export(temp_chunk_path, format="mp3")
130
 
131
- # Transcribe the chunk
132
- chunk_output = pipe(
133
- temp_chunk_path,
134
- return_timestamps="word",
135
- generate_kwargs=generate_kwargs
136
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # Adjust timestamps for the full file
139
- offset = start_ms / 1000.0
140
- chunk_segments = chunk_output.get("chunks", [])
141
- for segment in chunk_segments:
142
- segment['start'] = segment.get('start', 0.0) + offset
143
- segment['end'] = segment.get('end', 0.0) + offset
144
- full_segments.append(segment)
145
 
146
- full_text_list.append(chunk_output.get('text', ''))
 
147
 
148
- os.remove(temp_chunk_path) # Clean up temp file
 
 
 
149
 
150
- transcribed_text = " ".join(full_text_list).strip()
 
151
 
152
- else:
153
- # Standard transcription for the whole file at once
154
- progress(0.75, desc="🎀 Transcribing full audio file...")
155
- raw_output = pipe(
156
- audio_file,
157
- return_timestamps="word",
158
- generate_kwargs=generate_kwargs
159
- )
160
- full_segments = raw_output.get("chunks", [])
161
- transcribed_text = raw_output.get('text', '').strip()
 
 
 
 
 
 
 
 
 
162
 
163
- # Ensure segments is not empty
164
- if not full_segments and transcribed_text:
165
- # Create a single segment from the full text if chunks were not generated for some reason
166
- full_segments = [{'text': transcribed_text, 'start': 0.0, 'end': 0.0}]
 
 
 
 
 
 
167
 
 
168
  outputs = {}
169
-
170
  progress(0.85, desc="πŸ“ Generating output files...")
171
 
172
- # Generate VTT
173
  if vtt_output:
174
  vtt_path = "transcription.vtt"
175
  create_vtt(full_segments, vtt_path)
176
  outputs["VTT"] = vtt_path
177
 
178
- # Generate DOCX with timestamps
179
  if docx_timestamp_output:
180
  docx_ts_path = "transcription_with_timestamps.docx"
181
  create_docx(full_segments, docx_ts_path, with_timestamps=True)
182
  outputs["DOCX (with timestamps)"] = docx_ts_path
183
 
184
- # Generate DOCX without timestamps
185
  if docx_no_timestamp_output:
186
  docx_no_ts_path = "transcription_without_timestamps.docx"
187
  create_docx(full_segments, docx_no_ts_path, with_timestamps=False)
@@ -195,14 +218,14 @@ def transcribe_and_export(audio_file, model_size, vtt_output, docx_timestamp_out
195
  return (
196
  transcribed_text,
197
  gr.Files(value=downloadable_files, label="Download Transcripts"),
198
- gr.Audio(value=None), # Clear the audio input
199
  status_message
200
  )
201
 
202
  # --- Gradio UI ---
203
  with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
204
  gr.Markdown("# πŸŽ™οΈ Whisper ZeroGPU Transcription")
205
- gr.Markdown("Transcribe audio with timestamps and choose your output format. The first run may take up to a minute due to cold start.")
206
 
207
  with gr.Row():
208
  audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File")
@@ -213,13 +236,18 @@ with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
213
  choices=list(MODEL_SIZES.keys()),
214
  value="Distil-Large-v3-FR (French-Specific)"
215
  )
216
- gr.Markdown("### Processing Options")
217
- # NEW CHECKBOX for 5-minute sequencing
218
- sequence_checkbox = gr.Checkbox(
219
- label="Process in 5-minute sequences (Recommended for files > 30 min or to prevent memory errors)",
220
- value=False
221
- )
222
 
 
 
 
 
 
 
 
 
 
 
 
223
  gr.Markdown("### Choose Output Formats")
224
  with gr.Row():
225
  vtt_checkbox = gr.Checkbox(label="VTT", value=True)
@@ -232,10 +260,17 @@ with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
232
  transcription_output = gr.Textbox(label="Full Transcription", lines=10)
233
  downloadable_files_output = gr.Files(label="Download Transcripts")
234
 
 
 
 
 
 
 
 
 
235
  transcribe_btn.click(
236
  fn=transcribe_and_export,
237
- # UPDATED INPUTS list to include the new checkbox
238
- inputs=[audio_input, model_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, sequence_checkbox],
239
  outputs=[transcription_output, downloadable_files_output, audio_input, status_text]
240
  )
241
 
 
14
  "Base (Faster)": "openai/whisper-base",
15
  "Small (Balanced)": "openai/whisper-small",
16
  "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
 
 
17
  "Distil-Large-v3-FR (French-Specific)": "distil-whisper/distil-large-v3"
18
  }
19
 
20
  # Use a dictionary to cache loaded models
21
  model_cache = {}
22
 
23
+ # Define the fixed chunk length (5 minutes in milliseconds)
24
+ CHUNK_LENGTH_MS = 5 * 60 * 1000
25
+
26
  def get_model_pipeline(model_name, progress):
27
  """
28
  Initializes and caches the ASR pipeline for a given model name.
 
30
  if model_name not in model_cache:
31
  progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
32
  model_id = MODEL_SIZES[model_name]
 
33
  device = 0 if torch.cuda.is_available() else "cpu"
34
 
35
  progress(0.1, desc=f"⏳ Loading {model_name} model...")
 
37
  "automatic-speech-recognition",
38
  model=model_id,
39
  device=device,
 
40
  max_new_tokens=128
41
  )
42
  progress(0.5, desc="βœ… Model loaded successfully!")
43
  return model_cache[model_name]
44
 
45
+ # Helper function to format seconds to HH:MM:SS string
46
+ def format_seconds(seconds):
47
+ return str(datetime.timedelta(seconds=int(seconds)))
48
+
49
  def create_vtt(segments, file_path):
50
  """
51
  Creates a WebVTT (.vtt) file from transcription segments.
 
53
  with open(file_path, "w", encoding="utf-8") as f:
54
  f.write("WEBVTT\n\n")
55
  for i, segment in enumerate(segments):
 
56
  start_ms = int(segment.get('start', 0) * 1000)
57
  end_ms = int(segment.get('end', 0) * 1000)
58
 
 
79
  if with_timestamps:
80
  for segment in segments:
81
  text = segment.get('text', '').strip()
82
+ start = format_seconds(segment.get('start', 0))
83
+ end = format_seconds(segment.get('end', 0))
 
84
  document.add_paragraph(f"[{start} - {end}] {text}")
85
  else:
86
  full_text = " ".join([segment.get('text', '').strip() for segment in segments])
 
88
 
89
  document.save(file_path)
90
 
91
+ # --- NEW FUNCTION: Analyze Audio and Populate Dropdown ---
92
+ def analyze_audio_and_get_chunks(audio_file):
93
  """
94
+ Reads the audio file and generates chunk options for the dropdown.
 
95
  """
96
  if audio_file is None:
97
+ return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), "Please upload an audio file first."
 
 
 
 
 
 
 
 
 
 
98
 
99
+ try:
 
 
 
 
 
100
  audio = AudioSegment.from_file(audio_file)
 
101
  total_duration_ms = len(audio)
102
+ num_chunks = (total_duration_ms + CHUNK_LENGTH_MS - 1) // CHUNK_LENGTH_MS
103
 
104
+ chunk_options = ["Full Audio"]
105
  for i in range(num_chunks):
106
+ start_ms = i * CHUNK_LENGTH_MS
107
+ end_ms = min((i + 1) * CHUNK_LENGTH_MS, total_duration_ms)
108
 
109
+ start_sec = start_ms / 1000
110
+ end_sec = end_ms / 1000
111
 
112
+ start_time_str = format_seconds(start_sec).split('.')[0]
113
+ end_time_str = format_seconds(end_sec).split('.')[0]
 
114
 
115
+ option_name = f"Chunk {i+1} ({start_time_str} - {end_time_str})"
116
+ chunk_options.append(option_name)
117
+
118
+ status = f"Audio analyzed. Duration: {format_seconds(total_duration_ms/1000.0)}. Found {num_chunks} chunks."
119
+ return gr.Dropdown(choices=chunk_options, value="Full Audio", interactive=True), status
120
+
121
+ except Exception as e:
122
+ error_msg = f"Error analyzing audio: {e}"
123
+ return gr.Dropdown(choices=["Full Audio"], value="Full Audio", interactive=False), error_msg
124
+ # --------------------------------------------------------
125
+
126
+
127
+ @spaces.GPU
128
+ def transcribe_and_export(audio_file, model_size, chunk_choice, vtt_output, docx_timestamp_output, docx_no_timestamp_output, progress=gr.Progress()):
129
+ """
130
+ Main function to transcribe audio and export to selected formats.
131
+ Modified to process a single selected chunk or the full audio.
132
+ """
133
+ if audio_file is None:
134
+ return (None, None, None, "Please upload an audio file.")
135
+
136
+ start_time = time.time()
137
+ pipe = get_model_pipeline(model_size, progress)
138
+
139
+ # 1. Determine which segment to process
140
+ audio_segment_to_process = audio_file
141
+ offset = 0.0 # Time offset for segment timestamps
142
+
143
+ if chunk_choice != "Full Audio":
144
+ progress(0.70, desc="βœ‚οΈ Preparing audio segment...")
145
+ try:
146
+ # Parse chunk number from choice string (e.g., "Chunk 2 (5:00:00 - 10:00:00)")
147
+ chunk_num = int(chunk_choice.split(' ')[1]) - 1
148
 
149
+ full_audio = AudioSegment.from_file(audio_file)
150
+ total_duration_ms = len(full_audio)
 
 
 
 
 
151
 
152
+ start_ms = chunk_num * CHUNK_LENGTH_MS
153
+ end_ms = min((chunk_num + 1) * CHUNK_LENGTH_MS, total_duration_ms)
154
 
155
+ # Slice the audio
156
+ chunk = full_audio[start_ms:end_ms]
157
+ temp_chunk_path = "/tmp/selected_chunk.mp3"
158
+ chunk.export(temp_chunk_path, format="mp3")
159
 
160
+ audio_segment_to_process = temp_chunk_path
161
+ offset = start_ms / 1000.0 # Offset is the start time of the chunk in seconds
162
 
163
+ except Exception as e:
164
+ return (None, None, None, f"Error preparing audio chunk: {e}")
165
+
166
+ # 2. Define generation arguments (Language)
167
+ generate_kwargs = {}
168
+ if model_size == "Distil-Large-v3-FR (French-Specific)":
169
+ generate_kwargs["language"] = "fr"
170
+
171
+ # 3. Transcribe the segment
172
+ progress(0.75, desc=f"🎀 Transcribing {chunk_choice}...")
173
+ raw_output = pipe(
174
+ audio_segment_to_process,
175
+ return_timestamps="word",
176
+ generate_kwargs=generate_kwargs
177
+ )
178
+
179
+ # 4. Process and adjust segments
180
+ full_segments = raw_output.get("chunks", [])
181
+ transcribed_text = raw_output.get('text', '').strip()
182
 
183
+ # Adjust timestamps if a chunk was processed
184
+ if chunk_choice != "Full Audio":
185
+ for segment in full_segments:
186
+ # Add the offset to the segment start and end times
187
+ segment['start'] = segment.get('start', 0.0) + offset
188
+ segment['end'] = segment.get('end', 0.0) + offset
189
+
190
+ # Clean up the temporary file
191
+ if os.path.exists(audio_segment_to_process):
192
+ os.remove(audio_segment_to_process)
193
 
194
+ # 5. Generate output files
195
  outputs = {}
 
196
  progress(0.85, desc="πŸ“ Generating output files...")
197
 
 
198
  if vtt_output:
199
  vtt_path = "transcription.vtt"
200
  create_vtt(full_segments, vtt_path)
201
  outputs["VTT"] = vtt_path
202
 
 
203
  if docx_timestamp_output:
204
  docx_ts_path = "transcription_with_timestamps.docx"
205
  create_docx(full_segments, docx_ts_path, with_timestamps=True)
206
  outputs["DOCX (with timestamps)"] = docx_ts_path
207
 
 
208
  if docx_no_timestamp_output:
209
  docx_no_ts_path = "transcription_without_timestamps.docx"
210
  create_docx(full_segments, docx_no_ts_path, with_timestamps=False)
 
218
  return (
219
  transcribed_text,
220
  gr.Files(value=downloadable_files, label="Download Transcripts"),
221
+ gr.Audio(value=None),
222
  status_message
223
  )
224
 
225
  # --- Gradio UI ---
226
  with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
227
  gr.Markdown("# πŸŽ™οΈ Whisper ZeroGPU Transcription")
228
+ gr.Markdown("1. **Upload** an audio file. 2. Click **'Analyze Audio'** to load the 5-minute chunks. 3. Select a chunk or **'Full Audio'** and click **'Transcribe'**.")
229
 
230
  with gr.Row():
231
  audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio File")
 
236
  choices=list(MODEL_SIZES.keys()),
237
  value="Distil-Large-v3-FR (French-Specific)"
238
  )
 
 
 
 
 
 
239
 
240
+ # NEW: Button to analyze audio and populate chunk options
241
+ analyze_btn = gr.Button("Analyze Audio πŸ”Ž", variant="secondary")
242
+
243
+ # NEW: Dropdown for chunk selection
244
+ chunk_selector = gr.Dropdown(
245
+ label="Select Audio Segment (5-minute chunks)",
246
+ choices=["Full Audio"],
247
+ value="Full Audio",
248
+ interactive=False # Disabled until audio is uploaded and analyzed
249
+ )
250
+
251
  gr.Markdown("### Choose Output Formats")
252
  with gr.Row():
253
  vtt_checkbox = gr.Checkbox(label="VTT", value=True)
 
260
  transcription_output = gr.Textbox(label="Full Transcription", lines=10)
261
  downloadable_files_output = gr.Files(label="Download Transcripts")
262
 
263
+ # NEW: Link the analyze button to the analysis function
264
+ analyze_btn.click(
265
+ fn=analyze_audio_and_get_chunks,
266
+ inputs=[audio_input],
267
+ outputs=[chunk_selector, status_text]
268
+ )
269
+
270
+ # UPDATED: Link the transcribe button to the transcription function
271
  transcribe_btn.click(
272
  fn=transcribe_and_export,
273
+ inputs=[audio_input, model_selector, chunk_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox],
 
274
  outputs=[transcription_output, downloadable_files_output, audio_input, status_text]
275
  )
276