Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import pandas as pd | |
| import time | |
| import datetime | |
| import re | |
| import subprocess | |
| import os | |
| import tempfile | |
| import spaces | |
| from transformers import pipeline | |
| from pyannote.audio import Pipeline | |
| import requests | |
| import base64 | |
| # Install flash attention for acceleration | |
| ''' | |
| try: | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| check=True | |
| ) | |
| except subprocess.CalledProcessError: | |
| print("Warning: Could not install flash-attn, falling back to default attention") | |
| ''' | |
| # Create global pipeline (similar to working HuggingFace example) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=torch.float16, | |
| device="cuda", | |
| model_kwargs={"attn_implementation": "flash_attention_2"}, | |
| return_timestamps=True, | |
| ) | |
| def comprehensive_flash_attention_verification(): | |
| """Comprehensive verification of flash attention setup""" | |
| print("π Running Flash Attention Verification...") | |
| print("=" * 50) | |
| verification_results = {} | |
| # Check 1: Package Installation | |
| print("π Checking Python packages...") | |
| try: | |
| import flash_attn | |
| print(f"β flash-attn: {flash_attn.__version__}") | |
| verification_results["flash_attn_installed"] = True | |
| except ImportError: | |
| print("β flash-attn: Not installed") | |
| verification_results["flash_attn_installed"] = False | |
| try: | |
| import transformers | |
| print(f"β transformers: {transformers.__version__}") | |
| verification_results["transformers_available"] = True | |
| except ImportError: | |
| print("β transformers: Not installed") | |
| verification_results["transformers_available"] = False | |
| # Check 2: CUDA Availability | |
| print("\nπ Checking CUDA availability...") | |
| cuda_available = torch.cuda.is_available() | |
| print(f"β CUDA available: {cuda_available}") | |
| if cuda_available: | |
| print(f"β CUDA version: {torch.version.cuda}") | |
| print(f"β GPU count: {torch.cuda.device_count()}") | |
| for i in range(torch.cuda.device_count()): | |
| print(f"β GPU {i}: {torch.cuda.get_device_name(i)}") | |
| verification_results["cuda_available"] = cuda_available | |
| # Check 3: Flash Attention Import | |
| print("\nπ Testing flash attention imports...") | |
| try: | |
| from flash_attn import flash_attn_func | |
| print("β flash_attn_func imported successfully") | |
| if flash_attn_func is None: | |
| print("β flash_attn_func is None") | |
| verification_results["flash_attn_import"] = False | |
| else: | |
| print("β flash_attn_func is callable") | |
| verification_results["flash_attn_import"] = True | |
| except ImportError as e: | |
| print(f"β Import error: {e}") | |
| verification_results["flash_attn_import"] = False | |
| except Exception as e: | |
| print(f"β Unexpected error: {e}") | |
| verification_results["flash_attn_import"] = False | |
| # Check 4: Flash Attention Functionality Test | |
| print("\nπ Testing flash attention functionality...") | |
| if not cuda_available: | |
| print("β οΈ Skipping functionality test - CUDA not available") | |
| verification_results["flash_attn_functional"] = False | |
| elif not verification_results.get("flash_attn_import", False): | |
| print("β οΈ Skipping functionality test - Import failed") | |
| verification_results["flash_attn_functional"] = False | |
| else: | |
| try: | |
| from flash_attn import flash_attn_func | |
| # Create small dummy tensors | |
| batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32 | |
| device = "cuda:0" | |
| dtype = torch.float16 | |
| print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}") | |
| q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
| k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
| v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
| print("β Tensors created successfully") | |
| # Test flash attention | |
| output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False) | |
| print(f"β Flash attention output shape: {output.shape}") | |
| print("β Flash attention test passed!") | |
| verification_results["flash_attn_functional"] = True | |
| except Exception as e: | |
| print(f"β Flash attention test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| verification_results["flash_attn_functional"] = False | |
| # Summary | |
| print("\n" + "=" * 50) | |
| print("π VERIFICATION SUMMARY") | |
| print("=" * 50) | |
| all_passed = True | |
| for check_name, result in verification_results.items(): | |
| status = "β PASS" if result else "β FAIL" | |
| print(f"{check_name}: {status}") | |
| if not result: | |
| all_passed = False | |
| if all_passed: | |
| print("\nπ All checks passed! Flash attention should work.") | |
| return True | |
| else: | |
| print("\nβ οΈ Some checks failed. Flash attention may not work properly.") | |
| print("\nRecommendations:") | |
| print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation") | |
| print("2. Check CUDA compatibility with your PyTorch version") | |
| print("3. Consider using default attention as fallback") | |
| return False | |
| class WhisperTranscriber: | |
| def __init__(self): | |
| self.pipe = pipe # Use global pipeline | |
| self.diarization_model = None | |
| #@spaces.GPU | |
| def setup_models(self): | |
| """Initialize models with GPU acceleration""" | |
| if self.pipe is None: | |
| print("Loading Whisper model...") | |
| self.pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=torch.float16, | |
| device="cuda:0", | |
| model_kwargs={"attn_implementation": "flash_attention_2"}, | |
| return_timestamps=True, | |
| ) | |
| if self.diarization_model is None: | |
| print("Loading diarization model...") | |
| # Note: You'll need to set up authentication for pyannote models | |
| # For demo purposes, we'll handle the case where it's not available | |
| try: | |
| self.diarization_model = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=os.getenv("HF_TOKEN") | |
| ).to(torch.device("cuda")) | |
| except Exception as e: | |
| print(f"Could not load diarization model: {e}") | |
| self.diarization_model = None | |
| def convert_audio_format(self, audio_path): | |
| """Convert audio to 16kHz mono WAV format""" | |
| temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_wav_path = temp_wav.name | |
| temp_wav.close() | |
| try: | |
| subprocess.run([ | |
| "ffmpeg", "-i", audio_path, | |
| "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", | |
| temp_wav_path, "-y" | |
| ], check=True, capture_output=True) | |
| return temp_wav_path | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"Audio conversion failed: {e}") | |
| def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None): | |
| """Transcribe audio using Whisper with flash attention""" | |
| # Run comprehensive flash attention verification | |
| #flash_attention_working = comprehensive_flash_attention_verification() | |
| #if not flash_attention_working: | |
| # print("β οΈ Flash attention verification failed, but proceeding with transcription...") | |
| # print("You may encounter the TypeError: 'NoneType' object is not callable error") | |
| ''' | |
| #if self.pipe is None: | |
| # self.setup_models() | |
| if next(self.pipe.model.parameters()).device.type != "cuda": | |
| self.pipe.model.to("cuda") | |
| ''' | |
| print("Starting transcription...") | |
| start_time = time.time() | |
| # Prepare generation kwargs | |
| generate_kwargs = {} | |
| if language: | |
| generate_kwargs["language"] = language | |
| if translate: | |
| generate_kwargs["task"] = "translate" | |
| if prompt: | |
| generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt) | |
| # Transcribe with timestamps | |
| result = self.pipe( | |
| audio_path, | |
| return_timestamps=True, | |
| generate_kwargs=generate_kwargs, | |
| chunk_length_s=30, | |
| batch_size=128, | |
| ) | |
| transcription_time = time.time() - start_time | |
| print(f"Transcription completed in {transcription_time:.2f} seconds") | |
| # Extract segments and detected language | |
| segments = [] | |
| if "chunks" in result: | |
| for chunk in result["chunks"]: | |
| segment = { | |
| "start": float(chunk["timestamp"][0] or 0), | |
| "end": float(chunk["timestamp"][1] or 0), | |
| "text": chunk["text"].strip(), | |
| } | |
| segments.append(segment) | |
| else: | |
| # Fallback for different result format | |
| segments = [{ | |
| "start": 0.0, | |
| "end": 0.0, | |
| "text": result["text"] | |
| }] | |
| detected_language = getattr(result, 'language', language or 'unknown') | |
| transcription_time = time.time() - start_time | |
| print(f"Transcription parse completed in {transcription_time:.2f} seconds") | |
| return segments, detected_language | |
| def perform_diarization(self, audio_path, num_speakers=None): | |
| """Perform speaker diarization""" | |
| if self.diarization_model is None: | |
| print("Diarization model not available, assigning single speaker") | |
| return [], 1 | |
| print("Starting diarization...") | |
| start_time = time.time() | |
| # Load audio for diarization | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Perform diarization | |
| diarization = self.diarization_model( | |
| {"waveform": waveform, "sample_rate": sample_rate}, | |
| num_speakers=num_speakers, | |
| ) | |
| # Convert to list format | |
| diarize_segments = [] | |
| diarization_list = list(diarization.itertracks(yield_label=True)) | |
| for turn, _, speaker in diarization_list: | |
| diarize_segments.append({ | |
| "start": turn.start, | |
| "end": turn.end, | |
| "speaker": speaker | |
| }) | |
| unique_speakers = {speaker for _, _, speaker in diarization_list} | |
| detected_num_speakers = len(unique_speakers) | |
| diarization_time = time.time() - start_time | |
| print(f"Diarization completed in {diarization_time:.2f} seconds") | |
| return diarize_segments, detected_num_speakers | |
| def merge_transcription_and_diarization(self, transcription_segments, diarization_segments): | |
| """Merge transcription segments with speaker information""" | |
| if not diarization_segments: | |
| # No diarization available, assign single speaker | |
| for segment in transcription_segments: | |
| segment["speaker"] = "SPEAKER_00" | |
| return transcription_segments | |
| print("Merging transcription and diarization...") | |
| diarize_df = pd.DataFrame(diarization_segments) | |
| final_segments = [] | |
| for segment in transcription_segments: | |
| # Calculate intersection with diarization segments | |
| diarize_df["intersection"] = np.maximum(0, | |
| np.minimum(diarize_df["end"], segment["end"]) - | |
| np.maximum(diarize_df["start"], segment["start"]) | |
| ) | |
| # Find speaker with maximum intersection | |
| dia_tmp = diarize_df[diarize_df["intersection"] > 0] | |
| if len(dia_tmp) > 0: | |
| speaker = ( | |
| dia_tmp.groupby("speaker")["intersection"] | |
| .sum() | |
| .sort_values(ascending=False) | |
| .index[0] | |
| ) | |
| else: | |
| speaker = "SPEAKER_00" | |
| segment["speaker"] = speaker | |
| segment["duration"] = segment["end"] - segment["start"] | |
| final_segments.append(segment) | |
| return final_segments | |
| def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
| """Group consecutive segments from the same speaker""" | |
| if not segments: | |
| return segments | |
| grouped_segments = [] | |
| current_group = segments[0].copy() | |
| sentence_end_pattern = r"[.!?]+\s*$" | |
| for segment in segments[1:]: | |
| time_gap = segment["start"] - current_group["end"] | |
| current_duration = current_group["end"] - current_group["start"] | |
| # Conditions for combining segments | |
| can_combine = ( | |
| segment["speaker"] == current_group["speaker"] and | |
| time_gap <= max_gap and | |
| current_duration < max_duration and | |
| not re.search(sentence_end_pattern, current_group["text"]) | |
| ) | |
| if can_combine: | |
| # Merge segments | |
| current_group["end"] = segment["end"] | |
| current_group["text"] += " " + segment["text"] | |
| current_group["duration"] = current_group["end"] - current_group["start"] | |
| else: | |
| # Start new group | |
| grouped_segments.append(current_group) | |
| current_group = segment.copy() | |
| grouped_segments.append(current_group) | |
| # Clean up text | |
| for segment in grouped_segments: | |
| segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
| segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
| return grouped_segments | |
| def process_audio(self, audio_file, num_speakers=None, language=None, | |
| translate=False, prompt=None, group_segments=True): | |
| """Main processing function""" | |
| if audio_file is None: | |
| return {"error": "No audio file provided"} | |
| try: | |
| # Setup models if not already done | |
| #self.setup_models() | |
| # Convert audio format | |
| #wav_path = self.convert_audio_format(audio_file) | |
| try: | |
| # Transcribe audio | |
| transcription_segments, detected_language = self.transcribe_audio( | |
| audio_file, language, translate, prompt | |
| ) | |
| # Perform diarization | |
| diarization_segments, detected_num_speakers = self.perform_diarization( | |
| audio_file, num_speakers | |
| ) | |
| # Merge transcription and diarization | |
| final_segments = self.merge_transcription_and_diarization( | |
| transcription_segments, diarization_segments | |
| ) | |
| # Group segments if requested | |
| if group_segments: | |
| final_segments = self.group_segments_by_speaker(final_segments) | |
| return { | |
| "segments": final_segments, | |
| "language": detected_language, | |
| "num_speakers": detected_num_speakers or 1, | |
| "total_segments": len(final_segments) | |
| } | |
| finally: | |
| # Clean up temporary file | |
| if os.path.exists(audio_file): | |
| os.unlink(audio_file) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| # Initialize transcriber | |
| transcriber = WhisperTranscriber() | |
| def format_segments_for_display(result): | |
| """Format segments for display in Gradio""" | |
| if "error" in result: | |
| return f"β Error: {result['error']}" | |
| segments = result.get("segments", []) | |
| language = result.get("language", "unknown") | |
| num_speakers = result.get("num_speakers", 1) | |
| output = f"π― **Detection Results:**\n" | |
| output += f"- Language: {language}\n" | |
| output += f"- Speakers: {num_speakers}\n" | |
| output += f"- Segments: {len(segments)}\n\n" | |
| output += "π **Transcription:**\n\n" | |
| for i, segment in enumerate(segments, 1): | |
| start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
| end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
| speaker = segment.get("speaker", "SPEAKER_00") | |
| text = segment["text"] | |
| output += f"**{speaker}** ({start_time} β {end_time})\n" | |
| output += f"{text}\n\n" | |
| return output | |
| def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments): | |
| """Gradio interface function""" | |
| result = transcriber.process_audio( | |
| audio_file=audio_file, | |
| num_speakers=num_speakers if num_speakers > 0 else None, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| group_segments=group_segments | |
| ) | |
| formatted_output = format_segments_for_display(result) | |
| return formatted_output, result | |
| # Create Gradio interface | |
| demo = gr.Blocks( | |
| title="ποΈ Whisper Transcription with Speaker Diarization", | |
| theme="default" | |
| ) | |
| with demo: | |
| gr.Markdown(""" | |
| # ποΈ Advanced Audio Transcription & Speaker Diarization | |
| Upload an audio file to get accurate transcription with speaker identification, powered by: | |
| - **Whisper Large V3 Turbo** with Flash Attention for fast transcription | |
| - **Pyannote 3.1** for speaker diarization | |
| - **ZeroGPU** acceleration for optimal performance | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="π΅ Upload Audio File", | |
| type="filepath", | |
| #source="upload" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| num_speakers = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="Number of Speakers (0 = auto-detect)" | |
| ) | |
| language = gr.Dropdown( | |
| choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="auto", | |
| label="Language" | |
| ) | |
| translate = gr.Checkbox( | |
| label="Translate to English", | |
| value=False | |
| ) | |
| prompt = gr.Textbox( | |
| label="Vocabulary Prompt (names, acronyms, etc.)", | |
| placeholder="Enter names, technical terms, or context...", | |
| lines=2 | |
| ) | |
| group_segments = gr.Checkbox( | |
| label="Group segments by speaker", | |
| value=True | |
| ) | |
| process_btn = gr.Button("π Transcribe Audio", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Markdown( | |
| label="π Transcription Results", | |
| value="Upload an audio file and click 'Transcribe Audio' to get started!" | |
| ) | |
| output_json = gr.JSON( | |
| label="π§ Raw Output (JSON)", | |
| visible=False | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio_gradio, | |
| inputs=[ | |
| audio_input, | |
| num_speakers, | |
| language, | |
| translate, | |
| prompt, | |
| group_segments | |
| ], | |
| outputs=[output_text, output_json] | |
| ) | |
| # Examples | |
| gr.Markdown("### π Usage Tips:") | |
| gr.Markdown(""" | |
| - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more | |
| - **Max duration**: Recommended under 10 minutes for optimal performance | |
| - **Speaker detection**: Works best with clear, distinct voices | |
| - **Languages**: Supports 100+ languages with auto-detection | |
| - **Vocabulary**: Add names and technical terms in the prompt for better accuracy | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |