sbompolas commited on
Commit
75153af
Β·
verified Β·
1 Parent(s): 7b908d7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +422 -0
app.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import logging
4
+ import gc
5
+ import time
6
+ from pathlib import Path
7
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
8
+ from transformers.utils import is_flash_attn_2_available
9
+ import librosa
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class OptimizedWhisperApp:
16
+ def __init__(self):
17
+ self.pipe = None
18
+ self.current_model = None
19
+ self.available_models = [
20
+ "openai/whisper-tiny",
21
+ "openai/whisper-base",
22
+ "openai/whisper-small",
23
+ "openai/whisper-medium", # Often the sweet spot
24
+ "openai/whisper-large-v2",
25
+ "openai/whisper-large-v3",
26
+ "distil-whisper/distil-medium.en",
27
+ "distil-whisper/distil-large-v2",
28
+ "ilsp/whisper_greek_dialect_of_lesbos" # Your specialized model
29
+ ]
30
+
31
+ def create_pipe(self, model_name, use_flash_attention=True):
32
+ """Create pipeline like the successful space"""
33
+ try:
34
+ # Device selection
35
+ if torch.cuda.is_available():
36
+ device = "cuda:0"
37
+ torch_dtype = torch.float16
38
+ else:
39
+ device = "cpu"
40
+ torch_dtype = torch.float32
41
+
42
+ logger.info(f"Loading {model_name} on {device} with {torch_dtype}")
43
+
44
+ # Attention implementation
45
+ if use_flash_attention and is_flash_attn_2_available() and torch.cuda.is_available():
46
+ attn_implementation = "flash_attention_2"
47
+ logger.info("Using Flash Attention 2")
48
+ else:
49
+ attn_implementation = "sdpa" # Scaled Dot Product Attention
50
+ logger.info(f"Using {attn_implementation}")
51
+
52
+ # Load model directly (like the successful space)
53
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch_dtype,
56
+ low_cpu_mem_usage=True,
57
+ use_safetensors=True,
58
+ attn_implementation=attn_implementation,
59
+ cache_dir="./cache"
60
+ )
61
+ model.to(device)
62
+
63
+ # Load processor
64
+ processor = AutoProcessor.from_pretrained(model_name)
65
+
66
+ # Create pipeline manually (like the successful space)
67
+ pipe = pipeline(
68
+ "automatic-speech-recognition",
69
+ model=model,
70
+ tokenizer=processor.tokenizer,
71
+ feature_extractor=processor.feature_extractor,
72
+ torch_dtype=torch_dtype,
73
+ device=device,
74
+ )
75
+
76
+ logger.info("Pipeline created successfully!")
77
+ return pipe
78
+
79
+ except Exception as e:
80
+ logger.error(f"Failed to create pipeline: {e}")
81
+ return None
82
+
83
+ def load_model(self, model_name, use_flash_attention=True):
84
+ """Load model if different from current"""
85
+ if self.current_model != model_name or self.pipe is None:
86
+ logger.info(f"Loading new model: {model_name}")
87
+
88
+ # Clear previous model
89
+ if self.pipe is not None:
90
+ del self.pipe
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+ gc.collect()
94
+
95
+ # Create new pipeline
96
+ self.pipe = self.create_pipe(model_name, use_flash_attention)
97
+ self.current_model = model_name if self.pipe else None
98
+
99
+ return self.pipe is not None
100
+ else:
101
+ logger.info("Model already loaded")
102
+ return True
103
+
104
+ def transcribe_audio(self, audio_file, model_name="openai/whisper-medium",
105
+ language="Automatic Detection", task="transcribe",
106
+ chunk_length_s=30, batch_size=16, use_flash_attention=True,
107
+ return_timestamps=True):
108
+ """Transcribe using the optimized approach"""
109
+
110
+ if audio_file is None:
111
+ return "Please upload an audio file", "", ""
112
+
113
+ try:
114
+ start_time = time.time()
115
+
116
+ # Load model if needed
117
+ success = self.load_model(model_name, use_flash_attention)
118
+ if not success:
119
+ return "Failed to load model", "", ""
120
+
121
+ logger.info(f"Processing: {audio_file}")
122
+ logger.info(f"Settings: {model_name}, {language}, {task}")
123
+ logger.info(f"Chunk length: {chunk_length_s}s, Batch size: {batch_size}")
124
+
125
+ # Prepare generation kwargs (like the successful space)
126
+ generate_kwargs = {}
127
+
128
+ # Only set language if not auto-detection and model supports multilingual
129
+ if language != "Automatic Detection" and not model_name.endswith(".en"):
130
+ # Map common language names
131
+ language_map = {
132
+ "Greek": "greek",
133
+ "English": "english",
134
+ "Spanish": "spanish",
135
+ "French": "french",
136
+ "German": "german",
137
+ "Italian": "italian"
138
+ }
139
+ lang_code = language_map.get(language, language.lower())
140
+ generate_kwargs["language"] = lang_code
141
+ logger.info(f"Set language: {lang_code}")
142
+
143
+ # Set task if model supports it
144
+ if not model_name.endswith(".en"):
145
+ generate_kwargs["task"] = task
146
+ logger.info(f"Set task: {task}")
147
+
148
+ # Transcribe (like the successful space approach)
149
+ logger.info("Starting transcription...")
150
+ outputs = self.pipe(
151
+ audio_file,
152
+ chunk_length_s=chunk_length_s,
153
+ batch_size=batch_size,
154
+ generate_kwargs=generate_kwargs,
155
+ return_timestamps=return_timestamps,
156
+ )
157
+
158
+ transcription_time = time.time() - start_time
159
+ logger.info(f"Transcription completed in {transcription_time:.2f} seconds")
160
+
161
+ # Extract results
162
+ transcription = outputs.get("text", "")
163
+ chunks = outputs.get("chunks", [])
164
+
165
+ # Format timestamps
166
+ timestamp_text = ""
167
+ if chunks:
168
+ timestamp_text = self._format_timestamps(chunks)
169
+
170
+ # Create detailed output
171
+ detailed_output = self._format_detailed_output(
172
+ transcription, model_name, language, task,
173
+ transcription_time, chunk_length_s, batch_size,
174
+ use_flash_attention, len(chunks)
175
+ )
176
+
177
+ return transcription.strip(), timestamp_text, detailed_output
178
+
179
+ except Exception as e:
180
+ error_msg = f"Transcription error: {str(e)}"
181
+ logger.error(error_msg)
182
+ return error_msg, "", error_msg
183
+
184
+ def _format_timestamps(self, chunks):
185
+ """Format timestamp information"""
186
+ timestamp_text = "=== TIMESTAMPS ===\n"
187
+ for i, chunk in enumerate(chunks):
188
+ timestamp = chunk.get('timestamp', [0, 0])
189
+ text = chunk.get('text', '')
190
+ start, end = timestamp[0], timestamp[1]
191
+ timestamp_text += f"[{start:.1f}s - {end:.1f}s]: {text}\n"
192
+ return timestamp_text
193
+
194
+ def _format_detailed_output(self, transcription, model_name, language, task,
195
+ transcription_time, chunk_length_s, batch_size,
196
+ use_flash_attention, num_chunks):
197
+ """Format detailed information"""
198
+ output = "=== TRANSCRIPTION ===\n"
199
+ output += f"{transcription}\n\n"
200
+
201
+ output += "=== MODEL INFORMATION ===\n"
202
+ output += f"Model: {model_name}\n"
203
+ output += f"Language: {language}\n"
204
+ output += f"Task: {task}\n"
205
+ output += f"Processing time: {transcription_time:.2f} seconds\n"
206
+ output += f"Chunks processed: {num_chunks}\n"
207
+
208
+ output += "\n=== PROCESSING SETTINGS ===\n"
209
+ output += f"Chunk length: {chunk_length_s} seconds\n"
210
+ output += f"Batch size: {batch_size}\n"
211
+ output += f"Flash Attention: {'Enabled' if use_flash_attention else 'Disabled'}\n"
212
+
213
+ if self.pipe:
214
+ device = next(self.pipe.model.parameters()).device
215
+ dtype = next(self.pipe.model.parameters()).dtype
216
+ output += f"Device: {device}\n"
217
+ output += f"Data type: {dtype}\n"
218
+
219
+ output += f"Flash Attention 2 available: {is_flash_attn_2_available()}\n"
220
+
221
+ output += "\n=== OPTIMIZATIONS ===\n"
222
+ output += "β€’ Direct model loading (not pipeline abstraction)\n"
223
+ output += "β€’ Manual pipeline construction\n"
224
+ output += "β€’ Optimized attention mechanism\n"
225
+ output += "β€’ Batch processing\n"
226
+ output += "β€’ Conservative language handling\n"
227
+ output += "β€’ Proper memory management\n"
228
+
229
+ return output
230
+
231
+ def get_model_info(self):
232
+ """Get current model information"""
233
+ if self.pipe is None:
234
+ return "No model loaded"
235
+
236
+ device = next(self.pipe.model.parameters()).device
237
+ dtype = next(self.pipe.model.parameters()).dtype
238
+
239
+ return f"βœ… {self.current_model} loaded on {device} ({dtype})"
240
+
241
+ # Initialize the app
242
+ logger.info("Initializing Optimized Whisper App...")
243
+ whisper_app = OptimizedWhisperApp()
244
+
245
+ def transcribe_wrapper(audio, model_name, language, task, chunk_length_s,
246
+ batch_size, use_flash_attention, return_timestamps):
247
+ """Wrapper for Gradio interface"""
248
+ return whisper_app.transcribe_audio(
249
+ audio, model_name, language, task,
250
+ chunk_length_s, batch_size, use_flash_attention, return_timestamps
251
+ )
252
+
253
+ def get_model_status():
254
+ """Get current model status"""
255
+ return whisper_app.get_model_info()
256
+
257
+ # Create the interface
258
+ def create_interface():
259
+ with gr.Blocks(title="Optimized Whisper Transcription", theme=gr.themes.Soft()) as interface:
260
+
261
+ gr.Markdown(
262
+ """
263
+ # πŸš€ Optimized Whisper Transcription
264
+
265
+ **High-Performance Speech-to-Text Based on Successful Implementation**
266
+
267
+ Uses the same optimizations as high-performing Whisper spaces:
268
+ - Direct model loading for better control
269
+ - Flash Attention 2 support
270
+ - Optimized chunking and batching
271
+ - Conservative parameter handling
272
+ """
273
+ )
274
+
275
+ # Model status
276
+ model_status = gr.Textbox(
277
+ value=get_model_status(),
278
+ label="πŸ”§ Current Model Status",
279
+ interactive=False
280
+ )
281
+
282
+ # Main interface
283
+ with gr.Row():
284
+ with gr.Column():
285
+ # Audio input
286
+ audio_input = gr.Audio(
287
+ label="🎡 Upload Audio File",
288
+ type="filepath",
289
+ waveform_options=gr.WaveformOptions(
290
+ waveform_color="#01C6FF",
291
+ waveform_progress_color="#0066B4",
292
+ skip_length=2,
293
+ show_controls=True,
294
+ )
295
+ )
296
+
297
+ # Model selection
298
+ model_dropdown = gr.Dropdown(
299
+ choices=whisper_app.available_models,
300
+ value="openai/whisper-medium",
301
+ label="Model",
302
+ info="Medium often works best for real-world usage"
303
+ )
304
+
305
+ # Basic settings
306
+ with gr.Row():
307
+ language_dropdown = gr.Dropdown(
308
+ choices=["Automatic Detection", "Greek", "English", "Spanish", "French", "German", "Italian"],
309
+ value="Automatic Detection",
310
+ label="Language"
311
+ )
312
+
313
+ task_dropdown = gr.Dropdown(
314
+ choices=["transcribe", "translate"],
315
+ value="transcribe",
316
+ label="Task"
317
+ )
318
+
319
+ # Advanced settings
320
+ with gr.Accordion("Advanced Settings", open=False):
321
+ chunk_length_s = gr.Slider(
322
+ minimum=10,
323
+ maximum=60,
324
+ value=30,
325
+ step=5,
326
+ label="Chunk Length (seconds)",
327
+ info="30s is optimal for most cases"
328
+ )
329
+
330
+ batch_size = gr.Slider(
331
+ minimum=1,
332
+ maximum=32,
333
+ value=16,
334
+ step=1,
335
+ label="Batch Size",
336
+ info="Higher = faster, more memory"
337
+ )
338
+
339
+ use_flash_attention = gr.Checkbox(
340
+ label="Flash Attention 2",
341
+ value=True,
342
+ info="Faster processing (requires compatible GPU)"
343
+ )
344
+
345
+ return_timestamps = gr.Checkbox(
346
+ label="Return Timestamps",
347
+ value=True
348
+ )
349
+
350
+ transcribe_btn = gr.Button(
351
+ "πŸš€ Transcribe",
352
+ variant="primary",
353
+ size="lg"
354
+ )
355
+
356
+ with gr.Column():
357
+ # Results
358
+ transcription_output = gr.Textbox(
359
+ label="Transcription",
360
+ lines=8,
361
+ show_copy_button=True
362
+ )
363
+
364
+ with gr.Accordion("Timestamps", open=False):
365
+ timestamps_output = gr.Textbox(
366
+ label="Timestamp Information",
367
+ lines=10,
368
+ show_copy_button=True
369
+ )
370
+
371
+ with gr.Accordion("Detailed Information", open=False):
372
+ detailed_output = gr.Textbox(
373
+ label="Processing Details & Model Info",
374
+ lines=15,
375
+ show_copy_button=True
376
+ )
377
+
378
+ # Event handlers
379
+ transcribe_btn.click(
380
+ fn=transcribe_wrapper,
381
+ inputs=[audio_input, model_dropdown, language_dropdown, task_dropdown,
382
+ chunk_length_s, batch_size, use_flash_attention, return_timestamps],
383
+ outputs=[transcription_output, timestamps_output, detailed_output],
384
+ show_progress=True
385
+ )
386
+
387
+ # Update model status when model changes
388
+ model_dropdown.change(
389
+ fn=lambda: "Model will be loaded on next transcription",
390
+ outputs=[model_status]
391
+ )
392
+
393
+ # Footer
394
+ gr.Markdown(
395
+ """
396
+ ### 🎯 Model Recommendations
397
+
398
+ **For Greek dialect of Lesbos:**
399
+ - `ilsp/whisper_greek_dialect_of_lesbos` - Specialized but may have issues
400
+ - `openai/whisper-medium` - Often better for real-world usage
401
+ - `openai/whisper-large-v2` - More accurate but slower
402
+
403
+ **General recommendations:**
404
+ - **Medium model** often provides the best balance
405
+ - **30-second chunks** work well for most audio
406
+ - **Flash Attention** speeds up processing significantly
407
+ - **Automatic language detection** usually works well
408
+
409
+ ### ⚑ Performance Tips
410
+ - GPU with Flash Attention 2 = Fastest
411
+ - Batch size 16-24 optimal for most GPUs
412
+ - Lower chunk length for very noisy audio
413
+ - Use English-only models (.en) for English-only content
414
+ """
415
+ )
416
+
417
+ return interface
418
+
419
+ # Launch the app
420
+ if __name__ == "__main__":
421
+ interface = create_interface()
422
+ interface.launch(share=True)