Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import tempfile | |
| from typing import Optional, Tuple | |
| import logging | |
| from utils.audio_processor import AudioProcessor | |
| from utils.downloader import MediaDownloader | |
| from utils.transcription import WhisperTranscriber | |
| from utils.formatters import SubtitleFormatter | |
| from utils.diarization import SpeakerDiarizer | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class WhisperTranscriberApp: | |
| """Main application class for Whisper Transcriber""" | |
| def __init__(self): | |
| self.transcriber = None | |
| self.diarizer = None | |
| self.current_model = None | |
| def process_media( | |
| self, | |
| file_input, | |
| url_input: str, | |
| model_size: str, | |
| language: str, | |
| enable_diarization: bool, | |
| progress=gr.Progress() | |
| ) -> Tuple[str, str, str, str, str]: | |
| """Main processing function for transcription""" | |
| temp_files = [] | |
| try: | |
| # Step 1: Get input audio file | |
| progress(0.05, desc="Processing input...") | |
| if url_input and url_input.strip(): | |
| audio_file, source_type = MediaDownloader.download_media( | |
| url_input, | |
| progress_callback=lambda msg: progress(0.1, desc=msg) | |
| ) | |
| temp_files.append(audio_file) | |
| elif file_input is not None: | |
| audio_file = file_input.name | |
| else: | |
| raise ValueError("Please provide either a file or a URL") | |
| # Step 2: Extract audio | |
| progress(0.15, desc="Extracting audio...") | |
| processed_audio = AudioProcessor.extract_audio( | |
| audio_file, | |
| output_format='wav', | |
| progress_callback=lambda msg: progress(0.2, desc=msg) | |
| ) | |
| temp_files.append(processed_audio) | |
| duration = AudioProcessor.get_audio_duration(processed_audio) | |
| # Step 3: Load model | |
| if self.transcriber is None or self.current_model != model_size: | |
| progress(0.25, desc=f"Loading Whisper {model_size} model...") | |
| self.transcriber = WhisperTranscriber(model_size=model_size) | |
| self.transcriber.load_model( | |
| progress_callback=lambda msg: progress(0.3, desc=msg) | |
| ) | |
| self.current_model = model_size | |
| # Step 4: Chunk audio | |
| progress(0.35, desc="Preparing audio...") | |
| chunks = AudioProcessor.chunk_audio( | |
| processed_audio, | |
| progress_callback=lambda msg: progress(0.4, desc=msg) | |
| ) | |
| for chunk_file, _ in chunks: | |
| if chunk_file != processed_audio: | |
| temp_files.append(chunk_file) | |
| # Step 5: Transcribe | |
| progress(0.45, desc="Transcribing audio...") | |
| if len(chunks) == 1: | |
| transcription_result = self.transcriber.transcribe( | |
| chunks[0][0], | |
| language=language, | |
| progress_callback=lambda msg: progress(0.65, desc=msg) | |
| ) | |
| else: | |
| transcription_result = self.transcriber.transcribe_chunks( | |
| chunks, | |
| language=language, | |
| progress_callback=lambda msg: progress(0.65, desc=msg) | |
| ) | |
| progress(0.70, desc="Transcription complete!") | |
| # Step 6: Diarization (optional) | |
| speaker_labels = None | |
| if enable_diarization: | |
| progress(0.75, desc="Performing speaker diarization...") | |
| if not SpeakerDiarizer.is_available(): | |
| progress(0.75, desc="Skipping diarization (HF_TOKEN not set)") | |
| else: | |
| try: | |
| if self.diarizer is None: | |
| self.diarizer = SpeakerDiarizer() | |
| diarization_result = self.diarizer.diarize( | |
| processed_audio, | |
| progress_callback=lambda msg: progress(0.85, desc=msg) | |
| ) | |
| speaker_labels = self.diarizer.align_with_transcription( | |
| diarization_result, | |
| transcription_result, | |
| progress_callback=lambda msg: progress(0.9, desc=msg) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Diarization failed: {e}") | |
| # Step 7: Generate outputs | |
| progress(0.92, desc="Generating output files...") | |
| output_prefix = tempfile.mktemp(prefix="whisper_output_") | |
| outputs = SubtitleFormatter.generate_all_formats( | |
| transcription_result, | |
| output_prefix, | |
| speaker_labels | |
| ) | |
| preview_text = f"""**Transcription Complete!** | |
| **Language:** {transcription_result['language']} | |
| **Duration:** {duration:.2f} seconds | |
| **Model Used:** {model_size} | |
| **Preview:** | |
| {transcription_result['text'][:500]}...""" | |
| progress(1.0, desc="Done!") | |
| AudioProcessor.cleanup_temp_files(*temp_files) | |
| return ( | |
| preview_text, | |
| outputs['srt'], | |
| outputs['vtt'], | |
| outputs['txt'], | |
| outputs['json'] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Processing failed: {e}") | |
| AudioProcessor.cleanup_temp_files(*temp_files) | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| # Create app instance | |
| app = WhisperTranscriberApp() | |
| # Get available options | |
| model_choices = WhisperTranscriber.get_available_models() | |
| language_choices = WhisperTranscriber.get_language_list() | |
| # Create interface | |
| with gr.Blocks(title="Whisper Transcriber") as demo: | |
| gr.Markdown("# 🎤 Whisper Transcriber\nGenerate subtitles from audio/video using OpenAI Whisper") | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File(label="Upload Audio/Video File") | |
| url_input = gr.Textbox(label="Or Paste URL", placeholder="YouTube or direct link") | |
| model_size = gr.Dropdown(choices=model_choices, value='tiny', label="Model Size") | |
| language = gr.Dropdown( | |
| choices=[(f"{v} ({k})", k) for k, v in language_choices.items()], | |
| value='auto', | |
| label="Language" | |
| ) | |
| enable_diarization = gr.Checkbox(label="Enable Speaker Diarization", value=False) | |
| btn = gr.Button("Generate Transcription", variant="primary") | |
| with gr.Column(): | |
| preview = gr.Markdown(label="Preview") | |
| srt_file = gr.File(label="SRT File") | |
| vtt_file = gr.File(label="VTT File") | |
| txt_file = gr.File(label="TXT File") | |
| json_file = gr.File(label="JSON File") | |
| btn.click( | |
| fn=app.process_media, | |
| inputs=[file_input, url_input, model_size, language, enable_diarization], | |
| outputs=[preview, srt_file, vtt_file, txt_file, json_file] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() | |