Spaces:
Sleeping
Sleeping
File size: 7,215 Bytes
4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 4051511 7f464b5 eff77b5 7f464b5 72f1983 7f464b5 eff77b5 7f464b5 4051511 7f464b5 4051511 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import gradio as gr
import os
import tempfile
from typing import Optional, Tuple
import logging
from utils.audio_processor import AudioProcessor
from utils.downloader import MediaDownloader
from utils.transcription import WhisperTranscriber
from utils.formatters import SubtitleFormatter
from utils.diarization import SpeakerDiarizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class WhisperTranscriberApp:
"""Main application class for Whisper Transcriber"""
def __init__(self):
self.transcriber = None
self.diarizer = None
self.current_model = None
def process_media(
self,
file_input,
url_input: str,
model_size: str,
language: str,
enable_diarization: bool,
progress=gr.Progress()
) -> Tuple[str, str, str, str, str]:
"""Main processing function for transcription"""
temp_files = []
try:
# Step 1: Get input audio file
progress(0.05, desc="Processing input...")
if url_input and url_input.strip():
audio_file, source_type = MediaDownloader.download_media(
url_input,
progress_callback=lambda msg: progress(0.1, desc=msg)
)
temp_files.append(audio_file)
elif file_input is not None:
audio_file = file_input.name
else:
raise ValueError("Please provide either a file or a URL")
# Step 2: Extract audio
progress(0.15, desc="Extracting audio...")
processed_audio = AudioProcessor.extract_audio(
audio_file,
output_format='wav',
progress_callback=lambda msg: progress(0.2, desc=msg)
)
temp_files.append(processed_audio)
duration = AudioProcessor.get_audio_duration(processed_audio)
# Step 3: Load model
if self.transcriber is None or self.current_model != model_size:
progress(0.25, desc=f"Loading Whisper {model_size} model...")
self.transcriber = WhisperTranscriber(model_size=model_size)
self.transcriber.load_model(
progress_callback=lambda msg: progress(0.3, desc=msg)
)
self.current_model = model_size
# Step 4: Chunk audio
progress(0.35, desc="Preparing audio...")
chunks = AudioProcessor.chunk_audio(
processed_audio,
progress_callback=lambda msg: progress(0.4, desc=msg)
)
for chunk_file, _ in chunks:
if chunk_file != processed_audio:
temp_files.append(chunk_file)
# Step 5: Transcribe
progress(0.45, desc="Transcribing audio...")
if len(chunks) == 1:
transcription_result = self.transcriber.transcribe(
chunks[0][0],
language=language,
progress_callback=lambda msg: progress(0.65, desc=msg)
)
else:
transcription_result = self.transcriber.transcribe_chunks(
chunks,
language=language,
progress_callback=lambda msg: progress(0.65, desc=msg)
)
progress(0.70, desc="Transcription complete!")
# Step 6: Diarization (optional)
speaker_labels = None
if enable_diarization:
progress(0.75, desc="Performing speaker diarization...")
if not SpeakerDiarizer.is_available():
progress(0.75, desc="Skipping diarization (HF_TOKEN not set)")
else:
try:
if self.diarizer is None:
self.diarizer = SpeakerDiarizer()
diarization_result = self.diarizer.diarize(
processed_audio,
progress_callback=lambda msg: progress(0.85, desc=msg)
)
speaker_labels = self.diarizer.align_with_transcription(
diarization_result,
transcription_result,
progress_callback=lambda msg: progress(0.9, desc=msg)
)
except Exception as e:
logger.error(f"Diarization failed: {e}")
# Step 7: Generate outputs
progress(0.92, desc="Generating output files...")
output_prefix = tempfile.mktemp(prefix="whisper_output_")
outputs = SubtitleFormatter.generate_all_formats(
transcription_result,
output_prefix,
speaker_labels
)
preview_text = f"""**Transcription Complete!**
**Language:** {transcription_result['language']}
**Duration:** {duration:.2f} seconds
**Model Used:** {model_size}
**Preview:**
{transcription_result['text'][:500]}..."""
progress(1.0, desc="Done!")
AudioProcessor.cleanup_temp_files(*temp_files)
return (
preview_text,
outputs['srt'],
outputs['vtt'],
outputs['txt'],
outputs['json']
)
except Exception as e:
logger.error(f"Processing failed: {e}")
AudioProcessor.cleanup_temp_files(*temp_files)
raise gr.Error(f"Processing failed: {str(e)}")
# Create app instance
app = WhisperTranscriberApp()
# Get available options
model_choices = WhisperTranscriber.get_available_models()
language_choices = WhisperTranscriber.get_language_list()
# Create interface
with gr.Blocks(title="Whisper Transcriber") as demo:
gr.Markdown("# 🎤 Whisper Transcriber\nGenerate subtitles from audio/video using OpenAI Whisper")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload Audio/Video File")
url_input = gr.Textbox(label="Or Paste URL", placeholder="YouTube or direct link")
model_size = gr.Dropdown(choices=model_choices, value='tiny', label="Model Size")
language = gr.Dropdown(
choices=[(f"{v} ({k})", k) for k, v in language_choices.items()],
value='auto',
label="Language"
)
enable_diarization = gr.Checkbox(label="Enable Speaker Diarization", value=False)
btn = gr.Button("Generate Transcription", variant="primary")
with gr.Column():
preview = gr.Markdown(label="Preview")
srt_file = gr.File(label="SRT File")
vtt_file = gr.File(label="VTT File")
txt_file = gr.File(label="TXT File")
json_file = gr.File(label="JSON File")
btn.click(
fn=app.process_media,
inputs=[file_input, url_input, model_size, language, enable_diarization],
outputs=[preview, srt_file, vtt_file, txt_file, json_file]
)
if __name__ == "__main__":
demo.queue()
demo.launch()
|