Spaces:
Sleeping
Sleeping
| # Danish Speech-to-Text with Parakeet + PunctFixer v2 | |
| import gradio as gr | |
| import nemo.collections.asr as nemo_asr | |
| import torch | |
| from punctfix import PunctFixer | |
| import soundfile as sf | |
| import numpy as np | |
| import base64 | |
| import io | |
| import json | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| import traceback | |
| from ten_vad import TenVad | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| asr_model = nemo_asr.models.ASRModel.from_pretrained( | |
| model_name="nvidia/parakeet-rnnt-110m-da-dk" | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| punct_fixer = PunctFixer(language="da", device=device) | |
| def detect_silence_periods(audio_data, sample_rate, prob_threshold=0.5, min_off_ms=48, min_on_ms=64): | |
| """Run TEN VAD to detect silence periods in audio. | |
| Args: | |
| audio_data: numpy array of audio samples (float, mono, 16kHz) | |
| sample_rate: sample rate (must be 16000) | |
| prob_threshold: VAD probability threshold (0.0-1.0), higher = less sensitive | |
| min_off_ms: Minimum silence duration in ms - shorter silences are filled in as voice | |
| min_on_ms: Minimum voice duration in ms - shorter voice bursts are removed | |
| Returns: | |
| List of dicts with 'start' and 'end' times for each silence period | |
| """ | |
| TARGET_SR = 16000 # TEN VAD requires 16kHz | |
| HOP_SIZE = 256 # 16ms at 16kHz | |
| FRAME_MS = 16.0 # Each frame is 16ms | |
| print(f"[VAD] Settings: prob_threshold={prob_threshold}, min_off_ms={min_off_ms}, min_on_ms={min_on_ms}") | |
| if sample_rate != TARGET_SR: | |
| print(f"[VAD] Warning: Expected 16kHz audio, got {sample_rate}Hz") | |
| # Convert float audio to int16 (TEN VAD expects int16) | |
| if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: | |
| audio_int16 = (audio_data * 32767).astype(np.int16) | |
| else: | |
| audio_int16 = audio_data.astype(np.int16) | |
| # Create VAD instance | |
| vad = TenVad(hop_size=HOP_SIZE, threshold=prob_threshold) | |
| frame_duration = HOP_SIZE / TARGET_SR # 0.016s = 16ms | |
| # Process frame by frame and collect raw flags | |
| num_frames = len(audio_int16) // HOP_SIZE | |
| # Use list for mutable flags (will be modified by post-processing) | |
| is_voice = [0] * num_frames | |
| for i in range(num_frames): | |
| frame_start = i * HOP_SIZE | |
| frame = audio_int16[frame_start:frame_start + HOP_SIZE] | |
| result = vad.process(frame) | |
| # TEN VAD returns tuple: (probability, flag) or has .flag attribute | |
| if isinstance(result, tuple): | |
| flag = result[1] # (probability, flag) | |
| else: | |
| flag = result.flag | |
| is_voice[i] = flag | |
| # Convert ms thresholds to frame counts | |
| min_off_frames = int(min_off_ms / FRAME_MS + 0.1) | |
| min_on_frames = int(min_on_ms / FRAME_MS + 0.1) | |
| # Post-processing loop (matches VadPipeline.cpp logic) | |
| while True: | |
| # Pass 1: Fill in short silence gaps (minOff) | |
| # If silence duration <= min_off_frames, convert to voice | |
| if min_off_frames > 0: | |
| start_off = -1 | |
| for i in range(num_frames): | |
| if is_voice[i]: # Voice detected | |
| if start_off >= 0 and (i - start_off) <= min_off_frames: | |
| # Short silence gap - fill it in as voice | |
| for j in range(start_off, i): | |
| is_voice[j] = 1 | |
| start_off = -1 | |
| elif start_off < 0: | |
| start_off = i | |
| # Pass 2: Remove short voice bursts (minOn) | |
| # If voice duration <= min_on_frames, convert to silence | |
| changed = False | |
| if min_on_frames > 0: | |
| start_on = -1 | |
| for i in range(num_frames): | |
| if not is_voice[i]: # Silence detected | |
| if start_on >= 0 and (i - start_on) <= min_on_frames: | |
| # Short voice burst - remove it | |
| changed = True | |
| for j in range(start_on, i): | |
| is_voice[j] = 0 | |
| start_on = -1 | |
| elif start_on < 0: | |
| start_on = i | |
| # Handle case where audio ends with short voice burst | |
| if start_on >= 0 and (num_frames - start_on) <= min_on_frames: | |
| changed = True | |
| for j in range(start_on, num_frames): | |
| is_voice[j] = 0 | |
| # Exit loop if no changes or minOff is disabled | |
| if not changed or min_off_frames == 0: | |
| break | |
| # Convert frame flags to silence periods | |
| silence_periods = [] | |
| in_silence = False | |
| silence_start = 0.0 | |
| for i in range(num_frames): | |
| current_time = i * frame_duration | |
| if is_voice[i]: | |
| # Voice frame | |
| if in_silence: | |
| # End of silence period | |
| silence_periods.append({ | |
| 'start': round(silence_start, 3), | |
| 'end': round(current_time, 3) | |
| }) | |
| in_silence = False | |
| else: | |
| # Silence frame | |
| if not in_silence: | |
| # Start of silence period | |
| silence_start = current_time | |
| in_silence = True | |
| # Handle case where audio ends in silence | |
| if in_silence: | |
| silence_periods.append({ | |
| 'start': round(silence_start, 3), | |
| 'end': round(num_frames * frame_duration, 3) | |
| }) | |
| return silence_periods | |
| def print_speech_silence_log(timestamps_data, silence_periods): | |
| """Print interleaved speech and silence log sorted by start time.""" | |
| # Build unified list | |
| entries = [] | |
| # Add speech entries (word timestamps) | |
| for item in timestamps_data: | |
| entries.append({ | |
| 'type': 'speech', | |
| 'start': item['start'], | |
| 'end': item['end'], | |
| 'word': item['word'] | |
| }) | |
| # Add silence entries | |
| for item in silence_periods: | |
| entries.append({ | |
| 'type': 'silence', | |
| 'start': item['start'], | |
| 'end': item['end'] | |
| }) | |
| # Sort by start time | |
| entries.sort(key=lambda x: x['start']) | |
| # Print log | |
| print("\n=== SPEECH & SILENCE LOG ===") | |
| for entry in entries: | |
| if entry['type'] == 'speech': | |
| print(f"[Speech] [{entry['start']:.3f}-{entry['end']:.3f}] {entry['word']}") | |
| else: | |
| duration_ms = int((entry['end'] - entry['start']) * 1000) | |
| print(f"[Silence] [{entry['start']:.3f}-{entry['end']:.3f}] [{duration_ms}ms]") | |
| # Calculate summary | |
| total_silence = sum(p['end'] - p['start'] for p in silence_periods) | |
| print(f"\n=== SUMMARY ===") | |
| print(f"Words: {len(timestamps_data)}, Silence periods: {len(silence_periods)}, Total silence: {total_silence:.2f}s") | |
| print("=" * 30 + "\n") | |
| def parse_transcript_file(file_path): | |
| """Parse a transcript JSON file and extract word timestamps. | |
| Supports three formats: | |
| - Format 1: segments[].words[] with {start, end, word} | |
| - Format 2: Top-level words[] with {start, end, word} | |
| - Format 3: segments[] with {start, end, text} (text treated as single word) | |
| Args: | |
| file_path: Path to the JSON transcript file | |
| Returns: | |
| Tuple of (full_text, timestamps_data) where timestamps_data is list of | |
| {word, start, end} dicts | |
| Raises: | |
| ValueError if format not recognized | |
| """ | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| timestamps_data = [] | |
| full_text_parts = [] | |
| # Try Format 1: segments[].words[] with {start, end, word} | |
| if 'segments' in data and len(data['segments']) > 0: | |
| first_segment = data['segments'][0] | |
| if 'words' in first_segment and isinstance(first_segment['words'], list): | |
| # Format 1: Nested words inside segments | |
| print("[IMPORT] Detected Format 1: segments[].words[]") | |
| for segment in data['segments']: | |
| for word_entry in segment.get('words', []): | |
| word = word_entry.get('word', '').strip() | |
| if word: | |
| timestamps_data.append({ | |
| 'word': word, | |
| 'start': float(word_entry.get('start', 0)), | |
| 'end': float(word_entry.get('end', 0)) | |
| }) | |
| full_text_parts.append(word) | |
| return ' '.join(full_text_parts), timestamps_data | |
| # Try Format 3: segments[] with {start, end, text} (text as word) | |
| if 'text' in first_segment and 'start' in first_segment and 'end' in first_segment: | |
| print("[IMPORT] Detected Format 3: segments[] with {start, end, text}") | |
| for segment in data['segments']: | |
| text = segment.get('text', '').strip() | |
| if text: | |
| timestamps_data.append({ | |
| 'word': text, | |
| 'start': float(segment.get('start', 0)), | |
| 'end': float(segment.get('end', 0)) | |
| }) | |
| full_text_parts.append(text) | |
| # Use top-level text if available, otherwise join segments | |
| full_text = data.get('text', ' '.join(full_text_parts)) | |
| return full_text, timestamps_data | |
| # Try Format 2: Top-level words[] with {start, end, word} | |
| if 'words' in data and isinstance(data['words'], list): | |
| print("[IMPORT] Detected Format 2: words[]") | |
| for word_entry in data['words']: | |
| word = word_entry.get('word', '').strip() | |
| if word: | |
| timestamps_data.append({ | |
| 'word': word, | |
| 'start': float(word_entry.get('start', 0)), | |
| 'end': float(word_entry.get('end', 0)) | |
| }) | |
| full_text_parts.append(word) | |
| # Use top-level text if available, otherwise join words | |
| full_text = data.get('text', ' '.join(full_text_parts)) | |
| return full_text, timestamps_data | |
| raise ValueError("Unrecognized transcript format. Expected segments[].words[], words[], or segments[] with {start, end, text}") | |
| def load_transcript(audio, transcript_file, prob_threshold=0.5, min_off_ms=48, min_on_ms=64): | |
| """Load external transcript and run VAD on audio. | |
| Similar to transcribe_audio but skips ASR - uses timestamps from transcript file. | |
| Args: | |
| audio: Path to audio file | |
| transcript_file: Path to JSON transcript file | |
| prob_threshold: VAD probability threshold | |
| min_off_ms: Minimum silence duration in ms | |
| min_on_ms: Minimum voice duration in ms | |
| Returns: | |
| Same tuple format as transcribe_audio: | |
| (text, timestamps_data, audio_data_tuple, raw_text, export_metadata, silence_periods) | |
| """ | |
| try: | |
| # Check if audio is provided | |
| if audio is None: | |
| return "No audio provided. Please upload audio first.", [], None, "", {}, [] | |
| # Check if transcript file is provided | |
| if transcript_file is None: | |
| return "No transcript file provided.", [], None, "", {}, [] | |
| # Parse the transcript file | |
| try: | |
| text, timestamps_data = parse_transcript_file(transcript_file) | |
| print(f"[IMPORT] Loaded {len(timestamps_data)} word timestamps from transcript") | |
| except Exception as e: | |
| return f"Error parsing transcript file: {str(e)}", [], None, "", {}, [] | |
| # Preprocess audio: convert to mono if stereo | |
| audio_data, sample_rate = sf.read(audio) | |
| # Convert stereo to mono by averaging channels | |
| if len(audio_data.shape) > 1 and audio_data.shape[1] == 2: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Resample to 16kHz if needed (required by TEN VAD) | |
| TARGET_SR = 16000 | |
| if sample_rate != TARGET_SR: | |
| duration = len(audio_data) / sample_rate | |
| new_length = int(duration * TARGET_SR) | |
| x_old = np.linspace(0, duration, len(audio_data), endpoint=False) | |
| x_new = np.linspace(0, duration, new_length, endpoint=False) | |
| audio_data = np.interp(x_new, x_old, audio_data).astype(np.float32) | |
| print(f"[AUDIO] Resampled from {sample_rate}Hz to {TARGET_SR}Hz") | |
| sample_rate = TARGET_SR | |
| # Run VAD to detect silence periods | |
| silence_periods = [] | |
| try: | |
| silence_periods = detect_silence_periods(audio_data, sample_rate, prob_threshold, min_off_ms, min_on_ms) | |
| print_speech_silence_log(timestamps_data, silence_periods) | |
| except Exception as e: | |
| print(f"[VAD] Error during silence detection: {str(e)}\n{traceback.format_exc()}") | |
| # Calculate audio duration | |
| audio_duration = len(audio_data) / sample_rate | |
| # Build export metadata | |
| export_metadata = { | |
| 'model': 'imported-transcript', | |
| 'audio_duration': round(audio_duration, 2), | |
| 'word_count': len(timestamps_data), | |
| 'token_count': 0, | |
| 'hypothesis_score': None, | |
| 'frame_duration': None | |
| } | |
| # Return text, timestamps, audio data, raw_text (same as text for imports), export metadata, and silence periods | |
| return text, timestamps_data, (audio_data, sample_rate), text, export_metadata, silence_periods | |
| except Exception as e: | |
| return f"Error loading transcript: {str(e)}\n{traceback.format_exc()}", [], None, "", {}, [] | |
| def transcribe_audio(audio, prob_threshold=0.5, min_off_ms=48, min_on_ms=64): | |
| try: | |
| # Check if audio is provided | |
| if audio is None: | |
| return "No audio provided. Please record or upload audio first.", [], None, "", {} | |
| # Preprocess audio: convert to mono if stereo | |
| audio_data, sample_rate = sf.read(audio) | |
| # Convert stereo to mono by averaging channels | |
| if len(audio_data.shape) > 1 and audio_data.shape[1] == 2: | |
| audio_data = np.mean(audio_data, axis=1) | |
| # Resample to 16kHz if needed (required by both Parakeet and TEN VAD) | |
| TARGET_SR = 16000 | |
| if sample_rate != TARGET_SR: | |
| duration = len(audio_data) / sample_rate | |
| new_length = int(duration * TARGET_SR) | |
| x_old = np.linspace(0, duration, len(audio_data), endpoint=False) | |
| x_new = np.linspace(0, duration, new_length, endpoint=False) | |
| audio_data = np.interp(x_new, x_old, audio_data).astype(np.float32) | |
| print(f"[AUDIO] Resampled from {sample_rate}Hz to {TARGET_SR}Hz") | |
| sample_rate = TARGET_SR | |
| # Save as temporary mono 16kHz file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| mono_audio_path = temp_file.name | |
| temp_file.close() | |
| sf.write(mono_audio_path, audio_data, sample_rate) | |
| # Get transcription with word timestamps | |
| transcription = asr_model.transcribe( | |
| [mono_audio_path], | |
| return_hypotheses=True | |
| ) | |
| # Clean up temp file | |
| try: | |
| os.unlink(mono_audio_path) | |
| except: | |
| pass | |
| if transcription: | |
| hyp = transcription[0] | |
| raw_text = hyp.text | |
| punctuated_text = punct_fixer.punctuate(text=raw_text) | |
| # Get hypothesis score (overall log-probability) | |
| hypothesis_score = getattr(hyp, 'score', None) | |
| # Build precise word timestamps from token sequence | |
| timestamps_data = [] | |
| token_count = 0 | |
| try: | |
| if (hasattr(hyp, 'y_sequence') and hyp.y_sequence is not None and | |
| hasattr(hyp, 'timestamp') and hyp.timestamp is not None and | |
| hasattr(hyp, 'words') and hyp.words is not None): | |
| # Get token IDs and timestamps | |
| token_ids = hyp.y_sequence.tolist() if hasattr(hyp.y_sequence, 'tolist') else list(hyp.y_sequence) | |
| timestamps = hyp.timestamp.tolist() if hasattr(hyp.timestamp, 'tolist') else list(hyp.timestamp) | |
| words = hyp.words | |
| tokenizer = asr_model.tokenizer | |
| frame_duration = 0.08 | |
| token_count = len(token_ids) | |
| # Get full decoded text | |
| full_decoded = tokenizer.ids_to_text(token_ids) | |
| # Build token-to-character mapping by incrementally decoding | |
| token_char_ranges = [] | |
| prev_len = 0 | |
| for i in range(len(token_ids)): | |
| # Decode up to this token | |
| partial_text = tokenizer.ids_to_text(token_ids[:i+1]) | |
| curr_len = len(partial_text) | |
| token_char_ranges.append((prev_len, curr_len, timestamps[i] * frame_duration)) | |
| prev_len = curr_len | |
| # Build character-to-time array | |
| char_to_time = [0.0] * len(full_decoded) | |
| for start_pos, end_pos, token_time in token_char_ranges: | |
| for pos in range(start_pos, min(end_pos, len(char_to_time))): | |
| char_to_time[pos] = token_time | |
| # Find each word and get its timing | |
| search_pos = 0 | |
| for word in words: | |
| word_pos = full_decoded.find(word, search_pos) | |
| if word_pos >= 0 and word_pos < len(char_to_time): | |
| word_end_pos = min(word_pos + len(word), len(char_to_time)) | |
| start_time = char_to_time[word_pos] | |
| end_time = char_to_time[word_end_pos - 1] if word_end_pos > word_pos else start_time | |
| timestamps_data.append({ | |
| 'word': word, | |
| 'start': start_time, | |
| 'end': end_time | |
| }) | |
| search_pos = word_end_pos | |
| except Exception as e: | |
| print(f"Timestamp extraction failed: {str(e)}\n{traceback.format_exc()}") | |
| # Map punctuated words back to timestamps_data | |
| # ----------------------------------------------------------------- | |
| # The ASR model outputs raw lowercase text without punctuation. | |
| # PunctFixer adds punctuation and capitalization, but may occasionally: | |
| # - Merge words (e.g., "i morgen" → "imorgen") | |
| # - Split contractions differently | |
| # - Result in different word counts than the raw output | |
| # | |
| # We handle this by: | |
| # 1. If word counts match: direct position-based mapping (common case) | |
| # 2. If counts differ: fuzzy matching with lookahead to realign | |
| # ----------------------------------------------------------------- | |
| try: | |
| # Split punctuated text into words, keeping punctuation attached | |
| punctuated_words = punctuated_text.split() | |
| # Helper to strip punctuation for comparison (normalize for matching) | |
| def strip_punct(s): | |
| return re.sub(r'[^\w]', '', s).lower() | |
| # Align punctuated words to raw words | |
| if len(punctuated_words) == len(timestamps_data): | |
| # Same word count - direct mapping (most common case) | |
| # Verify base word matches before replacing to catch any edge cases | |
| for i, pw in enumerate(punctuated_words): | |
| if strip_punct(pw) == strip_punct(timestamps_data[i]['word']): | |
| timestamps_data[i]['word'] = pw | |
| else: | |
| # Different word counts - PunctFixer may have merged/split words | |
| # Use two-pointer approach with lookahead for realignment | |
| pi = 0 # punctuated index | |
| for ti in range(len(timestamps_data)): | |
| if pi >= len(punctuated_words): | |
| break | |
| raw_word = strip_punct(timestamps_data[ti]['word']) | |
| punct_word = strip_punct(punctuated_words[pi]) | |
| if raw_word == punct_word: | |
| timestamps_data[ti]['word'] = punctuated_words[pi] | |
| pi += 1 | |
| else: | |
| # Words don't match - try lookahead to find alignment | |
| # This handles cases where PunctFixer inserted/removed words | |
| for look_ahead in range(1, min(3, len(punctuated_words) - pi)): | |
| if strip_punct(punctuated_words[pi + look_ahead]) == raw_word: | |
| pi += look_ahead | |
| timestamps_data[ti]['word'] = punctuated_words[pi] | |
| pi += 1 | |
| break | |
| except Exception as e: | |
| # Graceful fallback: keep original raw words if mapping fails | |
| print(f"Punctuation mapping failed: {str(e)}") | |
| # Run VAD to detect silence periods | |
| silence_periods = [] | |
| try: | |
| silence_periods = detect_silence_periods(audio_data, sample_rate, prob_threshold, min_off_ms, min_on_ms) | |
| print_speech_silence_log(timestamps_data, silence_periods) | |
| except Exception as e: | |
| print(f"[VAD] Error during silence detection: {str(e)}\n{traceback.format_exc()}") | |
| # Calculate audio duration | |
| audio_duration = len(audio_data) / sample_rate | |
| # Build export metadata | |
| export_metadata = { | |
| 'model': 'nvidia/parakeet-rnnt-110m-da-dk', | |
| 'audio_duration': round(audio_duration, 2), | |
| 'word_count': len(timestamps_data), | |
| 'token_count': token_count, | |
| 'hypothesis_score': round(hypothesis_score, 4) if hypothesis_score is not None else None, | |
| 'frame_duration': 0.08 | |
| } | |
| # Return text, timestamps, audio data, raw_text, export metadata, and silence periods | |
| return punctuated_text, timestamps_data, (audio_data, sample_rate), raw_text, export_metadata, silence_periods | |
| return "No transcription available.", [], None, "", {}, [] | |
| except Exception as e: | |
| return f"Error during transcription: {str(e)}\n{traceback.format_exc()}", [], None, "", {}, [] | |
| def extract_audio_segment(audio_state, start_time, end_time, current_window=None): | |
| """Fast audio extraction from memory with waveform visualization. | |
| Args: | |
| audio_state: Tuple of (audio_data, sample_rate) | |
| start_time: Start time of the interval to play | |
| end_time: End time of the interval to play | |
| current_window: Dict with 'start' and 'end' of current waveform window, or None | |
| Returns: | |
| Tuple of (html_output, new_window_state) | |
| """ | |
| # Wrapper to ensure controls never collapse | |
| def wrap_output(content, window_state=None): | |
| return f'<div style="min-height: 200px;">{content}</div>', window_state | |
| try: | |
| if audio_state is None: | |
| return wrap_output("<p style='color: red; padding: 20px;'>No audio loaded. Please transcribe audio first.</p>") | |
| audio_data, sample_rate = audio_state | |
| audio_duration = len(audio_data) / sample_rate | |
| # Default context padding is 160ms | |
| DEFAULT_PADDING = 0.16 | |
| # Determine if we need to redraw the waveform or just update the shaded area | |
| need_redraw = True | |
| if current_window is not None: | |
| # Check if the new interval fits within the current window | |
| if start_time >= current_window['start'] and end_time <= current_window['end']: | |
| need_redraw = False | |
| # Reuse the current window boundaries | |
| padded_start = current_window['start'] | |
| padded_end = current_window['end'] | |
| if need_redraw: | |
| # Calculate new window with ±160ms padding | |
| padded_start = max(0, start_time - DEFAULT_PADDING) | |
| padded_end = min(audio_duration, end_time + DEFAULT_PADDING) | |
| # Extract padded segment for waveform visualization | |
| start_sample_padded = int(padded_start * sample_rate) | |
| end_sample_padded = int(padded_end * sample_rate) | |
| segment_for_waveform = audio_data[start_sample_padded:end_sample_padded] | |
| # Extract ACTUAL segment for playback (no padding) | |
| start_sample = int(start_time * sample_rate) | |
| end_sample = int(end_time * sample_rate) | |
| segment_for_playback = audio_data[start_sample:end_sample] | |
| # Generate waveform visualization with padded segment (reduced height) | |
| fig, ax = plt.subplots(figsize=(12, 2.25)) | |
| # Downsample for visualization using block averaging (more accurate than skipping) | |
| max_points = 8000 | |
| if len(segment_for_waveform) > max_points: | |
| # Reshape into blocks and take mean of each block | |
| # Pad to multiple of block_size to avoid losing end samples | |
| block_size = len(segment_for_waveform) // max_points | |
| remainder = len(segment_for_waveform) % block_size | |
| if remainder > 0: | |
| # Pad with the last value to make it divisible | |
| padding_needed = block_size - remainder | |
| segment_padded = np.pad(segment_for_waveform, (0, padding_needed), mode='edge') | |
| else: | |
| segment_padded = segment_for_waveform | |
| segment_vis = segment_padded.reshape(-1, block_size).mean(axis=1) | |
| # Generate matching time points spanning the FULL padded range | |
| times_vis = np.linspace(padded_start, padded_end, len(segment_vis)) | |
| else: | |
| segment_vis = segment_for_waveform | |
| times_vis = np.linspace(padded_start, padded_end, len(segment_for_waveform)) | |
| ax.plot(times_vis, segment_vis, linewidth=0.5, color='#666') | |
| ax.fill_between(times_vis, segment_vis, alpha=0.3, color='#ccc') | |
| # Highlight the actual playback region (without padding) | |
| ax.axvspan(start_time, end_time, alpha=0.3, color='#4CAF50', label='Playback region') | |
| ax.axvspan(padded_start, start_time, alpha=0.1, color='#888', label='Context (not played)') | |
| ax.axvspan(end_time, padded_end, alpha=0.1, color='#888') | |
| ax.set_xlabel('Time (seconds)', fontsize=10) | |
| ax.set_ylabel('Amplitude', fontsize=10) | |
| # Calculate context on each side in ms | |
| left_context_ms = int((start_time - padded_start) * 1000) | |
| right_context_ms = int((padded_end - end_time) * 1000) | |
| # Format context string - symmetric or asymmetric | |
| if left_context_ms == right_context_ms: | |
| context_str = f'(±{left_context_ms}ms context)' | |
| else: | |
| context_str = f'(-{left_context_ms}ms context +{right_context_ms}ms context)' | |
| ax.set_title(f'Audio Segment: {start_time:.3f}s – {end_time:.3f}s {context_str}', fontsize=11) | |
| ax.legend(fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| # Convert plot to base64 image | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| plt.savefig(buf, format='png', dpi=100) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode() | |
| plt.close(fig) | |
| # Convert PLAYBACK segment (no padding) to base64 WAV | |
| audio_buf = io.BytesIO() | |
| sf.write(audio_buf, segment_for_playback, sample_rate, format='WAV') | |
| audio_buf.seek(0) | |
| audio_base64 = base64.b64encode(audio_buf.read()).decode() | |
| audio_data_url = f"data:audio/wav;base64,{audio_base64}" | |
| # Add unique ID to force Gradio to re-render (triggers autoplay) | |
| unique_id = int(time.time() * 1000) | |
| # Calculate context on each side in ms for the info text | |
| left_context_ms = int((start_time - padded_start) * 1000) | |
| right_context_ms = int((padded_end - end_time) * 1000) | |
| # Format context string - symmetric or asymmetric | |
| if left_context_ms == right_context_ms: | |
| context_info = f'±{left_context_ms}ms' | |
| else: | |
| context_info = f'-{left_context_ms}ms / +{right_context_ms}ms' | |
| # Create HTML with waveform and native audio controls | |
| html_output = f''' | |
| <div style="margin: 10px 0;" data-render-id="{unique_id}"> | |
| <img src="data:image/png;base64,{img_base64}" style="width: 100%; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
| <div style="margin-top: 10px; display: flex; align-items: center; gap: 15px;"> | |
| <audio id="segment-audio" controls autoplay style="flex: 1;"> | |
| <source src="{audio_data_url}" type="audio/wav"> | |
| </audio> | |
| </div> | |
| <div style="margin-top: 8px; text-align: center;"> | |
| <span style="font-size: 14px; font-weight: bold; color: #333;"> | |
| Segment: {start_time:.3f}s – {end_time:.3f}s | |
| </span> | |
| <span style="font-size: 12px; color: #666; margin-left: 15px;"> | |
| Duration: {(end_time - start_time)*1000:.0f}ms | Context shown: {context_info} | |
| </span> | |
| </div> | |
| </div> | |
| ''' | |
| # Return HTML and new window state | |
| new_window = {'start': padded_start, 'end': padded_end} | |
| return wrap_output(html_output, new_window) | |
| except Exception as e: | |
| return wrap_output(f"<pre style='padding: 20px;'>Error: {str(e)}\n{traceback.format_exc()}</pre>", current_window) | |
| def build_timestamps_iframe_html(entries_json, export_json_str): | |
| """Build the interactive word timestamps iframe HTML. | |
| Args: | |
| entries_json: JSON string of word/silence entries | |
| export_json_str: JSON string of full export data for download | |
| Returns: | |
| Complete iframe HTML for embedding in Gradio | |
| """ | |
| iframe_html = f''' | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <style> | |
| * {{ margin: 0; padding: 0; box-sizing: border-box; }} | |
| body {{ font-family: -apple-system, BlinkMacSystemFont, sans-serif; padding: 10px; background: #f9f9f9; }} | |
| h3 {{ margin-bottom: 8px; font-size: 16px; }} | |
| .help {{ font-size: 11px; color: #666; margin-bottom: 10px; }} | |
| .container {{ max-height: 180px; overflow-y: auto; background: #fff; border-radius: 8px; padding: 8px; border: 1px solid #ddd; }} | |
| .word-btn {{ | |
| display: inline-block; | |
| background: #e8f4f8; | |
| padding: 5px 10px; | |
| margin: 3px; | |
| border-radius: 4px; | |
| cursor: pointer; | |
| border: 1px solid #cde; | |
| font-size: 13px; | |
| transition: all 0.15s; | |
| }} | |
| .word-btn:hover {{ background: #c5e5f5; }} | |
| .word-btn.selected {{ background: #4CAF50; color: white; border-color: #3a9; }} | |
| .silence-btn {{ | |
| display: inline-block; | |
| background: #ffe4c4; | |
| padding: 5px 8px; | |
| margin: 3px; | |
| border-radius: 4px; | |
| cursor: pointer; | |
| border: 1px solid #dca; | |
| font-size: 11px; | |
| transition: all 0.15s; | |
| }} | |
| .silence-btn:hover {{ background: #ffd4a4; }} | |
| .silence-btn.selected {{ background: #ff9800; color: white; border-color: #e68a00; }} | |
| .checkbox-container {{ | |
| display: inline-flex; | |
| align-items: center; | |
| margin-left: 15px; | |
| font-size: 12px; | |
| cursor: pointer; | |
| }} | |
| .checkbox-container input {{ | |
| margin-right: 5px; | |
| cursor: pointer; | |
| }} | |
| .checkbox-container:hover {{ | |
| color: #0066cc; | |
| }} | |
| .time {{ color: #0066cc; font-size: 10px; font-weight: bold; }} | |
| .silence-time {{ color: #996600; font-size: 10px; font-weight: bold; }} | |
| .duration {{ color: #666; font-size: 10px; margin-left: 3px; }} | |
| .word {{ margin-left: 4px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px;"> | |
| <div style="display: flex; align-items: center;"> | |
| <h3 style="margin: 0;">Word Timestamps</h3> | |
| <label class="checkbox-container" title="Extends word end times toward midpoint of gap to next word (max 120ms). Helps capture word endings that may be cut off."> | |
| <input type="checkbox" id="adjust-intervals"> | |
| Apply Time Interval Adjustment | |
| </label> | |
| </div> | |
| <a href="#" id="download-json" style="font-size: 12px; color: #0066cc; text-decoration: none;">📥 Download JSON</a> | |
| </div> | |
| <script>var exportJsonStr = {json.dumps(export_json_str)};</script> | |
| <p class="help"><b>Click</b> = select | <b>Ctrl+Click</b> = toggle | <b>Shift+Click</b> = range <span style="background: #ffe4c4; padding: 2px 8px; border-radius: 3px; border: 1px solid #dca;"></span> = detected non speech</p> | |
| <div class="container" id="words"></div> | |
| <script> | |
| var entries = {entries_json}; | |
| var container = document.getElementById('words'); | |
| // Merge consecutive silence periods (no word between them) | |
| function mergeConsecutiveSilences(entryList) {{ | |
| var merged = []; | |
| var pendingSilence = null; | |
| entryList.forEach(function(entry) {{ | |
| if (entry.type === 'silence') {{ | |
| if (pendingSilence === null) {{ | |
| pendingSilence = {{ type: 'silence', start: entry.start, end: entry.end }}; | |
| }} else {{ | |
| pendingSilence.end = entry.end; | |
| }} | |
| }} else {{ | |
| if (pendingSilence !== null) {{ | |
| merged.push(pendingSilence); | |
| pendingSilence = null; | |
| }} | |
| merged.push(entry); | |
| }} | |
| }}); | |
| if (pendingSilence !== null) {{ | |
| merged.push(pendingSilence); | |
| }} | |
| return merged; | |
| }} | |
| entries = mergeConsecutiveSilences(entries); | |
| var words = entries.filter(function(e) {{ return e.type === 'word'; }}); | |
| var silences = entries.filter(function(e) {{ return e.type === 'silence'; }}); | |
| function calculateAdjustedEnd(wordIndex) {{ | |
| var word = words[wordIndex]; | |
| var nextWord = words[wordIndex + 1]; | |
| if (!nextWord) return word.end; | |
| var gap = nextWord.start - word.end; | |
| var extension = Math.min(gap / 2, 0.12); | |
| return word.end + extension; | |
| }} | |
| var adjustedEnds = words.map(function(w, i) {{ return calculateAdjustedEnd(i); }}); | |
| var lastClickedIndex = -1; | |
| function getAllButtons() {{ | |
| return Array.from(container.querySelectorAll('.word-btn, .silence-btn')); | |
| }} | |
| function handleItemClick(btn, e) {{ | |
| var allBtns = getAllButtons(); | |
| var clickedIndex = allBtns.indexOf(btn); | |
| if (e.shiftKey && lastClickedIndex >= 0) {{ | |
| var start = Math.min(lastClickedIndex, clickedIndex); | |
| var end = Math.max(lastClickedIndex, clickedIndex); | |
| allBtns.forEach(function(b, i) {{ | |
| if (i >= start && i <= end) {{ | |
| b.classList.add('selected'); | |
| }} | |
| }}); | |
| }} else if (e.ctrlKey) {{ | |
| btn.classList.toggle('selected'); | |
| }} else {{ | |
| allBtns.forEach(function(b) {{ b.classList.remove('selected'); }}); | |
| btn.classList.add('selected'); | |
| }} | |
| lastClickedIndex = clickedIndex; | |
| updateInterval(); | |
| }} | |
| var wordIndex = 0; | |
| entries.forEach(function(entry, i) {{ | |
| var btn = document.createElement('span'); | |
| if (entry.type === 'word') {{ | |
| var wi = wordIndex; | |
| btn.className = 'word-btn'; | |
| btn.dataset.origS = entry.start; | |
| btn.dataset.origE = entry.end; | |
| btn.dataset.adjE = adjustedEnds[wi]; | |
| btn.dataset.s = entry.start; | |
| btn.dataset.e = entry.end; | |
| btn.dataset.word = entry.word; | |
| btn.innerHTML = '<span class="time">[' + entry.start.toFixed(3) + '-' + entry.end.toFixed(3) + 's]</span><span class="word"> ' + entry.word + '</span>'; | |
| btn.onclick = function(e) {{ handleItemClick(this, e); }}; | |
| wordIndex++; | |
| }} else {{ | |
| btn.className = 'silence-btn'; | |
| btn.dataset.s = entry.start; | |
| btn.dataset.e = entry.end; | |
| var durationMs = Math.round((entry.end - entry.start) * 1000); | |
| btn.innerHTML = '<span class="silence-time">[' + entry.start.toFixed(3) + '-' + entry.end.toFixed(3) + 's]</span><span class="duration">' + durationMs + 'ms</span>'; | |
| btn.onclick = function(e) {{ handleItemClick(this, e); }}; | |
| }} | |
| container.appendChild(btn); | |
| // Add vertical space after sentence-ending punctuation | |
| if (entry.type === 'word') {{ | |
| var lastChar = entry.word.slice(-1); | |
| if (lastChar === '.' || lastChar === '!' || lastChar === '?') {{ | |
| var spacer = document.createElement('div'); | |
| spacer.style.height = '15px'; | |
| container.appendChild(spacer); | |
| }} | |
| }} | |
| }}); | |
| function updateWordLabels() {{ | |
| var adjusted = document.getElementById('adjust-intervals').checked; | |
| document.querySelectorAll('.word-btn').forEach(function(btn) {{ | |
| var s = parseFloat(btn.dataset.origS); | |
| var e = adjusted ? parseFloat(btn.dataset.adjE) : parseFloat(btn.dataset.origE); | |
| btn.dataset.s = s; | |
| btn.dataset.e = e; | |
| btn.innerHTML = '<span class="time">[' + s.toFixed(3) + '-' + e.toFixed(3) + 's]</span><span class="word"> ' + btn.dataset.word + '</span>'; | |
| }}); | |
| updateInterval(); | |
| }} | |
| document.getElementById('adjust-intervals').addEventListener('change', updateWordLabels); | |
| function updateInterval() {{ | |
| var sel = document.querySelectorAll('.word-btn.selected, .silence-btn.selected'); | |
| if (sel.length === 0) return; | |
| var minS = Infinity, maxE = 0; | |
| sel.forEach(function(b) {{ | |
| minS = Math.min(minS, parseFloat(b.dataset.s)); | |
| maxE = Math.max(maxE, parseFloat(b.dataset.e)); | |
| }}); | |
| var interval = minS.toFixed(3) + '-' + maxE.toFixed(3); | |
| try {{ | |
| var boxes = parent.document.querySelectorAll('input[data-testid="textbox"], textarea'); | |
| boxes.forEach(function(box) {{ | |
| if (box.placeholder && box.placeholder.indexOf('start-end') !== -1) {{ | |
| box.value = interval; | |
| box.dispatchEvent(new Event('input', {{bubbles: true}})); | |
| }} | |
| }}); | |
| }} catch(err) {{ console.log('Could not update parent:', err); }} | |
| }} | |
| function highlightFromInterval(intervalStr) {{ | |
| if (!intervalStr) return; | |
| var parts = intervalStr.replace(',', '-').split('-'); | |
| if (parts.length !== 2) return; | |
| var s = parseFloat(parts[0]), e = parseFloat(parts[1]); | |
| if (isNaN(s) || isNaN(e)) return; | |
| document.querySelectorAll('.word-btn').forEach(function(btn) {{ | |
| var ws = parseFloat(btn.dataset.s); | |
| var we = parseFloat(btn.dataset.e); | |
| var itemDuration = we - ws; | |
| var overlapStart = Math.max(ws, s); | |
| var overlapEnd = Math.min(we, e); | |
| var overlap = Math.max(0, overlapEnd - overlapStart); | |
| if (itemDuration > 0 && (overlap / itemDuration) > 0.5) {{ | |
| btn.classList.add('selected'); | |
| }} else {{ | |
| btn.classList.remove('selected'); | |
| }} | |
| }}); | |
| }} | |
| function setupParentWatcher() {{ | |
| try {{ | |
| var boxes = parent.document.querySelectorAll('input[data-testid="textbox"], textarea'); | |
| boxes.forEach(function(box) {{ | |
| if (box.placeholder && box.placeholder.indexOf('start-end') !== -1) {{ | |
| box.addEventListener('blur', function() {{ | |
| highlightFromInterval(this.value); | |
| }}); | |
| box.addEventListener('keydown', function(e) {{ | |
| if (e.key === 'Enter') {{ | |
| highlightFromInterval(this.value); | |
| }} | |
| }}); | |
| }} | |
| }}); | |
| }} catch(err) {{ console.log('Could not setup parent watcher:', err); }} | |
| }} | |
| setTimeout(setupParentWatcher, 500); | |
| document.getElementById('download-json').onclick = function(e) {{ | |
| e.preventDefault(); | |
| var dataUrl = 'data:application/json;charset=utf-8,' + encodeURIComponent(exportJsonStr); | |
| var a = document.createElement('a'); | |
| a.href = dataUrl; | |
| a.download = 'transcript.json'; | |
| document.body.appendChild(a); | |
| a.click(); | |
| document.body.removeChild(a); | |
| return false; | |
| }}; | |
| </script> | |
| </body> | |
| </html> | |
| ''' | |
| iframe_srcdoc = iframe_html.replace('"', '"') | |
| return f''' | |
| <iframe srcdoc="{iframe_srcdoc}" style="width: 100%; height: 250px; border: none; border-radius: 8px;"></iframe> | |
| ''' | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """# [Parakeet-RNNT-110M-Danish](https://huggingface.co/nvidia/parakeet-rnnt-110m-da-dk) Speech-to-Text Transcription Demo""" | |
| ) | |
| gr.Markdown( | |
| """Upload audio and press Transcribe, or alternatively upload your own transcription in one of the supported formats. See [examples of supported formats here](https://drive.google.com/drive/folders/1qrjfHjfssAZQIvSLJi36rLjpuqw3eOjj).""" | |
| ) | |
| # State to store audio data in memory for fast extraction | |
| audio_state = gr.State() | |
| timestamps_state = gr.State([]) # Store timestamps for dropdown | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="Upload or record your audio", | |
| sources=["upload", "microphone"], | |
| format="wav" | |
| ) | |
| # VAD Controls - inline labels with number inputs | |
| with gr.Row(): | |
| gr.Markdown("**VAD: Probability Threshold**") | |
| vad_prob_threshold = gr.Number( | |
| show_label=False, | |
| value=0.5, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| scale=0, | |
| min_width=80 | |
| ) | |
| gr.Markdown("**Min Voice Off (ms)**") | |
| vad_min_off = gr.Number( | |
| show_label=False, | |
| value=48, | |
| minimum=16, | |
| maximum=1000, | |
| step=16, | |
| scale=0, | |
| min_width=80 | |
| ) | |
| gr.Markdown("**Min Voice On (ms)**") | |
| vad_min_on = gr.Number( | |
| show_label=False, | |
| value=64, | |
| minimum=16, | |
| maximum=1000, | |
| step=16, | |
| scale=0, | |
| min_width=80 | |
| ) | |
| with gr.Row(): | |
| transcribe_button = gr.Button("Transcribe", scale=1) | |
| transcript_file_input = gr.File( | |
| label="Load Transcript (JSON)", | |
| file_types=[".json"], | |
| scale=1 | |
| ) | |
| transcription_output = gr.Textbox(label="Transcription", lines=5) | |
| timestamps_output = gr.HTML(label="Word Timestamps") | |
| # Time interval input - directly under timestamps (single row, no label) | |
| with gr.Row(): | |
| time_input = gr.Textbox( | |
| label="", | |
| show_label=False, | |
| container=False, | |
| placeholder="Time interval: start-end (e.g., 0.56-1.20)", | |
| scale=3, | |
| elem_id="time-interval-box" | |
| ) | |
| play_interval_button = gr.Button("▶ Play Interval", scale=1) | |
| # Track last played interval for smart replay | |
| last_interval_state = gr.State("") | |
| # Track current waveform window boundaries for smart redraw | |
| waveform_window_state = gr.State(None) | |
| # Waveform player - below interval controls | |
| waveform_player = gr.HTML(label="Segment Player") | |
| def load_transcript_and_setup(audio, transcript_file, prob_threshold, min_off_ms, min_on_ms): | |
| """Load external transcript and setup UI - mirrors transcribe_and_setup_audio.""" | |
| if transcript_file is None: | |
| # Return empty/unchanged outputs if no file selected | |
| return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() | |
| text, timestamps_data, audio_data, raw_text, export_metadata, silence_periods = load_transcript( | |
| audio, transcript_file, prob_threshold, int(min_off_ms), int(min_on_ms) | |
| ) | |
| # Check for errors | |
| if audio_data is None: | |
| # Error case - text contains error message | |
| return text, "", None, [], gr.update() | |
| # Build combined entries (words + silence) sorted by start time | |
| entries = [] | |
| for item in timestamps_data: | |
| entries.append({ | |
| 'type': 'word', | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| }) | |
| for item in silence_periods: | |
| entries.append({ | |
| 'type': 'silence', | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| }) | |
| entries.sort(key=lambda x: x['start']) | |
| entries_json = json.dumps(entries) | |
| # Build word data as JSON for the iframe | |
| words_json = json.dumps([{ | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| } for item in timestamps_data]) | |
| # Pre-generate full export JSON | |
| segments = [{ | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3), | |
| 'word_index': i | |
| } for i, item in enumerate(timestamps_data)] | |
| export_data = { | |
| 'metadata': export_metadata, | |
| 'text': text, | |
| 'raw_text': raw_text, | |
| 'segments': segments | |
| } | |
| export_json_str = json.dumps(export_data, ensure_ascii=False, indent=2) | |
| # Build iframe HTML using helper function | |
| timestamps_html = build_timestamps_iframe_html(entries_json, export_json_str) | |
| initial_player = ''' | |
| <div style="padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px; color: #666;"> | |
| <p>Select words above and click <b>▶ Play Interval</b> to hear the segment</p> | |
| </div> | |
| ''' | |
| return text, timestamps_html, audio_data, timestamps_data, initial_player | |
| def transcribe_and_setup_audio(audio, prob_threshold, min_off_ms, min_on_ms): | |
| text, timestamps_data, audio_data, raw_text, export_metadata, silence_periods = transcribe_audio( | |
| audio, prob_threshold, int(min_off_ms), int(min_on_ms) | |
| ) | |
| # Build combined entries (words + silence) sorted by start time | |
| entries = [] | |
| for item in timestamps_data: | |
| entries.append({ | |
| 'type': 'word', | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| }) | |
| for item in silence_periods: | |
| entries.append({ | |
| 'type': 'silence', | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| }) | |
| entries.sort(key=lambda x: x['start']) | |
| entries_json = json.dumps(entries) | |
| # Build word data as JSON for the iframe (kept for backward compat) | |
| words_json = json.dumps([{ | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3) | |
| } for item in timestamps_data]) | |
| # Pre-generate full export JSON | |
| segments = [{ | |
| 'word': item['word'], | |
| 'start': round(item['start'], 3), | |
| 'end': round(item['end'], 3), | |
| 'word_index': i | |
| } for i, item in enumerate(timestamps_data)] | |
| export_data = { | |
| 'metadata': export_metadata, | |
| 'text': text, | |
| 'raw_text': raw_text, | |
| 'segments': segments | |
| } | |
| export_json_str = json.dumps(export_data, ensure_ascii=False, indent=2) | |
| # Build iframe HTML using helper function | |
| timestamps_html = build_timestamps_iframe_html(entries_json, export_json_str) | |
| # Initial placeholder for waveform player so it doesn't collapse | |
| initial_player = ''' | |
| <div style="padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px; color: #666;"> | |
| <p>Select words above and click <b>▶ Play Interval</b> to hear the segment</p> | |
| </div> | |
| ''' | |
| return text, timestamps_html, audio_data, timestamps_data, initial_player | |
| def play_time_interval_fast(audio_state, time_interval, last_interval, current_window): | |
| """Fast extraction using preloaded audio from memory.""" | |
| def wrap_error(msg): | |
| return f'<div style="min-height: 150px; padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px;"><p style="color: #666;">{msg}</p></div>', last_interval, current_window | |
| try: | |
| if not time_interval or not audio_state: | |
| return wrap_error("No interval or audio loaded. Select words and try again.") | |
| # Parse the time interval | |
| time_interval = time_interval.strip().replace(',', '-') | |
| parts = time_interval.split('-') | |
| if len(parts) != 2: | |
| return wrap_error("Invalid interval format. Use: start-end (e.g., 1.20-2.50)") | |
| start_time = float(parts[0].strip()) | |
| end_time = float(parts[1].strip()) | |
| if start_time >= end_time: | |
| return wrap_error("Start time must be before end time.") | |
| # Load/reload audio segment (autoplay will replay even if same interval) | |
| # Pass current window state for smart redraw logic | |
| result_html, new_window = extract_audio_segment(audio_state, start_time, end_time, current_window) | |
| return result_html, time_interval, new_window | |
| except Exception as e: | |
| return wrap_error(f"Error: {str(e)}") | |
| transcribe_button.click( | |
| fn=transcribe_and_setup_audio, | |
| inputs=[audio_input, vad_prob_threshold, vad_min_off, vad_min_on], | |
| outputs=[transcription_output, timestamps_output, audio_state, timestamps_state, waveform_player] | |
| ) | |
| # Load transcript file input | |
| transcript_file_input.change( | |
| fn=load_transcript_and_setup, | |
| inputs=[audio_input, transcript_file_input, vad_prob_threshold, vad_min_off, vad_min_on], | |
| outputs=[transcription_output, timestamps_output, audio_state, timestamps_state, waveform_player] | |
| ) | |
| # Play interval button | |
| play_interval_button.click( | |
| fn=play_time_interval_fast, | |
| inputs=[audio_state, time_input, last_interval_state, waveform_window_state], | |
| outputs=[waveform_player, last_interval_state, waveform_window_state] | |
| ) | |
| demo.launch() | |