Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import json | |
| import os | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ============================================================================= | |
| # MODEL LOADING AND CONFIGURATION | |
| # ============================================================================= | |
| # Configure your model path - UPDATE THIS with your actual model name | |
| MODEL_NAME = "AfroLogicInsect/whisper-finetuned-float32" # Replace with your HF model | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| model_dtype = None | |
| def load_model(): | |
| """Load the Whisper model and processor""" | |
| global model, processor, model_dtype | |
| try: | |
| print(f"π Loading model: {MODEL_NAME}") | |
| # Load processor | |
| processor = WhisperProcessor.from_pretrained(MODEL_NAME) | |
| # Load model with appropriate dtype | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, # Use float32 for stability | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| model_dtype = torch.float32 | |
| # Move to GPU if available | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| print(f"β Model loaded on GPU: {torch.cuda.get_device_name()}") | |
| else: | |
| print("β Model loaded on CPU") | |
| return True | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| # Fallback to base Whisper model | |
| try: | |
| print("π Falling back to base Whisper model...") | |
| fallback_model = "openai/whisper-small" | |
| processor = WhisperProcessor.from_pretrained(fallback_model) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| fallback_model, | |
| torch_dtype=torch.float32 | |
| ) | |
| model_dtype = torch.float32 | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| print(f"β Fallback model loaded: {fallback_model}") | |
| return True | |
| except Exception as e2: | |
| print(f"β Fallback model loading failed: {e2}") | |
| return False | |
| # Load model on startup | |
| print("π Initializing Whisper Transcription Service...") | |
| model_loaded = load_model() | |
| # ============================================================================= | |
| # CORE TRANSCRIPTION FUNCTIONS | |
| # ============================================================================= | |
| def transcribe_audio_chunk(audio_chunk, sr=16000): | |
| """Transcribe a single audio chunk""" | |
| try: | |
| # Process with processor | |
| inputs = processor( | |
| audio_chunk, | |
| sampling_rate=sr, | |
| return_tensors="pt" | |
| ) | |
| input_features = inputs.input_features | |
| # Handle dtype matching | |
| if model_dtype == torch.float16: | |
| input_features = input_features.half() | |
| else: | |
| input_features = input_features.float() | |
| # Move to same device as model | |
| input_features = input_features.to(model.device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| try: | |
| predicted_ids = model.generate( | |
| input_features, | |
| language="en", | |
| task="transcribe", | |
| max_length=448, | |
| num_beams=1, | |
| do_sample=False, | |
| use_cache=True, | |
| no_repeat_ngram_size=2 | |
| ) | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription | |
| except RuntimeError as gen_error: | |
| if "Input type" in str(gen_error) and "bias type" in str(gen_error): | |
| # Handle dtype mismatch | |
| model.float() | |
| input_features = input_features.float() | |
| predicted_ids = model.generate( | |
| input_features, | |
| language="en", | |
| task="transcribe", | |
| max_length=448, | |
| num_beams=1, | |
| do_sample=False, | |
| no_repeat_ngram_size=2 | |
| ) | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription | |
| else: | |
| raise gen_error | |
| except Exception as e: | |
| print(f"β Chunk transcription failed: {e}") | |
| return None | |
| def process_audio_with_timestamps(audio_array, sr=16000, chunk_length=15): | |
| """Process audio with timestamps using robust chunking""" | |
| try: | |
| total_duration = len(audio_array) / sr | |
| # Check duration limit (3 minutes = 180 seconds) | |
| if total_duration > 180: | |
| return { | |
| "error": f"β οΈ Audio too long ({total_duration:.1f}s). Maximum allowed: 3 minutes (180s)", | |
| "success": False | |
| } | |
| chunk_samples = chunk_length * sr | |
| overlap_samples = int(2 * sr) # 2-second overlap | |
| all_segments = [] | |
| start = 0 | |
| chunk_index = 0 | |
| progress_updates = [] | |
| while start < len(audio_array): | |
| # Define chunk boundaries | |
| end = min(start + chunk_samples, len(audio_array)) | |
| # Add overlap for better transcription | |
| chunk_start_with_overlap = max(0, start - overlap_samples // 2) | |
| chunk_end_with_overlap = min(len(audio_array), end + overlap_samples // 2) | |
| chunk_audio = audio_array[chunk_start_with_overlap:chunk_end_with_overlap] | |
| # Calculate time boundaries | |
| start_time = start / sr | |
| end_time = end / sr | |
| # Update progress | |
| progress = (chunk_index + 1) / max(1, int(np.ceil(len(audio_array) / chunk_samples))) * 100 | |
| progress_updates.append(f"Processing chunk {chunk_index + 1}: {start_time:.1f}s - {end_time:.1f}s ({progress:.0f}%)") | |
| # Transcribe chunk | |
| transcription = transcribe_audio_chunk(chunk_audio, sr) | |
| if transcription and transcription.strip(): | |
| clean_text = transcription.strip() | |
| segment = { | |
| "start": round(start_time, 2), | |
| "end": round(end_time, 2), | |
| "text": clean_text, | |
| "duration": round(end_time - start_time, 2) | |
| } | |
| all_segments.append(segment) | |
| # Move to next chunk | |
| start = end | |
| chunk_index += 1 | |
| # Remove overlaps between segments | |
| cleaned_segments = remove_segment_overlaps(all_segments) | |
| if cleaned_segments: | |
| full_text = " ".join([seg["text"] for seg in cleaned_segments]) | |
| result = { | |
| "success": True, | |
| "text": full_text, | |
| "segments": cleaned_segments, | |
| "metadata": { | |
| "total_duration": round(total_duration, 2), | |
| "num_segments": len(cleaned_segments), | |
| "chunk_length": chunk_length, | |
| "processing_time": time.time() | |
| } | |
| } | |
| return result | |
| else: | |
| return { | |
| "error": "β No transcription could be generated", | |
| "success": False | |
| } | |
| except Exception as e: | |
| return { | |
| "error": f"β Processing failed: {str(e)}", | |
| "success": False | |
| } | |
| def remove_segment_overlaps(segments): | |
| """Remove overlapping text between segments""" | |
| if len(segments) <= 1: | |
| return segments | |
| cleaned_segments = [segments[0]] | |
| for i in range(1, len(segments)): | |
| current_segment = segments[i].copy() | |
| previous_text = cleaned_segments[-1]["text"] | |
| current_text = current_segment["text"] | |
| # Simple overlap detection | |
| prev_words = previous_text.lower().split() | |
| curr_words = current_text.lower().split() | |
| overlap_length = 0 | |
| max_check = min(8, len(prev_words), len(curr_words)) | |
| for j in range(1, max_check + 1): | |
| if prev_words[-j:] == curr_words[:j]: | |
| overlap_length = j | |
| if overlap_length > 0: | |
| remaining_words = current_text.split()[overlap_length:] | |
| if remaining_words: | |
| current_segment["text"] = " ".join(remaining_words) | |
| cleaned_segments.append(current_segment) | |
| else: | |
| cleaned_segments.append(current_segment) | |
| return cleaned_segments | |
| # ============================================================================= | |
| # GRADIO INTERFACE FUNCTIONS | |
| # ============================================================================= | |
| def transcribe_file(audio_file): | |
| """Handle file upload transcription""" | |
| if not model_loaded: | |
| return "β Model not loaded. Please refresh the page.", None, None | |
| if audio_file is None: | |
| return "β οΈ Please upload an audio file.", None, None | |
| try: | |
| # Load audio file | |
| audio_array, sr = librosa.load(audio_file, sr=16000) | |
| # Check duration | |
| duration = len(audio_array) / sr | |
| if duration > 180: # 3 minutes | |
| return f"β οΈ Audio too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None | |
| # Process with timestamps | |
| result = process_audio_with_timestamps(audio_array, sr) | |
| if result["success"]: | |
| # Format output | |
| formatted_text = format_transcription_output(result) | |
| # Create downloadable files | |
| json_file = create_json_download(result, audio_file) | |
| srt_file = create_srt_download(result, audio_file) | |
| return formatted_text, json_file, srt_file | |
| else: | |
| return result["error"], None, None | |
| except Exception as e: | |
| return f"β Error processing file: {str(e)}", None, None | |
| def transcribe_microphone(audio_data): | |
| """Handle microphone recording transcription""" | |
| if not model_loaded: | |
| return "β Model not loaded. Please refresh the page.", None, None | |
| if audio_data is None: | |
| return "β οΈ No audio recorded. Please record something first.", None, None | |
| try: | |
| # Extract sample rate and audio array from Gradio audio data | |
| sr, audio_array = audio_data | |
| # Convert to float32 and normalize | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) | |
| if audio_array.max() > 1.0: | |
| audio_array = audio_array / 32768.0 # Convert from int16 to float32 | |
| # Resample to 16kHz if needed | |
| if sr != 16000: | |
| audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000) | |
| sr = 16000 | |
| # Check duration | |
| duration = len(audio_array) / sr | |
| if duration > 180: # 3 minutes | |
| return f"β οΈ Recording too long ({duration:.1f}s). Maximum allowed: 3 minutes.", None, None | |
| if duration < 0.5: # Less than 0.5 seconds | |
| return "β οΈ Recording too short. Please record for at least 0.5 seconds.", None, None | |
| # Process with timestamps | |
| result = process_audio_with_timestamps(audio_array, sr) | |
| if result["success"]: | |
| # Format output | |
| formatted_text = format_transcription_output(result) | |
| # Create downloadable files | |
| json_file = create_json_download(result, "microphone_recording") | |
| srt_file = create_srt_download(result, "microphone_recording") | |
| return formatted_text, json_file, srt_file | |
| else: | |
| return result["error"], None, None | |
| except Exception as e: | |
| return f"β Error processing recording: {str(e)}", None, None | |
| def format_transcription_output(result): | |
| """Format transcription result for display""" | |
| output = [] | |
| # Header | |
| output.append("π― TRANSCRIPTION RESULTS") | |
| output.append("=" * 50) | |
| # Metadata | |
| metadata = result["metadata"] | |
| output.append(f"π Duration: {metadata['total_duration']}s") | |
| output.append(f"π Segments: {metadata['num_segments']}") | |
| output.append("") | |
| # Full text | |
| output.append("π FULL TRANSCRIPT:") | |
| output.append("-" * 30) | |
| output.append(result["text"]) | |
| output.append("") | |
| # Timestamped segments | |
| output.append("π TIMESTAMPED SEGMENTS:") | |
| output.append("-" * 30) | |
| for i, segment in enumerate(result["segments"], 1): | |
| start_min = int(segment["start"] // 60) | |
| start_sec = int(segment["start"] % 60) | |
| end_min = int(segment["end"] // 60) | |
| end_sec = int(segment["end"] % 60) | |
| time_str = f"{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}" | |
| output.append(f"{i:2d}. [{time_str}] {segment['text']}") | |
| return "\n".join(output) | |
| def create_json_download(result, source_name): | |
| """Create JSON file for download""" | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"transcription_{timestamp}.json" | |
| # Add metadata | |
| result["metadata"]["source"] = os.path.basename(str(source_name)) | |
| result["metadata"]["generated_at"] = datetime.now().isoformat() | |
| result["metadata"]["model"] = MODEL_NAME | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False, encoding='utf-8') as f: | |
| json.dump(result, f, indent=2, ensure_ascii=False) | |
| return f.name | |
| except Exception as e: | |
| print(f"Error creating JSON download: {e}") | |
| return None | |
| def create_srt_download(result, source_name): | |
| """Create SRT subtitle file for download""" | |
| try: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"subtitles_{timestamp}.srt" | |
| srt_content = [] | |
| for i, segment in enumerate(result["segments"], 1): | |
| start_time = format_time_srt(segment["start"]) | |
| end_time = format_time_srt(segment["end"]) | |
| srt_content.extend([ | |
| str(i), | |
| f"{start_time} --> {end_time}", | |
| segment["text"], | |
| "" | |
| ]) | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.srt', delete=False, encoding='utf-8') as f: | |
| f.write("\n".join(srt_content)) | |
| return f.name | |
| except Exception as e: | |
| print(f"Error creating SRT download: {e}") | |
| return None | |
| def format_time_srt(seconds): | |
| """Format seconds to SRT time format""" | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| millis = int((seconds % 1) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| def create_gradio_interface(): | |
| """Create the Gradio interface""" | |
| # Custom CSS for better styling | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .title { | |
| text-align: center; | |
| color: #2d3748; | |
| margin-bottom: 2rem; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| color: #4a5568; | |
| margin-bottom: 1rem; | |
| } | |
| .output-text { | |
| font-family: 'Courier New', monospace; | |
| background-color: #f7fafc; | |
| padding: 1rem; | |
| border-radius: 8px; | |
| border: 1px solid #e2e8f0; | |
| } | |
| .warning { | |
| background-color: #fff3cd; | |
| border: 1px solid #ffeaa7; | |
| color: #856404; | |
| padding: 10px; | |
| border-radius: 4px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="ποΈ Whisper Speech Transcription") as interface: | |
| # Header | |
| gr.HTML(""" | |
| <div class="title"> | |
| <h1>ποΈ Whisper Speech Transcription</h1> | |
| <p class="subtitle">Upload an audio file or record your voice to get an AI-powered transcription with timestamps</p> | |
| </div> | |
| """) | |
| # Warning about limits | |
| gr.HTML(""" | |
| <div class="warning"> | |
| <strong>β οΈ Important:</strong> Maximum audio length is 3 minutes (180 seconds). | |
| Longer files will be rejected to ensure fair usage for all users. | |
| </div> | |
| """) | |
| # Model status | |
| status_color = "green" if model_loaded else "red" | |
| status_text = "β Model loaded and ready" if model_loaded else "β Model loading failed" | |
| gr.HTML(f'<p style="color: {status_color}; text-align: center;"><strong>{status_text}</strong></p>') | |
| with gr.Tabs(): | |
| # Tab 1: File Upload | |
| with gr.TabItem("π Upload Audio File"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_file_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload"] | |
| ) | |
| file_transcribe_btn = gr.Button( | |
| "π Transcribe File", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Row(): | |
| file_output = gr.Textbox( | |
| label="Transcription Results", | |
| lines=15, | |
| placeholder="Your transcription will appear here...", | |
| elem_classes=["output-text"] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| json_download = gr.File( | |
| label="π Download JSON", | |
| visible=False | |
| ) | |
| with gr.Column(): | |
| srt_download = gr.File( | |
| label="π Download SRT Subtitles", | |
| visible=False | |
| ) | |
| # Tab 2: Voice Recording | |
| with gr.TabItem("π€ Record Voice"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_mic_input = gr.Audio( | |
| label="Record Your Voice", | |
| sources=["microphone"], | |
| type="numpy" | |
| ) | |
| mic_transcribe_btn = gr.Button( | |
| "π Transcribe Recording", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Row(): | |
| mic_output = gr.Textbox( | |
| label="Transcription Results", | |
| lines=15, | |
| placeholder="Your transcription will appear here...", | |
| elem_classes=["output-text"] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| json_download_mic = gr.File( | |
| label="π Download JSON", | |
| visible=False | |
| ) | |
| with gr.Column(): | |
| srt_download_mic = gr.File( | |
| label="π Download SRT Subtitles", | |
| visible=False | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 2rem; padding: 1rem; background-color: #f8f9fa; border-radius: 8px;"> | |
| <h3>π Output Formats</h3> | |
| <p><strong>JSON:</strong> Complete transcription data with timestamps and metadata</p> | |
| <p><strong>SRT:</strong> Standard subtitle format for video players</p> | |
| <p><strong>Display:</strong> Formatted text with timestamped segments</p> | |
| <br> | |
| <p style="color: #6c757d; font-size: 0.9em;"> | |
| Powered by Whisper AI | Maximum 3 minutes per audio | English language optimized | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| def update_file_outputs(result_text, json_file, srt_file): | |
| json_visible = json_file is not None | |
| srt_visible = srt_file is not None | |
| return ( | |
| result_text, | |
| gr.update(value=json_file, visible=json_visible), | |
| gr.update(value=srt_file, visible=srt_visible) | |
| ) | |
| file_transcribe_btn.click( | |
| fn=transcribe_file, | |
| inputs=[audio_file_input], | |
| outputs=[file_output, json_download, srt_download] | |
| ).then( | |
| fn=update_file_outputs, | |
| inputs=[file_output, json_download, srt_download], | |
| outputs=[file_output, json_download, srt_download] | |
| ) | |
| mic_transcribe_btn.click( | |
| fn=transcribe_microphone, | |
| inputs=[audio_mic_input], | |
| outputs=[mic_output, json_download_mic, srt_download_mic] | |
| ).then( | |
| fn=update_file_outputs, | |
| inputs=[mic_output, json_download_mic, srt_download_mic], | |
| outputs=[mic_output, json_download_mic, srt_download_mic] | |
| ) | |
| return interface | |
| # ============================================================================= | |
| # LAUNCH APPLICATION | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| interface = create_gradio_interface() | |
| # Launch configuration | |
| interface.launch( | |
| share=True, # Creates a public URL | |
| server_name="0.0.0.0", # Allows external access | |
| server_port=7860, # Standard Gradio port | |
| show_error=True, | |
| # enable_queue=True, # Handle multiple users | |
| max_threads=10 # Limit concurrent processing | |
| ) | |