Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import logging | |
| import gc | |
| import time | |
| from pathlib import Path | |
| from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperProcessor, WhisperForConditionalGeneration | |
| import librosa | |
| # Try to import flash attention, but don't fail if not available | |
| try: | |
| from transformers.utils import is_flash_attn_2_available | |
| FLASH_ATTN_AVAILABLE = True | |
| except ImportError: | |
| FLASH_ATTN_AVAILABLE = False | |
| def is_flash_attn_2_available(): | |
| return False | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class OptimizedWhisperApp: | |
| def __init__(self): | |
| self.pipe = None | |
| self.current_model = None | |
| self.available_models = [ | |
| "openai/whisper-tiny", | |
| "openai/whisper-base", | |
| "openai/whisper-small", | |
| "openai/whisper-medium", | |
| "openai/whisper-large-v2", | |
| "openai/whisper-large-v3", | |
| "ilsp/whisper_greek_dialect_of_lesbos", | |
| "ilsp/xls-r-greek-cretan" | |
| ] | |
| def is_fine_tuned_model(self, model_name): | |
| """Check if this is a fine-tuned model that might need special handling""" | |
| fine_tuned_indicators = [ | |
| "ilsp/", | |
| "fine", | |
| "dialect", | |
| "custom", | |
| ] | |
| return any(indicator in model_name.lower() for indicator in fine_tuned_indicators) | |
| def create_pipe_for_fine_tuned(self, model_name): | |
| """Special handling for fine-tuned models""" | |
| try: | |
| logger.info(f"Loading fine-tuned model: {model_name}") | |
| # Device selection - be more conservative for fine-tuned models | |
| if torch.cuda.is_available(): | |
| device = "cuda:0" | |
| torch_dtype = torch.float32 # Use float32 for stability | |
| else: | |
| device = "cpu" | |
| torch_dtype = torch.float32 | |
| logger.info(f"Using device: {device}, dtype: {torch_dtype}") | |
| # Try to load as Whisper model first | |
| try: | |
| logger.info("Attempting to load as WhisperForConditionalGeneration...") | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| cache_dir="./cache" | |
| ) | |
| processor = WhisperProcessor.from_pretrained(model_name) | |
| logger.info("Successfully loaded as Whisper model") | |
| except Exception as e: | |
| logger.info(f"Whisper loading failed: {e}, trying AutoModel...") | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=False, # Fine-tuned models might not have safetensors | |
| cache_dir="./cache" | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model.to(device) | |
| logger.info("Model moved to device") | |
| # Create pipeline with conservative settings | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| chunk_length_s=30, # Fixed chunk length for fine-tuned models | |
| ) | |
| logger.info("Fine-tuned model pipeline created successfully!") | |
| return pipe | |
| except Exception as e: | |
| logger.error(f"Failed to create fine-tuned model pipeline: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def create_pipe(self, model_name, use_flash_attention=True): | |
| """Create pipeline with special handling for fine-tuned models""" | |
| # Use special handling for fine-tuned models | |
| if self.is_fine_tuned_model(model_name): | |
| return self.create_pipe_for_fine_tuned(model_name) | |
| try: | |
| logger.info(f"Loading standard model: {model_name}") | |
| # Device selection | |
| if torch.cuda.is_available(): | |
| device = "cuda:0" | |
| torch_dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| torch_dtype = torch.float32 | |
| # Attention implementation - disable for fine-tuned models | |
| attn_implementation = "eager" | |
| if use_flash_attention and FLASH_ATTN_AVAILABLE and is_flash_attn_2_available() and torch.cuda.is_available(): | |
| try: | |
| attn_implementation = "flash_attention_2" | |
| logger.info("Using Flash Attention 2") | |
| except: | |
| attn_implementation = "eager" | |
| logger.info("Flash Attention 2 failed, using eager") | |
| # Load model | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_name, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| attn_implementation=attn_implementation, | |
| cache_dir="./cache" | |
| ) | |
| model.to(device) | |
| # Load processor | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # Create pipeline | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| torch_dtype=torch_dtype, | |
| device=device, | |
| ) | |
| logger.info("Standard model pipeline created successfully!") | |
| return pipe | |
| except Exception as e: | |
| logger.error(f"Failed to create standard model pipeline: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def load_model(self, model_name, use_flash_attention=True): | |
| """Load model with timeout protection""" | |
| if self.current_model != model_name or self.pipe is None: | |
| logger.info(f"Loading new model: {model_name}") | |
| # Clear previous model | |
| if self.pipe is not None: | |
| logger.info("Clearing previous model...") | |
| del self.pipe | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| try: | |
| # Disable flash attention for fine-tuned models | |
| if self.is_fine_tuned_model(model_name): | |
| use_flash_attention = False | |
| logger.info("Disabled flash attention for fine-tuned model") | |
| self.pipe = self.create_pipe(model_name, use_flash_attention) | |
| self.current_model = model_name if self.pipe else None | |
| if self.pipe: | |
| logger.info(f"Model {model_name} loaded successfully") | |
| return True | |
| else: | |
| logger.error(f"Failed to load model {model_name}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| return False | |
| else: | |
| logger.info("Model already loaded") | |
| return True | |
| def transcribe_audio_fine_tuned(self, audio_file, chunk_length_s=30, batch_size=1): | |
| """Special transcription method for fine-tuned models with conservative settings""" | |
| try: | |
| logger.info("Using fine-tuned model transcription method") | |
| # Use very conservative settings for fine-tuned models | |
| outputs = self.pipe( | |
| audio_file, | |
| chunk_length_s=min(chunk_length_s, 30), # Max 30 seconds | |
| batch_size=min(batch_size, 2), # Max batch size 2 | |
| return_timestamps=True, | |
| generate_kwargs={ | |
| "task": "transcribe", | |
| "do_sample": False, # Deterministic output | |
| "num_beams": 1, # No beam search | |
| "max_length": 448, # Conservative max length | |
| } | |
| ) | |
| return outputs | |
| except Exception as e: | |
| logger.error(f"Fine-tuned transcription failed: {e}") | |
| raise e | |
| def transcribe_audio(self, audio_file, model_name="openai/whisper-medium", | |
| language="Automatic Detection", task="transcribe", | |
| chunk_length_s=30, batch_size=16, use_flash_attention=True, | |
| return_timestamps=True): | |
| """Transcribe with special handling for fine-tuned models""" | |
| if audio_file is None: | |
| return "Please upload an audio file", "", "" | |
| try: | |
| logger.info("=== Starting transcription ===") | |
| start_time = time.time() | |
| # Load model | |
| success = self.load_model(model_name, use_flash_attention) | |
| if not success: | |
| return "Failed to load model", "", "" | |
| logger.info(f"Processing: {audio_file}") | |
| logger.info(f"Settings: {model_name}, {language}, {task}") | |
| # Check if this is a fine-tuned model | |
| is_fine_tuned = self.is_fine_tuned_model(model_name) | |
| if is_fine_tuned: | |
| logger.info("Using fine-tuned model optimizations") | |
| # Use special method for fine-tuned models | |
| outputs = self.transcribe_audio_fine_tuned( | |
| audio_file, chunk_length_s, batch_size | |
| ) | |
| else: | |
| # Standard transcription for regular models | |
| logger.info("Using standard model transcription") | |
| # Prepare generation kwargs | |
| generate_kwargs = {} | |
| # Language handling | |
| if language != "Automatic Detection" and not model_name.endswith(".en"): | |
| language_map = { | |
| "Greek": "greek", | |
| "English": "english", | |
| "Spanish": "spanish", | |
| "French": "french", | |
| "German": "german", | |
| "Italian": "italian" | |
| } | |
| lang_code = language_map.get(language, language.lower()) | |
| generate_kwargs["language"] = lang_code | |
| logger.info(f"Set language: {lang_code}") | |
| # Task handling | |
| if not model_name.endswith(".en"): | |
| generate_kwargs["task"] = task | |
| outputs = self.pipe( | |
| audio_file, | |
| chunk_length_s=chunk_length_s, | |
| batch_size=batch_size, | |
| generate_kwargs=generate_kwargs, | |
| return_timestamps=return_timestamps, | |
| ) | |
| transcription_time = time.time() - start_time | |
| logger.info(f"Transcription completed in {transcription_time:.2f} seconds") | |
| # Extract results | |
| transcription = outputs.get("text", "") if outputs else "" | |
| chunks = outputs.get("chunks", []) if outputs else [] | |
| # Handle timestamps | |
| timestamp_text = "" | |
| if return_timestamps: | |
| try: | |
| if chunks: | |
| timestamp_text = self._format_timestamps(chunks) | |
| else: | |
| timestamp_text = "=== TIMESTAMPS ===\nNo chunks returned.\n" | |
| except Exception as ts_error: | |
| logger.warning(f"Timestamp formatting error: {ts_error}") | |
| timestamp_text = f"=== TIMESTAMPS ===\nError: {str(ts_error)}\n" | |
| else: | |
| timestamp_text = "=== TIMESTAMPS ===\nDisabled.\n" | |
| # Create detailed output | |
| detailed_output = self._format_detailed_output( | |
| transcription, model_name, language, task, | |
| transcription_time, chunk_length_s, batch_size, | |
| use_flash_attention, len(chunks), is_fine_tuned | |
| ) | |
| return transcription.strip(), timestamp_text, detailed_output | |
| except Exception as e: | |
| error_msg = f"Transcription error: {str(e)}" | |
| logger.error(error_msg) | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return error_msg, "", error_msg | |
| def _format_timestamps(self, chunks): | |
| """Format timestamp information""" | |
| timestamp_text = "=== TIMESTAMPS ===\n" | |
| if not chunks: | |
| return timestamp_text + "No chunks available.\n" | |
| for i, chunk in enumerate(chunks): | |
| try: | |
| timestamp = chunk.get('timestamp', None) | |
| text = chunk.get('text', '') | |
| if timestamp is None: | |
| timestamp_text += f"[No timestamp]: {text}\n" | |
| elif isinstance(timestamp, (list, tuple)) and len(timestamp) >= 2: | |
| start, end = timestamp[0], timestamp[1] | |
| if start is None or end is None: | |
| timestamp_text += f"[Invalid]: {text}\n" | |
| else: | |
| try: | |
| start_f = float(start) | |
| end_f = float(end) | |
| timestamp_text += f"[{start_f:.1f}s - {end_f:.1f}s]: {text}\n" | |
| except (ValueError, TypeError): | |
| timestamp_text += f"[Format error]: {text}\n" | |
| else: | |
| timestamp_text += f"[Unexpected format]: {text}\n" | |
| except Exception as e: | |
| timestamp_text += f"[Chunk {i} error]: {str(e)}\n" | |
| return timestamp_text | |
| def _format_detailed_output(self, transcription, model_name, language, task, | |
| transcription_time, chunk_length_s, batch_size, | |
| use_flash_attention, num_chunks, is_fine_tuned=False): | |
| """Format detailed information""" | |
| output = "=== TRANSCRIPTION ===\n" | |
| output += f"{transcription}\n\n" | |
| output += "=== MODEL INFORMATION ===\n" | |
| output += f"Model: {model_name}\n" | |
| output += f"Model Type: {'Fine-tuned' if is_fine_tuned else 'Standard'}\n" | |
| output += f"Language: {language}\n" | |
| output += f"Task: {task}\n" | |
| output += f"Processing time: {transcription_time:.2f} seconds\n" | |
| output += f"Chunks processed: {num_chunks}\n" | |
| output += "\n=== PROCESSING SETTINGS ===\n" | |
| output += f"Chunk length: {chunk_length_s} seconds\n" | |
| output += f"Batch size: {batch_size}\n" | |
| output += f"Flash Attention: {'Enabled' if use_flash_attention and not is_fine_tuned else 'Disabled'}\n" | |
| if is_fine_tuned: | |
| output += "\n=== FINE-TUNED MODEL OPTIMIZATIONS ===\n" | |
| output += "• Conservative batch size (max 2)\n" | |
| output += "• Float32 precision for stability\n" | |
| output += "• Disabled flash attention\n" | |
| output += "• Deterministic generation\n" | |
| output += "• No beam search\n" | |
| return output | |
| def get_model_info(self): | |
| """Get current model information""" | |
| if self.pipe is None: | |
| return "No model loaded" | |
| try: | |
| device = next(self.pipe.model.parameters()).device | |
| dtype = next(self.pipe.model.parameters()).dtype | |
| model_type = "Fine-tuned" if self.is_fine_tuned_model(self.current_model) else "Standard" | |
| return f"✅ {self.current_model} ({model_type}) - {device} ({dtype})" | |
| except: | |
| return f"✅ {self.current_model} loaded" | |
| # Initialize the app | |
| logger.info("Initializing Optimized Whisper App...") | |
| whisper_app = OptimizedWhisperApp() | |
| def transcribe_wrapper(audio, model_name, language, task, chunk_length_s, | |
| batch_size, use_flash_attention, return_timestamps): | |
| """Wrapper for Gradio interface""" | |
| try: | |
| return whisper_app.transcribe_audio( | |
| audio, model_name, language, task, | |
| chunk_length_s, batch_size, use_flash_attention, return_timestamps | |
| ) | |
| except Exception as e: | |
| error_msg = f"Wrapper error: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, "", error_msg | |
| def get_model_status(): | |
| """Get current model status""" | |
| return whisper_app.get_model_info() | |
| def update_settings_for_model(model_name): | |
| """Update recommended settings based on model type""" | |
| is_fine_tuned = whisper_app.is_fine_tuned_model(model_name) | |
| if is_fine_tuned: | |
| return { | |
| "batch_size": gr.update(value=1, maximum=2), | |
| "use_flash_attention": gr.update(value=False), | |
| "chunk_length_s": gr.update(value=30) | |
| } | |
| else: | |
| return { | |
| "batch_size": gr.update(value=4, maximum=16), | |
| "use_flash_attention": gr.update(value=False), | |
| "chunk_length_s": gr.update(value=30) | |
| } | |
| # Create the interface | |
| def create_interface(): | |
| with gr.Blocks(title="Optimized Whisper Transcription", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown( | |
| """ | |
| # 🚀 ASR Fine-tuned Model for Lesbian Greek | |
| **Enhanced for Fine-tuned Models** | |
| Features: | |
| - Special handling for fine-tuned models (like Greek dialect) | |
| - Automatic optimization based on model type | |
| - Conservative settings for stability | |
| - Enhanced error handling | |
| """ | |
| ) | |
| # Model status | |
| model_status = gr.Textbox( | |
| value=get_model_status(), | |
| label="🔧 Current Model Status", | |
| interactive=False | |
| ) | |
| # Main interface | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Audio input | |
| audio_input = gr.Audio( | |
| label="🎵 Upload Audio File", | |
| type="filepath" | |
| ) | |
| # Model selection | |
| model_dropdown = gr.Dropdown( | |
| choices=whisper_app.available_models, | |
| value="openai/whisper-small", | |
| label="Model", | |
| info="Auto-optimizes settings for fine-tuned models" | |
| ) | |
| # Basic settings | |
| with gr.Row(): | |
| language_dropdown = gr.Dropdown( | |
| choices=["Automatic Detection", "Greek", "English", "Spanish", "French", "German", "Italian"], | |
| value="Automatic Detection", | |
| label="Language" | |
| ) | |
| task_dropdown = gr.Dropdown( | |
| choices=["transcribe", "translate"], | |
| value="transcribe", | |
| label="Task" | |
| ) | |
| # Advanced settings | |
| with gr.Accordion("Advanced Settings", open=False): | |
| chunk_length_s = gr.Slider( | |
| minimum=10, | |
| maximum=60, | |
| value=30, | |
| step=5, | |
| label="Chunk Length (seconds)" | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=16, | |
| value=4, | |
| step=1, | |
| label="Batch Size", | |
| info="Auto-adjusted for fine-tuned models" | |
| ) | |
| use_flash_attention = gr.Checkbox( | |
| label="Flash Attention 2", | |
| value=False, | |
| info="Auto-disabled for fine-tuned models" | |
| ) | |
| return_timestamps = gr.Checkbox( | |
| label="Return Timestamps", | |
| value=True | |
| ) | |
| transcribe_btn = gr.Button( | |
| "🚀 Transcribe", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(): | |
| # Results | |
| transcription_output = gr.Textbox( | |
| label="Transcription", | |
| lines=8, | |
| show_copy_button=True | |
| ) | |
| with gr.Accordion("Timestamps", open=False): | |
| timestamps_output = gr.Textbox( | |
| label="Timestamp Information", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| with gr.Accordion("Detailed Information", open=False): | |
| detailed_output = gr.Textbox( | |
| label="Processing Details & Model Info", | |
| lines=15, | |
| show_copy_button=True | |
| ) | |
| # Event handlers | |
| transcribe_btn.click( | |
| fn=transcribe_wrapper, | |
| inputs=[audio_input, model_dropdown, language_dropdown, task_dropdown, | |
| chunk_length_s, batch_size, use_flash_attention, return_timestamps], | |
| outputs=[transcription_output, timestamps_output, detailed_output], | |
| show_progress=True | |
| ) | |
| # Auto-adjust settings when model changes | |
| model_dropdown.change( | |
| fn=lambda model: ( | |
| f"Model will be loaded on next transcription ({'Fine-tuned' if whisper_app.is_fine_tuned_model(model) else 'Standard'} model)", | |
| 1 if whisper_app.is_fine_tuned_model(model) else 4, | |
| False | |
| ), | |
| inputs=[model_dropdown], | |
| outputs=[model_status, batch_size, use_flash_attention] | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| ### 🎯 Fine-tuned Model Optimizations | |
| **Automatic optimizations for fine-tuned models:** | |
| - Batch size limited to 1-2 for stability | |
| - Flash Attention automatically disabled | |
| - Float32 precision for better compatibility | |
| - Conservative generation settings | |
| - Enhanced error handling | |
| **For Greek dialect model specifically:** | |
| - Use batch size 1 | |
| - Keep chunk length at 30 seconds | |
| - Language detection usually works well | |
| """ | |
| ) | |
| return interface | |
| # Launch the app | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch(share=True) |