sbompolas commited on
Commit
18bcd78
·
verified ·
1 Parent(s): e44ba63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -154
app.py CHANGED
@@ -3,11 +3,9 @@ import torch
3
  import logging
4
  import gc
5
  import time
6
- import signal
7
  from pathlib import Path
8
- from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
9
  import librosa
10
- from functools import wraps
11
 
12
  # Try to import flash attention, but don't fail if not available
13
  try:
@@ -22,25 +20,6 @@ except ImportError:
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
25
- def timeout_handler(signum, frame):
26
- raise TimeoutError("Operation timed out")
27
-
28
- def with_timeout(seconds):
29
- def decorator(func):
30
- @wraps(func)
31
- def wrapper(*args, **kwargs):
32
- # Set the signal handler and a timeout alarm
33
- old_handler = signal.signal(signal.SIGALRM, timeout_handler)
34
- signal.alarm(seconds)
35
- try:
36
- result = func(*args, **kwargs)
37
- finally:
38
- signal.alarm(0) # Disable the alarm
39
- signal.signal(signal.SIGALRM, old_handler)
40
- return result
41
- return wrapper
42
- return decorator
43
-
44
  class OptimizedWhisperApp:
45
  def __init__(self):
46
  self.pipe = None
@@ -57,35 +36,105 @@ class OptimizedWhisperApp:
57
  "ilsp/whisper_greek_dialect_of_lesbos"
58
  ]
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def create_pipe(self, model_name, use_flash_attention=True):
61
- """Create pipeline with better error handling"""
 
 
 
 
 
62
  try:
63
- logger.info(f"Starting to load model: {model_name}")
64
 
65
  # Device selection
66
  if torch.cuda.is_available():
67
  device = "cuda:0"
68
  torch_dtype = torch.float16
69
- logger.info("Using CUDA device")
70
  else:
71
  device = "cpu"
72
  torch_dtype = torch.float32
73
- logger.info("Using CPU device")
74
 
75
- # Simpler attention implementation selection
76
- attn_implementation = "eager" # Start with most compatible
77
  if use_flash_attention and FLASH_ATTN_AVAILABLE and is_flash_attn_2_available() and torch.cuda.is_available():
78
  try:
79
  attn_implementation = "flash_attention_2"
80
- logger.info("Attempting Flash Attention 2")
81
  except:
82
  attn_implementation = "eager"
83
  logger.info("Flash Attention 2 failed, using eager")
84
 
85
- logger.info(f"Using attention implementation: {attn_implementation}")
86
-
87
- # Load model with timeout protection
88
- logger.info("Loading model...")
89
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
90
  model_name,
91
  torch_dtype=torch_dtype,
@@ -94,15 +143,12 @@ class OptimizedWhisperApp:
94
  attn_implementation=attn_implementation,
95
  cache_dir="./cache"
96
  )
97
- logger.info("Model loaded, moving to device...")
98
  model.to(device)
99
 
100
  # Load processor
101
- logger.info("Loading processor...")
102
  processor = AutoProcessor.from_pretrained(model_name)
103
 
104
  # Create pipeline
105
- logger.info("Creating pipeline...")
106
  pipe = pipeline(
107
  "automatic-speech-recognition",
108
  model=model,
@@ -112,11 +158,11 @@ class OptimizedWhisperApp:
112
  device=device,
113
  )
114
 
115
- logger.info("Pipeline created successfully!")
116
  return pipe
117
 
118
  except Exception as e:
119
- logger.error(f"Failed to create pipeline: {e}")
120
  import traceback
121
  logger.error(traceback.format_exc())
122
  return None
@@ -135,7 +181,11 @@ class OptimizedWhisperApp:
135
  gc.collect()
136
 
137
  try:
138
- # Create new pipeline with timeout
 
 
 
 
139
  self.pipe = self.create_pipe(model_name, use_flash_attention)
140
  self.current_model = model_name if self.pipe else None
141
 
@@ -153,11 +203,35 @@ class OptimizedWhisperApp:
153
  logger.info("Model already loaded")
154
  return True
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def transcribe_audio(self, audio_file, model_name="openai/whisper-medium",
157
  language="Automatic Detection", task="transcribe",
158
  chunk_length_s=30, batch_size=16, use_flash_attention=True,
159
  return_timestamps=True):
160
- """Transcribe with comprehensive error handling"""
161
 
162
  if audio_file is None:
163
  return "Please upload an audio file", "", ""
@@ -167,40 +241,47 @@ class OptimizedWhisperApp:
167
  start_time = time.time()
168
 
169
  # Load model
170
- logger.info(f"Loading model: {model_name}")
171
  success = self.load_model(model_name, use_flash_attention)
172
  if not success:
173
- return "Failed to load model - check logs for details", "", ""
174
-
175
- logger.info(f"Processing audio file: {audio_file}")
176
- logger.info(f"Settings - Model: {model_name}, Language: {language}, Task: {task}")
177
- logger.info(f"Chunk length: {chunk_length_s}s, Batch size: {batch_size}")
178
-
179
- # Prepare generation kwargs
180
- generate_kwargs = {}
181
-
182
- # Language handling
183
- if language != "Automatic Detection" and not model_name.endswith(".en"):
184
- language_map = {
185
- "Greek": "greek",
186
- "English": "english",
187
- "Spanish": "spanish",
188
- "French": "french",
189
- "German": "german",
190
- "Italian": "italian"
191
- }
192
- lang_code = language_map.get(language, language.lower())
193
- generate_kwargs["language"] = lang_code
194
- logger.info(f"Set language: {lang_code}")
195
 
196
- # Task handling
197
- if not model_name.endswith(".en"):
198
- generate_kwargs["task"] = task
199
- logger.info(f"Set task: {task}")
200
 
201
- # Transcribe with error handling
202
- logger.info("Starting transcription process...")
203
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  outputs = self.pipe(
205
  audio_file,
206
  chunk_length_s=chunk_length_s,
@@ -208,43 +289,35 @@ class OptimizedWhisperApp:
208
  generate_kwargs=generate_kwargs,
209
  return_timestamps=return_timestamps,
210
  )
211
- logger.info("Transcription completed")
212
- except Exception as transcribe_error:
213
- logger.error(f"Transcription failed: {transcribe_error}")
214
- return f"Transcription failed: {str(transcribe_error)}", "", ""
215
 
216
  transcription_time = time.time() - start_time
217
- logger.info(f"Total processing time: {transcription_time:.2f} seconds")
218
 
219
- # Extract and validate results
220
  transcription = outputs.get("text", "") if outputs else ""
221
  chunks = outputs.get("chunks", []) if outputs else []
222
 
223
- logger.info(f"Extracted transcription length: {len(transcription)} chars")
224
- logger.info(f"Number of chunks: {len(chunks)}")
225
-
226
- # Handle timestamps safely
227
  timestamp_text = ""
228
  if return_timestamps:
229
  try:
230
  if chunks:
231
  timestamp_text = self._format_timestamps(chunks)
232
  else:
233
- timestamp_text = "=== TIMESTAMPS ===\nNo chunks returned by the model.\n"
234
  except Exception as ts_error:
235
- logger.warning(f"Error formatting timestamps: {ts_error}")
236
- timestamp_text = f"=== TIMESTAMPS ===\nError formatting timestamps: {str(ts_error)}\n"
237
  else:
238
- timestamp_text = "=== TIMESTAMPS ===\nTimestamp output disabled.\n"
239
 
240
  # Create detailed output
241
  detailed_output = self._format_detailed_output(
242
  transcription, model_name, language, task,
243
  transcription_time, chunk_length_s, batch_size,
244
- use_flash_attention, len(chunks)
245
  )
246
 
247
- logger.info("=== Transcription completed successfully ===")
248
  return transcription.strip(), timestamp_text, detailed_output
249
 
250
  except Exception as e:
@@ -255,14 +328,11 @@ class OptimizedWhisperApp:
255
  return error_msg, "", error_msg
256
 
257
  def _format_timestamps(self, chunks):
258
- """Format timestamp information with comprehensive error handling"""
259
  timestamp_text = "=== TIMESTAMPS ===\n"
260
 
261
  if not chunks:
262
- timestamp_text += "No timestamp information available.\n"
263
- return timestamp_text
264
-
265
- logger.info(f"Formatting {len(chunks)} chunks")
266
 
267
  for i, chunk in enumerate(chunks):
268
  try:
@@ -274,32 +344,31 @@ class OptimizedWhisperApp:
274
  elif isinstance(timestamp, (list, tuple)) and len(timestamp) >= 2:
275
  start, end = timestamp[0], timestamp[1]
276
  if start is None or end is None:
277
- timestamp_text += f"[Invalid timestamp]: {text}\n"
278
  else:
279
  try:
280
- start_f = float(start) if start is not None else 0.0
281
- end_f = float(end) if end is not None else 0.0
282
  timestamp_text += f"[{start_f:.1f}s - {end_f:.1f}s]: {text}\n"
283
- except (ValueError, TypeError) as ve:
284
  timestamp_text += f"[Format error]: {text}\n"
285
  else:
286
  timestamp_text += f"[Unexpected format]: {text}\n"
287
-
288
- except Exception as chunk_error:
289
- logger.warning(f"Error processing chunk {i}: {chunk_error}")
290
- timestamp_text += f"[Chunk {i} error]: Processing failed\n"
291
 
292
  return timestamp_text
293
 
294
  def _format_detailed_output(self, transcription, model_name, language, task,
295
  transcription_time, chunk_length_s, batch_size,
296
- use_flash_attention, num_chunks):
297
  """Format detailed information"""
298
  output = "=== TRANSCRIPTION ===\n"
299
  output += f"{transcription}\n\n"
300
 
301
  output += "=== MODEL INFORMATION ===\n"
302
  output += f"Model: {model_name}\n"
 
303
  output += f"Language: {language}\n"
304
  output += f"Task: {task}\n"
305
  output += f"Processing time: {transcription_time:.2f} seconds\n"
@@ -308,26 +377,15 @@ class OptimizedWhisperApp:
308
  output += "\n=== PROCESSING SETTINGS ===\n"
309
  output += f"Chunk length: {chunk_length_s} seconds\n"
310
  output += f"Batch size: {batch_size}\n"
311
- output += f"Flash Attention: {'Enabled' if use_flash_attention else 'Disabled'}\n"
312
 
313
- if self.pipe and hasattr(self.pipe, 'model'):
314
- try:
315
- device = next(self.pipe.model.parameters()).device
316
- dtype = next(self.pipe.model.parameters()).dtype
317
- output += f"Device: {device}\n"
318
- output += f"Data type: {dtype}\n"
319
- except:
320
- output += "Device info: Unable to retrieve\n"
321
-
322
- output += f"Flash Attention 2 available: {FLASH_ATTN_AVAILABLE and is_flash_attn_2_available()}\n"
323
-
324
- output += "\n=== SYSTEM STATUS ===\n"
325
- if torch.cuda.is_available():
326
- output += f"GPU Available: Yes\n"
327
- output += f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB\n"
328
- output += f"GPU Memory Used: {torch.cuda.memory_allocated() / 1e9:.1f}GB\n"
329
- else:
330
- output += "GPU Available: No\n"
331
 
332
  return output
333
 
@@ -339,9 +397,10 @@ class OptimizedWhisperApp:
339
  try:
340
  device = next(self.pipe.model.parameters()).device
341
  dtype = next(self.pipe.model.parameters()).dtype
342
- return f" {self.current_model} loaded on {device} ({dtype})"
 
343
  except:
344
- return f"✅ {self.current_model} loaded (device info unavailable)"
345
 
346
  # Initialize the app
347
  logger.info("Initializing Optimized Whisper App...")
@@ -349,9 +408,8 @@ whisper_app = OptimizedWhisperApp()
349
 
350
  def transcribe_wrapper(audio, model_name, language, task, chunk_length_s,
351
  batch_size, use_flash_attention, return_timestamps):
352
- """Wrapper for Gradio interface with additional safety"""
353
  try:
354
- logger.info(f"Transcribe wrapper called with model: {model_name}")
355
  return whisper_app.transcribe_audio(
356
  audio, model_name, language, task,
357
  chunk_length_s, batch_size, use_flash_attention, return_timestamps
@@ -365,6 +423,23 @@ def get_model_status():
365
  """Get current model status"""
366
  return whisper_app.get_model_info()
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  # Create the interface
369
  def create_interface():
370
  with gr.Blocks(title="Optimized Whisper Transcription", theme=gr.themes.Soft()) as interface:
@@ -373,13 +448,13 @@ def create_interface():
373
  """
374
  # 🚀 Optimized Whisper Transcription
375
 
376
- **High-Performance Speech-to-Text with Enhanced Error Handling**
377
 
378
  Features:
379
- - Comprehensive error handling and timeout protection
380
- - Better logging for debugging
381
- - Graceful handling of model loading issues
382
- - Robust timestamp processing
383
  """
384
  )
385
 
@@ -402,9 +477,9 @@ def create_interface():
402
  # Model selection
403
  model_dropdown = gr.Dropdown(
404
  choices=whisper_app.available_models,
405
- value="openai/whisper-small", # Start with smaller model
406
  label="Model",
407
- info="Start with 'small' model for testing"
408
  )
409
 
410
  # Basic settings
@@ -428,23 +503,22 @@ def create_interface():
428
  maximum=60,
429
  value=30,
430
  step=5,
431
- label="Chunk Length (seconds)",
432
- info="30s is optimal for most cases"
433
  )
434
 
435
  batch_size = gr.Slider(
436
  minimum=1,
437
- maximum=8, # Reduced default
438
- value=4, # Smaller batch size
439
  step=1,
440
  label="Batch Size",
441
- info="Start with smaller batch size"
442
  )
443
 
444
  use_flash_attention = gr.Checkbox(
445
  label="Flash Attention 2",
446
- value=False, # Disabled by default for stability
447
- info="Enable only if you have compatible GPU and flash-attn installed"
448
  )
449
 
450
  return_timestamps = gr.Checkbox(
@@ -489,28 +563,33 @@ def create_interface():
489
  show_progress=True
490
  )
491
 
492
- # Update model status when model changes
493
  model_dropdown.change(
494
- fn=lambda: "Model will be loaded on next transcription",
495
- outputs=[model_status]
 
 
 
 
 
496
  )
497
 
498
  # Footer
499
  gr.Markdown(
500
  """
501
- ### 🔧 Troubleshooting
502
-
503
- **If processing hangs:**
504
- 1. Start with `whisper-small` model
505
- 2. Disable Flash Attention
506
- 3. Use smaller batch size (1-4)
507
- 4. Check the console logs for errors
508
-
509
- **For best results:**
510
- - Test with small model first
511
- - Gradually increase model size if needed
512
- - Monitor GPU memory usage
513
- - Check logs for any error messages
514
  """
515
  )
516
 
@@ -519,4 +598,4 @@ def create_interface():
519
  # Launch the app
520
  if __name__ == "__main__":
521
  interface = create_interface()
522
- interface.launch(share=True, debug=True)
 
3
  import logging
4
  import gc
5
  import time
 
6
  from pathlib import Path
7
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperProcessor, WhisperForConditionalGeneration
8
  import librosa
 
9
 
10
  # Try to import flash attention, but don't fail if not available
11
  try:
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class OptimizedWhisperApp:
24
  def __init__(self):
25
  self.pipe = None
 
36
  "ilsp/whisper_greek_dialect_of_lesbos"
37
  ]
38
 
39
+ def is_fine_tuned_model(self, model_name):
40
+ """Check if this is a fine-tuned model that might need special handling"""
41
+ fine_tuned_indicators = [
42
+ "ilsp/",
43
+ "fine",
44
+ "dialect",
45
+ "custom",
46
+ ]
47
+ return any(indicator in model_name.lower() for indicator in fine_tuned_indicators)
48
+
49
+ def create_pipe_for_fine_tuned(self, model_name):
50
+ """Special handling for fine-tuned models"""
51
+ try:
52
+ logger.info(f"Loading fine-tuned model: {model_name}")
53
+
54
+ # Device selection - be more conservative for fine-tuned models
55
+ if torch.cuda.is_available():
56
+ device = "cuda:0"
57
+ torch_dtype = torch.float32 # Use float32 for stability
58
+ else:
59
+ device = "cpu"
60
+ torch_dtype = torch.float32
61
+
62
+ logger.info(f"Using device: {device}, dtype: {torch_dtype}")
63
+
64
+ # Try to load as Whisper model first
65
+ try:
66
+ logger.info("Attempting to load as WhisperForConditionalGeneration...")
67
+ model = WhisperForConditionalGeneration.from_pretrained(
68
+ model_name,
69
+ torch_dtype=torch_dtype,
70
+ low_cpu_mem_usage=True,
71
+ cache_dir="./cache"
72
+ )
73
+ processor = WhisperProcessor.from_pretrained(model_name)
74
+ logger.info("Successfully loaded as Whisper model")
75
+ except Exception as e:
76
+ logger.info(f"Whisper loading failed: {e}, trying AutoModel...")
77
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
78
+ model_name,
79
+ torch_dtype=torch_dtype,
80
+ low_cpu_mem_usage=True,
81
+ use_safetensors=False, # Fine-tuned models might not have safetensors
82
+ cache_dir="./cache"
83
+ )
84
+ processor = AutoProcessor.from_pretrained(model_name)
85
+
86
+ model.to(device)
87
+ logger.info("Model moved to device")
88
+
89
+ # Create pipeline with conservative settings
90
+ pipe = pipeline(
91
+ "automatic-speech-recognition",
92
+ model=model,
93
+ tokenizer=processor.tokenizer,
94
+ feature_extractor=processor.feature_extractor,
95
+ torch_dtype=torch_dtype,
96
+ device=device,
97
+ chunk_length_s=30, # Fixed chunk length for fine-tuned models
98
+ )
99
+
100
+ logger.info("Fine-tuned model pipeline created successfully!")
101
+ return pipe
102
+
103
+ except Exception as e:
104
+ logger.error(f"Failed to create fine-tuned model pipeline: {e}")
105
+ import traceback
106
+ logger.error(traceback.format_exc())
107
+ return None
108
+
109
  def create_pipe(self, model_name, use_flash_attention=True):
110
+ """Create pipeline with special handling for fine-tuned models"""
111
+
112
+ # Use special handling for fine-tuned models
113
+ if self.is_fine_tuned_model(model_name):
114
+ return self.create_pipe_for_fine_tuned(model_name)
115
+
116
  try:
117
+ logger.info(f"Loading standard model: {model_name}")
118
 
119
  # Device selection
120
  if torch.cuda.is_available():
121
  device = "cuda:0"
122
  torch_dtype = torch.float16
 
123
  else:
124
  device = "cpu"
125
  torch_dtype = torch.float32
 
126
 
127
+ # Attention implementation - disable for fine-tuned models
128
+ attn_implementation = "eager"
129
  if use_flash_attention and FLASH_ATTN_AVAILABLE and is_flash_attn_2_available() and torch.cuda.is_available():
130
  try:
131
  attn_implementation = "flash_attention_2"
132
+ logger.info("Using Flash Attention 2")
133
  except:
134
  attn_implementation = "eager"
135
  logger.info("Flash Attention 2 failed, using eager")
136
 
137
+ # Load model
 
 
 
138
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
139
  model_name,
140
  torch_dtype=torch_dtype,
 
143
  attn_implementation=attn_implementation,
144
  cache_dir="./cache"
145
  )
 
146
  model.to(device)
147
 
148
  # Load processor
 
149
  processor = AutoProcessor.from_pretrained(model_name)
150
 
151
  # Create pipeline
 
152
  pipe = pipeline(
153
  "automatic-speech-recognition",
154
  model=model,
 
158
  device=device,
159
  )
160
 
161
+ logger.info("Standard model pipeline created successfully!")
162
  return pipe
163
 
164
  except Exception as e:
165
+ logger.error(f"Failed to create standard model pipeline: {e}")
166
  import traceback
167
  logger.error(traceback.format_exc())
168
  return None
 
181
  gc.collect()
182
 
183
  try:
184
+ # Disable flash attention for fine-tuned models
185
+ if self.is_fine_tuned_model(model_name):
186
+ use_flash_attention = False
187
+ logger.info("Disabled flash attention for fine-tuned model")
188
+
189
  self.pipe = self.create_pipe(model_name, use_flash_attention)
190
  self.current_model = model_name if self.pipe else None
191
 
 
203
  logger.info("Model already loaded")
204
  return True
205
 
206
+ def transcribe_audio_fine_tuned(self, audio_file, chunk_length_s=30, batch_size=1):
207
+ """Special transcription method for fine-tuned models with conservative settings"""
208
+ try:
209
+ logger.info("Using fine-tuned model transcription method")
210
+
211
+ # Use very conservative settings for fine-tuned models
212
+ outputs = self.pipe(
213
+ audio_file,
214
+ chunk_length_s=min(chunk_length_s, 30), # Max 30 seconds
215
+ batch_size=min(batch_size, 2), # Max batch size 2
216
+ return_timestamps=True,
217
+ generate_kwargs={
218
+ "task": "transcribe",
219
+ "do_sample": False, # Deterministic output
220
+ "num_beams": 1, # No beam search
221
+ "max_length": 448, # Conservative max length
222
+ }
223
+ )
224
+ return outputs
225
+
226
+ except Exception as e:
227
+ logger.error(f"Fine-tuned transcription failed: {e}")
228
+ raise e
229
+
230
  def transcribe_audio(self, audio_file, model_name="openai/whisper-medium",
231
  language="Automatic Detection", task="transcribe",
232
  chunk_length_s=30, batch_size=16, use_flash_attention=True,
233
  return_timestamps=True):
234
+ """Transcribe with special handling for fine-tuned models"""
235
 
236
  if audio_file is None:
237
  return "Please upload an audio file", "", ""
 
241
  start_time = time.time()
242
 
243
  # Load model
 
244
  success = self.load_model(model_name, use_flash_attention)
245
  if not success:
246
+ return "Failed to load model", "", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ logger.info(f"Processing: {audio_file}")
249
+ logger.info(f"Settings: {model_name}, {language}, {task}")
 
 
250
 
251
+ # Check if this is a fine-tuned model
252
+ is_fine_tuned = self.is_fine_tuned_model(model_name)
253
+
254
+ if is_fine_tuned:
255
+ logger.info("Using fine-tuned model optimizations")
256
+ # Use special method for fine-tuned models
257
+ outputs = self.transcribe_audio_fine_tuned(
258
+ audio_file, chunk_length_s, batch_size
259
+ )
260
+ else:
261
+ # Standard transcription for regular models
262
+ logger.info("Using standard model transcription")
263
+
264
+ # Prepare generation kwargs
265
+ generate_kwargs = {}
266
+
267
+ # Language handling
268
+ if language != "Automatic Detection" and not model_name.endswith(".en"):
269
+ language_map = {
270
+ "Greek": "greek",
271
+ "English": "english",
272
+ "Spanish": "spanish",
273
+ "French": "french",
274
+ "German": "german",
275
+ "Italian": "italian"
276
+ }
277
+ lang_code = language_map.get(language, language.lower())
278
+ generate_kwargs["language"] = lang_code
279
+ logger.info(f"Set language: {lang_code}")
280
+
281
+ # Task handling
282
+ if not model_name.endswith(".en"):
283
+ generate_kwargs["task"] = task
284
+
285
  outputs = self.pipe(
286
  audio_file,
287
  chunk_length_s=chunk_length_s,
 
289
  generate_kwargs=generate_kwargs,
290
  return_timestamps=return_timestamps,
291
  )
 
 
 
 
292
 
293
  transcription_time = time.time() - start_time
294
+ logger.info(f"Transcription completed in {transcription_time:.2f} seconds")
295
 
296
+ # Extract results
297
  transcription = outputs.get("text", "") if outputs else ""
298
  chunks = outputs.get("chunks", []) if outputs else []
299
 
300
+ # Handle timestamps
 
 
 
301
  timestamp_text = ""
302
  if return_timestamps:
303
  try:
304
  if chunks:
305
  timestamp_text = self._format_timestamps(chunks)
306
  else:
307
+ timestamp_text = "=== TIMESTAMPS ===\nNo chunks returned.\n"
308
  except Exception as ts_error:
309
+ logger.warning(f"Timestamp formatting error: {ts_error}")
310
+ timestamp_text = f"=== TIMESTAMPS ===\nError: {str(ts_error)}\n"
311
  else:
312
+ timestamp_text = "=== TIMESTAMPS ===\nDisabled.\n"
313
 
314
  # Create detailed output
315
  detailed_output = self._format_detailed_output(
316
  transcription, model_name, language, task,
317
  transcription_time, chunk_length_s, batch_size,
318
+ use_flash_attention, len(chunks), is_fine_tuned
319
  )
320
 
 
321
  return transcription.strip(), timestamp_text, detailed_output
322
 
323
  except Exception as e:
 
328
  return error_msg, "", error_msg
329
 
330
  def _format_timestamps(self, chunks):
331
+ """Format timestamp information"""
332
  timestamp_text = "=== TIMESTAMPS ===\n"
333
 
334
  if not chunks:
335
+ return timestamp_text + "No chunks available.\n"
 
 
 
336
 
337
  for i, chunk in enumerate(chunks):
338
  try:
 
344
  elif isinstance(timestamp, (list, tuple)) and len(timestamp) >= 2:
345
  start, end = timestamp[0], timestamp[1]
346
  if start is None or end is None:
347
+ timestamp_text += f"[Invalid]: {text}\n"
348
  else:
349
  try:
350
+ start_f = float(start)
351
+ end_f = float(end)
352
  timestamp_text += f"[{start_f:.1f}s - {end_f:.1f}s]: {text}\n"
353
+ except (ValueError, TypeError):
354
  timestamp_text += f"[Format error]: {text}\n"
355
  else:
356
  timestamp_text += f"[Unexpected format]: {text}\n"
357
+ except Exception as e:
358
+ timestamp_text += f"[Chunk {i} error]: {str(e)}\n"
 
 
359
 
360
  return timestamp_text
361
 
362
  def _format_detailed_output(self, transcription, model_name, language, task,
363
  transcription_time, chunk_length_s, batch_size,
364
+ use_flash_attention, num_chunks, is_fine_tuned=False):
365
  """Format detailed information"""
366
  output = "=== TRANSCRIPTION ===\n"
367
  output += f"{transcription}\n\n"
368
 
369
  output += "=== MODEL INFORMATION ===\n"
370
  output += f"Model: {model_name}\n"
371
+ output += f"Model Type: {'Fine-tuned' if is_fine_tuned else 'Standard'}\n"
372
  output += f"Language: {language}\n"
373
  output += f"Task: {task}\n"
374
  output += f"Processing time: {transcription_time:.2f} seconds\n"
 
377
  output += "\n=== PROCESSING SETTINGS ===\n"
378
  output += f"Chunk length: {chunk_length_s} seconds\n"
379
  output += f"Batch size: {batch_size}\n"
380
+ output += f"Flash Attention: {'Enabled' if use_flash_attention and not is_fine_tuned else 'Disabled'}\n"
381
 
382
+ if is_fine_tuned:
383
+ output += "\n=== FINE-TUNED MODEL OPTIMIZATIONS ===\n"
384
+ output += "• Conservative batch size (max 2)\n"
385
+ output += "• Float32 precision for stability\n"
386
+ output += " Disabled flash attention\n"
387
+ output += " Deterministic generation\n"
388
+ output += "• No beam search\n"
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  return output
391
 
 
397
  try:
398
  device = next(self.pipe.model.parameters()).device
399
  dtype = next(self.pipe.model.parameters()).dtype
400
+ model_type = "Fine-tuned" if self.is_fine_tuned_model(self.current_model) else "Standard"
401
+ return f"✅ {self.current_model} ({model_type}) - {device} ({dtype})"
402
  except:
403
+ return f"✅ {self.current_model} loaded"
404
 
405
  # Initialize the app
406
  logger.info("Initializing Optimized Whisper App...")
 
408
 
409
  def transcribe_wrapper(audio, model_name, language, task, chunk_length_s,
410
  batch_size, use_flash_attention, return_timestamps):
411
+ """Wrapper for Gradio interface"""
412
  try:
 
413
  return whisper_app.transcribe_audio(
414
  audio, model_name, language, task,
415
  chunk_length_s, batch_size, use_flash_attention, return_timestamps
 
423
  """Get current model status"""
424
  return whisper_app.get_model_info()
425
 
426
+ def update_settings_for_model(model_name):
427
+ """Update recommended settings based on model type"""
428
+ is_fine_tuned = whisper_app.is_fine_tuned_model(model_name)
429
+
430
+ if is_fine_tuned:
431
+ return {
432
+ "batch_size": gr.update(value=1, maximum=2),
433
+ "use_flash_attention": gr.update(value=False),
434
+ "chunk_length_s": gr.update(value=30)
435
+ }
436
+ else:
437
+ return {
438
+ "batch_size": gr.update(value=4, maximum=16),
439
+ "use_flash_attention": gr.update(value=False),
440
+ "chunk_length_s": gr.update(value=30)
441
+ }
442
+
443
  # Create the interface
444
  def create_interface():
445
  with gr.Blocks(title="Optimized Whisper Transcription", theme=gr.themes.Soft()) as interface:
 
448
  """
449
  # 🚀 Optimized Whisper Transcription
450
 
451
+ **Enhanced for Fine-tuned Models**
452
 
453
  Features:
454
+ - Special handling for fine-tuned models (like Greek dialect)
455
+ - Automatic optimization based on model type
456
+ - Conservative settings for stability
457
+ - Enhanced error handling
458
  """
459
  )
460
 
 
477
  # Model selection
478
  model_dropdown = gr.Dropdown(
479
  choices=whisper_app.available_models,
480
+ value="openai/whisper-small",
481
  label="Model",
482
+ info="Auto-optimizes settings for fine-tuned models"
483
  )
484
 
485
  # Basic settings
 
503
  maximum=60,
504
  value=30,
505
  step=5,
506
+ label="Chunk Length (seconds)"
 
507
  )
508
 
509
  batch_size = gr.Slider(
510
  minimum=1,
511
+ maximum=16,
512
+ value=4,
513
  step=1,
514
  label="Batch Size",
515
+ info="Auto-adjusted for fine-tuned models"
516
  )
517
 
518
  use_flash_attention = gr.Checkbox(
519
  label="Flash Attention 2",
520
+ value=False,
521
+ info="Auto-disabled for fine-tuned models"
522
  )
523
 
524
  return_timestamps = gr.Checkbox(
 
563
  show_progress=True
564
  )
565
 
566
+ # Auto-adjust settings when model changes
567
  model_dropdown.change(
568
+ fn=lambda model: (
569
+ f"Model will be loaded on next transcription ({'Fine-tuned' if whisper_app.is_fine_tuned_model(model) else 'Standard'} model)",
570
+ 1 if whisper_app.is_fine_tuned_model(model) else 4,
571
+ False
572
+ ),
573
+ inputs=[model_dropdown],
574
+ outputs=[model_status, batch_size, use_flash_attention]
575
  )
576
 
577
  # Footer
578
  gr.Markdown(
579
  """
580
+ ### 🎯 Fine-tuned Model Optimizations
581
+
582
+ **Automatic optimizations for fine-tuned models:**
583
+ - Batch size limited to 1-2 for stability
584
+ - Flash Attention automatically disabled
585
+ - Float32 precision for better compatibility
586
+ - Conservative generation settings
587
+ - Enhanced error handling
588
+
589
+ **For Greek dialect model specifically:**
590
+ - Use batch size 1
591
+ - Keep chunk length at 30 seconds
592
+ - Language detection usually works well
593
  """
594
  )
595
 
 
598
  # Launch the app
599
  if __name__ == "__main__":
600
  interface = create_interface()
601
+ interface.launch(share=True)