import gradio as gr import torch import logging import gc import time from pathlib import Path from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperProcessor, WhisperForConditionalGeneration import librosa # Try to import flash attention, but don't fail if not available try: from transformers.utils import is_flash_attn_2_available FLASH_ATTN_AVAILABLE = True except ImportError: FLASH_ATTN_AVAILABLE = False def is_flash_attn_2_available(): return False # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class OptimizedWhisperApp: def __init__(self): self.pipe = None self.current_model = None self.available_models = [ "openai/whisper-tiny", "openai/whisper-base", "openai/whisper-small", "openai/whisper-medium", "openai/whisper-large-v2", "openai/whisper-large-v3", "ilsp/whisper_greek_dialect_of_lesbos", "ilsp/xls-r-greek-cretan" ] def is_fine_tuned_model(self, model_name): """Check if this is a fine-tuned model that might need special handling""" fine_tuned_indicators = [ "ilsp/", "fine", "dialect", "custom", ] return any(indicator in model_name.lower() for indicator in fine_tuned_indicators) def create_pipe_for_fine_tuned(self, model_name): """Special handling for fine-tuned models""" try: logger.info(f"Loading fine-tuned model: {model_name}") # Device selection - be more conservative for fine-tuned models if torch.cuda.is_available(): device = "cuda:0" torch_dtype = torch.float32 # Use float32 for stability else: device = "cpu" torch_dtype = torch.float32 logger.info(f"Using device: {device}, dtype: {torch_dtype}") # Try to load as Whisper model first try: logger.info("Attempting to load as WhisperForConditionalGeneration...") model = WhisperForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, cache_dir="./cache" ) processor = WhisperProcessor.from_pretrained(model_name) logger.info("Successfully loaded as Whisper model") except Exception as e: logger.info(f"Whisper loading failed: {e}, trying AutoModel...") model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False, # Fine-tuned models might not have safetensors cache_dir="./cache" ) processor = AutoProcessor.from_pretrained(model_name) model.to(device) logger.info("Model moved to device") # Create pipeline with conservative settings pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, chunk_length_s=30, # Fixed chunk length for fine-tuned models ) logger.info("Fine-tuned model pipeline created successfully!") return pipe except Exception as e: logger.error(f"Failed to create fine-tuned model pipeline: {e}") import traceback logger.error(traceback.format_exc()) return None def create_pipe(self, model_name, use_flash_attention=True): """Create pipeline with special handling for fine-tuned models""" # Use special handling for fine-tuned models if self.is_fine_tuned_model(model_name): return self.create_pipe_for_fine_tuned(model_name) try: logger.info(f"Loading standard model: {model_name}") # Device selection if torch.cuda.is_available(): device = "cuda:0" torch_dtype = torch.float16 else: device = "cpu" torch_dtype = torch.float32 # Attention implementation - disable for fine-tuned models attn_implementation = "eager" if use_flash_attention and FLASH_ATTN_AVAILABLE and is_flash_attn_2_available() and torch.cuda.is_available(): try: attn_implementation = "flash_attention_2" logger.info("Using Flash Attention 2") except: attn_implementation = "eager" logger.info("Flash Attention 2 failed, using eager") # Load model model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attn_implementation, cache_dir="./cache" ) model.to(device) # Load processor processor = AutoProcessor.from_pretrained(model_name) # Create pipeline pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) logger.info("Standard model pipeline created successfully!") return pipe except Exception as e: logger.error(f"Failed to create standard model pipeline: {e}") import traceback logger.error(traceback.format_exc()) return None def load_model(self, model_name, use_flash_attention=True): """Load model with timeout protection""" if self.current_model != model_name or self.pipe is None: logger.info(f"Loading new model: {model_name}") # Clear previous model if self.pipe is not None: logger.info("Clearing previous model...") del self.pipe if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() try: # Disable flash attention for fine-tuned models if self.is_fine_tuned_model(model_name): use_flash_attention = False logger.info("Disabled flash attention for fine-tuned model") self.pipe = self.create_pipe(model_name, use_flash_attention) self.current_model = model_name if self.pipe else None if self.pipe: logger.info(f"Model {model_name} loaded successfully") return True else: logger.error(f"Failed to load model {model_name}") return False except Exception as e: logger.error(f"Error loading model: {e}") return False else: logger.info("Model already loaded") return True def transcribe_audio_fine_tuned(self, audio_file, chunk_length_s=30, batch_size=1): """Special transcription method for fine-tuned models with conservative settings""" try: logger.info("Using fine-tuned model transcription method") # Use very conservative settings for fine-tuned models outputs = self.pipe( audio_file, chunk_length_s=min(chunk_length_s, 30), # Max 30 seconds batch_size=min(batch_size, 2), # Max batch size 2 return_timestamps=True, generate_kwargs={ "task": "transcribe", "do_sample": False, # Deterministic output "num_beams": 1, # No beam search "max_length": 448, # Conservative max length } ) return outputs except Exception as e: logger.error(f"Fine-tuned transcription failed: {e}") raise e def transcribe_audio(self, audio_file, model_name="openai/whisper-medium", language="Automatic Detection", task="transcribe", chunk_length_s=30, batch_size=16, use_flash_attention=True, return_timestamps=True): """Transcribe with special handling for fine-tuned models""" if audio_file is None: return "Please upload an audio file", "", "" try: logger.info("=== Starting transcription ===") start_time = time.time() # Load model success = self.load_model(model_name, use_flash_attention) if not success: return "Failed to load model", "", "" logger.info(f"Processing: {audio_file}") logger.info(f"Settings: {model_name}, {language}, {task}") # Check if this is a fine-tuned model is_fine_tuned = self.is_fine_tuned_model(model_name) if is_fine_tuned: logger.info("Using fine-tuned model optimizations") # Use special method for fine-tuned models outputs = self.transcribe_audio_fine_tuned( audio_file, chunk_length_s, batch_size ) else: # Standard transcription for regular models logger.info("Using standard model transcription") # Prepare generation kwargs generate_kwargs = {} # Language handling if language != "Automatic Detection" and not model_name.endswith(".en"): language_map = { "Greek": "greek", "English": "english", "Spanish": "spanish", "French": "french", "German": "german", "Italian": "italian" } lang_code = language_map.get(language, language.lower()) generate_kwargs["language"] = lang_code logger.info(f"Set language: {lang_code}") # Task handling if not model_name.endswith(".en"): generate_kwargs["task"] = task outputs = self.pipe( audio_file, chunk_length_s=chunk_length_s, batch_size=batch_size, generate_kwargs=generate_kwargs, return_timestamps=return_timestamps, ) transcription_time = time.time() - start_time logger.info(f"Transcription completed in {transcription_time:.2f} seconds") # Extract results transcription = outputs.get("text", "") if outputs else "" chunks = outputs.get("chunks", []) if outputs else [] # Handle timestamps timestamp_text = "" if return_timestamps: try: if chunks: timestamp_text = self._format_timestamps(chunks) else: timestamp_text = "=== TIMESTAMPS ===\nNo chunks returned.\n" except Exception as ts_error: logger.warning(f"Timestamp formatting error: {ts_error}") timestamp_text = f"=== TIMESTAMPS ===\nError: {str(ts_error)}\n" else: timestamp_text = "=== TIMESTAMPS ===\nDisabled.\n" # Create detailed output detailed_output = self._format_detailed_output( transcription, model_name, language, task, transcription_time, chunk_length_s, batch_size, use_flash_attention, len(chunks), is_fine_tuned ) return transcription.strip(), timestamp_text, detailed_output except Exception as e: error_msg = f"Transcription error: {str(e)}" logger.error(error_msg) import traceback logger.error(traceback.format_exc()) return error_msg, "", error_msg def _format_timestamps(self, chunks): """Format timestamp information""" timestamp_text = "=== TIMESTAMPS ===\n" if not chunks: return timestamp_text + "No chunks available.\n" for i, chunk in enumerate(chunks): try: timestamp = chunk.get('timestamp', None) text = chunk.get('text', '') if timestamp is None: timestamp_text += f"[No timestamp]: {text}\n" elif isinstance(timestamp, (list, tuple)) and len(timestamp) >= 2: start, end = timestamp[0], timestamp[1] if start is None or end is None: timestamp_text += f"[Invalid]: {text}\n" else: try: start_f = float(start) end_f = float(end) timestamp_text += f"[{start_f:.1f}s - {end_f:.1f}s]: {text}\n" except (ValueError, TypeError): timestamp_text += f"[Format error]: {text}\n" else: timestamp_text += f"[Unexpected format]: {text}\n" except Exception as e: timestamp_text += f"[Chunk {i} error]: {str(e)}\n" return timestamp_text def _format_detailed_output(self, transcription, model_name, language, task, transcription_time, chunk_length_s, batch_size, use_flash_attention, num_chunks, is_fine_tuned=False): """Format detailed information""" output = "=== TRANSCRIPTION ===\n" output += f"{transcription}\n\n" output += "=== MODEL INFORMATION ===\n" output += f"Model: {model_name}\n" output += f"Model Type: {'Fine-tuned' if is_fine_tuned else 'Standard'}\n" output += f"Language: {language}\n" output += f"Task: {task}\n" output += f"Processing time: {transcription_time:.2f} seconds\n" output += f"Chunks processed: {num_chunks}\n" output += "\n=== PROCESSING SETTINGS ===\n" output += f"Chunk length: {chunk_length_s} seconds\n" output += f"Batch size: {batch_size}\n" output += f"Flash Attention: {'Enabled' if use_flash_attention and not is_fine_tuned else 'Disabled'}\n" if is_fine_tuned: output += "\n=== FINE-TUNED MODEL OPTIMIZATIONS ===\n" output += "• Conservative batch size (max 2)\n" output += "• Float32 precision for stability\n" output += "• Disabled flash attention\n" output += "• Deterministic generation\n" output += "• No beam search\n" return output def get_model_info(self): """Get current model information""" if self.pipe is None: return "No model loaded" try: device = next(self.pipe.model.parameters()).device dtype = next(self.pipe.model.parameters()).dtype model_type = "Fine-tuned" if self.is_fine_tuned_model(self.current_model) else "Standard" return f"✅ {self.current_model} ({model_type}) - {device} ({dtype})" except: return f"✅ {self.current_model} loaded" # Initialize the app logger.info("Initializing Optimized Whisper App...") whisper_app = OptimizedWhisperApp() def transcribe_wrapper(audio, model_name, language, task, chunk_length_s, batch_size, use_flash_attention, return_timestamps): """Wrapper for Gradio interface""" try: return whisper_app.transcribe_audio( audio, model_name, language, task, chunk_length_s, batch_size, use_flash_attention, return_timestamps ) except Exception as e: error_msg = f"Wrapper error: {str(e)}" logger.error(error_msg) return error_msg, "", error_msg def get_model_status(): """Get current model status""" return whisper_app.get_model_info() def update_settings_for_model(model_name): """Update recommended settings based on model type""" is_fine_tuned = whisper_app.is_fine_tuned_model(model_name) if is_fine_tuned: return { "batch_size": gr.update(value=1, maximum=2), "use_flash_attention": gr.update(value=False), "chunk_length_s": gr.update(value=30) } else: return { "batch_size": gr.update(value=4, maximum=16), "use_flash_attention": gr.update(value=False), "chunk_length_s": gr.update(value=30) } # Create the interface def create_interface(): with gr.Blocks(title="Optimized Whisper Transcription", theme=gr.themes.Soft()) as interface: gr.Markdown( """ # 🚀 ASR Fine-tuned Model for Lesbian Greek **Enhanced for Fine-tuned Models** Features: - Special handling for fine-tuned models (like Greek dialect) - Automatic optimization based on model type - Conservative settings for stability - Enhanced error handling """ ) # Model status model_status = gr.Textbox( value=get_model_status(), label="🔧 Current Model Status", interactive=False ) # Main interface with gr.Row(): with gr.Column(): # Audio input audio_input = gr.Audio( label="🎵 Upload Audio File", type="filepath" ) # Model selection model_dropdown = gr.Dropdown( choices=whisper_app.available_models, value="openai/whisper-small", label="Model", info="Auto-optimizes settings for fine-tuned models" ) # Basic settings with gr.Row(): language_dropdown = gr.Dropdown( choices=["Automatic Detection", "Greek", "English", "Spanish", "French", "German", "Italian"], value="Automatic Detection", label="Language" ) task_dropdown = gr.Dropdown( choices=["transcribe", "translate"], value="transcribe", label="Task" ) # Advanced settings with gr.Accordion("Advanced Settings", open=False): chunk_length_s = gr.Slider( minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)" ) batch_size = gr.Slider( minimum=1, maximum=16, value=4, step=1, label="Batch Size", info="Auto-adjusted for fine-tuned models" ) use_flash_attention = gr.Checkbox( label="Flash Attention 2", value=False, info="Auto-disabled for fine-tuned models" ) return_timestamps = gr.Checkbox( label="Return Timestamps", value=True ) transcribe_btn = gr.Button( "🚀 Transcribe", variant="primary", size="lg" ) with gr.Column(): # Results transcription_output = gr.Textbox( label="Transcription", lines=8, show_copy_button=True ) with gr.Accordion("Timestamps", open=False): timestamps_output = gr.Textbox( label="Timestamp Information", lines=10, show_copy_button=True ) with gr.Accordion("Detailed Information", open=False): detailed_output = gr.Textbox( label="Processing Details & Model Info", lines=15, show_copy_button=True ) # Event handlers transcribe_btn.click( fn=transcribe_wrapper, inputs=[audio_input, model_dropdown, language_dropdown, task_dropdown, chunk_length_s, batch_size, use_flash_attention, return_timestamps], outputs=[transcription_output, timestamps_output, detailed_output], show_progress=True ) # Auto-adjust settings when model changes model_dropdown.change( fn=lambda model: ( f"Model will be loaded on next transcription ({'Fine-tuned' if whisper_app.is_fine_tuned_model(model) else 'Standard'} model)", 1 if whisper_app.is_fine_tuned_model(model) else 4, False ), inputs=[model_dropdown], outputs=[model_status, batch_size, use_flash_attention] ) # Footer gr.Markdown( """ ### 🎯 Fine-tuned Model Optimizations **Automatic optimizations for fine-tuned models:** - Batch size limited to 1-2 for stability - Flash Attention automatically disabled - Float32 precision for better compatibility - Conservative generation settings - Enhanced error handling **For Greek dialect model specifically:** - Use batch size 1 - Keep chunk length at 30 seconds - Language detection usually works well """ ) return interface # Launch the app if __name__ == "__main__": interface = create_interface() interface.launch(share=True)