DanishTTS2 / app.py
hlevring's picture
Clean up and adjust vad parameters
6dbea10
# Danish Speech-to-Text with Parakeet + PunctFixer v2
import gradio as gr
import nemo.collections.asr as nemo_asr
import torch
from punctfix import PunctFixer
import soundfile as sf
import numpy as np
import base64
import io
import json
import os
import re
import tempfile
import time
import traceback
from ten_vad import TenVad
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
asr_model = nemo_asr.models.ASRModel.from_pretrained(
model_name="nvidia/parakeet-rnnt-110m-da-dk"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
punct_fixer = PunctFixer(language="da", device=device)
def detect_silence_periods(audio_data, sample_rate, prob_threshold=0.5, min_off_ms=48, min_on_ms=64):
"""Run TEN VAD to detect silence periods in audio.
Args:
audio_data: numpy array of audio samples (float, mono, 16kHz)
sample_rate: sample rate (must be 16000)
prob_threshold: VAD probability threshold (0.0-1.0), higher = less sensitive
min_off_ms: Minimum silence duration in ms - shorter silences are filled in as voice
min_on_ms: Minimum voice duration in ms - shorter voice bursts are removed
Returns:
List of dicts with 'start' and 'end' times for each silence period
"""
TARGET_SR = 16000 # TEN VAD requires 16kHz
HOP_SIZE = 256 # 16ms at 16kHz
FRAME_MS = 16.0 # Each frame is 16ms
print(f"[VAD] Settings: prob_threshold={prob_threshold}, min_off_ms={min_off_ms}, min_on_ms={min_on_ms}")
if sample_rate != TARGET_SR:
print(f"[VAD] Warning: Expected 16kHz audio, got {sample_rate}Hz")
# Convert float audio to int16 (TEN VAD expects int16)
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
audio_int16 = (audio_data * 32767).astype(np.int16)
else:
audio_int16 = audio_data.astype(np.int16)
# Create VAD instance
vad = TenVad(hop_size=HOP_SIZE, threshold=prob_threshold)
frame_duration = HOP_SIZE / TARGET_SR # 0.016s = 16ms
# Process frame by frame and collect raw flags
num_frames = len(audio_int16) // HOP_SIZE
# Use list for mutable flags (will be modified by post-processing)
is_voice = [0] * num_frames
for i in range(num_frames):
frame_start = i * HOP_SIZE
frame = audio_int16[frame_start:frame_start + HOP_SIZE]
result = vad.process(frame)
# TEN VAD returns tuple: (probability, flag) or has .flag attribute
if isinstance(result, tuple):
flag = result[1] # (probability, flag)
else:
flag = result.flag
is_voice[i] = flag
# Convert ms thresholds to frame counts
min_off_frames = int(min_off_ms / FRAME_MS + 0.1)
min_on_frames = int(min_on_ms / FRAME_MS + 0.1)
# Post-processing loop (matches VadPipeline.cpp logic)
while True:
# Pass 1: Fill in short silence gaps (minOff)
# If silence duration <= min_off_frames, convert to voice
if min_off_frames > 0:
start_off = -1
for i in range(num_frames):
if is_voice[i]: # Voice detected
if start_off >= 0 and (i - start_off) <= min_off_frames:
# Short silence gap - fill it in as voice
for j in range(start_off, i):
is_voice[j] = 1
start_off = -1
elif start_off < 0:
start_off = i
# Pass 2: Remove short voice bursts (minOn)
# If voice duration <= min_on_frames, convert to silence
changed = False
if min_on_frames > 0:
start_on = -1
for i in range(num_frames):
if not is_voice[i]: # Silence detected
if start_on >= 0 and (i - start_on) <= min_on_frames:
# Short voice burst - remove it
changed = True
for j in range(start_on, i):
is_voice[j] = 0
start_on = -1
elif start_on < 0:
start_on = i
# Handle case where audio ends with short voice burst
if start_on >= 0 and (num_frames - start_on) <= min_on_frames:
changed = True
for j in range(start_on, num_frames):
is_voice[j] = 0
# Exit loop if no changes or minOff is disabled
if not changed or min_off_frames == 0:
break
# Convert frame flags to silence periods
silence_periods = []
in_silence = False
silence_start = 0.0
for i in range(num_frames):
current_time = i * frame_duration
if is_voice[i]:
# Voice frame
if in_silence:
# End of silence period
silence_periods.append({
'start': round(silence_start, 3),
'end': round(current_time, 3)
})
in_silence = False
else:
# Silence frame
if not in_silence:
# Start of silence period
silence_start = current_time
in_silence = True
# Handle case where audio ends in silence
if in_silence:
silence_periods.append({
'start': round(silence_start, 3),
'end': round(num_frames * frame_duration, 3)
})
return silence_periods
def print_speech_silence_log(timestamps_data, silence_periods):
"""Print interleaved speech and silence log sorted by start time."""
# Build unified list
entries = []
# Add speech entries (word timestamps)
for item in timestamps_data:
entries.append({
'type': 'speech',
'start': item['start'],
'end': item['end'],
'word': item['word']
})
# Add silence entries
for item in silence_periods:
entries.append({
'type': 'silence',
'start': item['start'],
'end': item['end']
})
# Sort by start time
entries.sort(key=lambda x: x['start'])
# Print log
print("\n=== SPEECH & SILENCE LOG ===")
for entry in entries:
if entry['type'] == 'speech':
print(f"[Speech] [{entry['start']:.3f}-{entry['end']:.3f}] {entry['word']}")
else:
duration_ms = int((entry['end'] - entry['start']) * 1000)
print(f"[Silence] [{entry['start']:.3f}-{entry['end']:.3f}] [{duration_ms}ms]")
# Calculate summary
total_silence = sum(p['end'] - p['start'] for p in silence_periods)
print(f"\n=== SUMMARY ===")
print(f"Words: {len(timestamps_data)}, Silence periods: {len(silence_periods)}, Total silence: {total_silence:.2f}s")
print("=" * 30 + "\n")
def parse_transcript_file(file_path):
"""Parse a transcript JSON file and extract word timestamps.
Supports three formats:
- Format 1: segments[].words[] with {start, end, word}
- Format 2: Top-level words[] with {start, end, word}
- Format 3: segments[] with {start, end, text} (text treated as single word)
Args:
file_path: Path to the JSON transcript file
Returns:
Tuple of (full_text, timestamps_data) where timestamps_data is list of
{word, start, end} dicts
Raises:
ValueError if format not recognized
"""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
timestamps_data = []
full_text_parts = []
# Try Format 1: segments[].words[] with {start, end, word}
if 'segments' in data and len(data['segments']) > 0:
first_segment = data['segments'][0]
if 'words' in first_segment and isinstance(first_segment['words'], list):
# Format 1: Nested words inside segments
print("[IMPORT] Detected Format 1: segments[].words[]")
for segment in data['segments']:
for word_entry in segment.get('words', []):
word = word_entry.get('word', '').strip()
if word:
timestamps_data.append({
'word': word,
'start': float(word_entry.get('start', 0)),
'end': float(word_entry.get('end', 0))
})
full_text_parts.append(word)
return ' '.join(full_text_parts), timestamps_data
# Try Format 3: segments[] with {start, end, text} (text as word)
if 'text' in first_segment and 'start' in first_segment and 'end' in first_segment:
print("[IMPORT] Detected Format 3: segments[] with {start, end, text}")
for segment in data['segments']:
text = segment.get('text', '').strip()
if text:
timestamps_data.append({
'word': text,
'start': float(segment.get('start', 0)),
'end': float(segment.get('end', 0))
})
full_text_parts.append(text)
# Use top-level text if available, otherwise join segments
full_text = data.get('text', ' '.join(full_text_parts))
return full_text, timestamps_data
# Try Format 2: Top-level words[] with {start, end, word}
if 'words' in data and isinstance(data['words'], list):
print("[IMPORT] Detected Format 2: words[]")
for word_entry in data['words']:
word = word_entry.get('word', '').strip()
if word:
timestamps_data.append({
'word': word,
'start': float(word_entry.get('start', 0)),
'end': float(word_entry.get('end', 0))
})
full_text_parts.append(word)
# Use top-level text if available, otherwise join words
full_text = data.get('text', ' '.join(full_text_parts))
return full_text, timestamps_data
raise ValueError("Unrecognized transcript format. Expected segments[].words[], words[], or segments[] with {start, end, text}")
def load_transcript(audio, transcript_file, prob_threshold=0.5, min_off_ms=48, min_on_ms=64):
"""Load external transcript and run VAD on audio.
Similar to transcribe_audio but skips ASR - uses timestamps from transcript file.
Args:
audio: Path to audio file
transcript_file: Path to JSON transcript file
prob_threshold: VAD probability threshold
min_off_ms: Minimum silence duration in ms
min_on_ms: Minimum voice duration in ms
Returns:
Same tuple format as transcribe_audio:
(text, timestamps_data, audio_data_tuple, raw_text, export_metadata, silence_periods)
"""
try:
# Check if audio is provided
if audio is None:
return "No audio provided. Please upload audio first.", [], None, "", {}, []
# Check if transcript file is provided
if transcript_file is None:
return "No transcript file provided.", [], None, "", {}, []
# Parse the transcript file
try:
text, timestamps_data = parse_transcript_file(transcript_file)
print(f"[IMPORT] Loaded {len(timestamps_data)} word timestamps from transcript")
except Exception as e:
return f"Error parsing transcript file: {str(e)}", [], None, "", {}, []
# Preprocess audio: convert to mono if stereo
audio_data, sample_rate = sf.read(audio)
# Convert stereo to mono by averaging channels
if len(audio_data.shape) > 1 and audio_data.shape[1] == 2:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16kHz if needed (required by TEN VAD)
TARGET_SR = 16000
if sample_rate != TARGET_SR:
duration = len(audio_data) / sample_rate
new_length = int(duration * TARGET_SR)
x_old = np.linspace(0, duration, len(audio_data), endpoint=False)
x_new = np.linspace(0, duration, new_length, endpoint=False)
audio_data = np.interp(x_new, x_old, audio_data).astype(np.float32)
print(f"[AUDIO] Resampled from {sample_rate}Hz to {TARGET_SR}Hz")
sample_rate = TARGET_SR
# Run VAD to detect silence periods
silence_periods = []
try:
silence_periods = detect_silence_periods(audio_data, sample_rate, prob_threshold, min_off_ms, min_on_ms)
print_speech_silence_log(timestamps_data, silence_periods)
except Exception as e:
print(f"[VAD] Error during silence detection: {str(e)}\n{traceback.format_exc()}")
# Calculate audio duration
audio_duration = len(audio_data) / sample_rate
# Build export metadata
export_metadata = {
'model': 'imported-transcript',
'audio_duration': round(audio_duration, 2),
'word_count': len(timestamps_data),
'token_count': 0,
'hypothesis_score': None,
'frame_duration': None
}
# Return text, timestamps, audio data, raw_text (same as text for imports), export metadata, and silence periods
return text, timestamps_data, (audio_data, sample_rate), text, export_metadata, silence_periods
except Exception as e:
return f"Error loading transcript: {str(e)}\n{traceback.format_exc()}", [], None, "", {}, []
def transcribe_audio(audio, prob_threshold=0.5, min_off_ms=48, min_on_ms=64):
try:
# Check if audio is provided
if audio is None:
return "No audio provided. Please record or upload audio first.", [], None, "", {}
# Preprocess audio: convert to mono if stereo
audio_data, sample_rate = sf.read(audio)
# Convert stereo to mono by averaging channels
if len(audio_data.shape) > 1 and audio_data.shape[1] == 2:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16kHz if needed (required by both Parakeet and TEN VAD)
TARGET_SR = 16000
if sample_rate != TARGET_SR:
duration = len(audio_data) / sample_rate
new_length = int(duration * TARGET_SR)
x_old = np.linspace(0, duration, len(audio_data), endpoint=False)
x_new = np.linspace(0, duration, new_length, endpoint=False)
audio_data = np.interp(x_new, x_old, audio_data).astype(np.float32)
print(f"[AUDIO] Resampled from {sample_rate}Hz to {TARGET_SR}Hz")
sample_rate = TARGET_SR
# Save as temporary mono 16kHz file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
mono_audio_path = temp_file.name
temp_file.close()
sf.write(mono_audio_path, audio_data, sample_rate)
# Get transcription with word timestamps
transcription = asr_model.transcribe(
[mono_audio_path],
return_hypotheses=True
)
# Clean up temp file
try:
os.unlink(mono_audio_path)
except:
pass
if transcription:
hyp = transcription[0]
raw_text = hyp.text
punctuated_text = punct_fixer.punctuate(text=raw_text)
# Get hypothesis score (overall log-probability)
hypothesis_score = getattr(hyp, 'score', None)
# Build precise word timestamps from token sequence
timestamps_data = []
token_count = 0
try:
if (hasattr(hyp, 'y_sequence') and hyp.y_sequence is not None and
hasattr(hyp, 'timestamp') and hyp.timestamp is not None and
hasattr(hyp, 'words') and hyp.words is not None):
# Get token IDs and timestamps
token_ids = hyp.y_sequence.tolist() if hasattr(hyp.y_sequence, 'tolist') else list(hyp.y_sequence)
timestamps = hyp.timestamp.tolist() if hasattr(hyp.timestamp, 'tolist') else list(hyp.timestamp)
words = hyp.words
tokenizer = asr_model.tokenizer
frame_duration = 0.08
token_count = len(token_ids)
# Get full decoded text
full_decoded = tokenizer.ids_to_text(token_ids)
# Build token-to-character mapping by incrementally decoding
token_char_ranges = []
prev_len = 0
for i in range(len(token_ids)):
# Decode up to this token
partial_text = tokenizer.ids_to_text(token_ids[:i+1])
curr_len = len(partial_text)
token_char_ranges.append((prev_len, curr_len, timestamps[i] * frame_duration))
prev_len = curr_len
# Build character-to-time array
char_to_time = [0.0] * len(full_decoded)
for start_pos, end_pos, token_time in token_char_ranges:
for pos in range(start_pos, min(end_pos, len(char_to_time))):
char_to_time[pos] = token_time
# Find each word and get its timing
search_pos = 0
for word in words:
word_pos = full_decoded.find(word, search_pos)
if word_pos >= 0 and word_pos < len(char_to_time):
word_end_pos = min(word_pos + len(word), len(char_to_time))
start_time = char_to_time[word_pos]
end_time = char_to_time[word_end_pos - 1] if word_end_pos > word_pos else start_time
timestamps_data.append({
'word': word,
'start': start_time,
'end': end_time
})
search_pos = word_end_pos
except Exception as e:
print(f"Timestamp extraction failed: {str(e)}\n{traceback.format_exc()}")
# Map punctuated words back to timestamps_data
# -----------------------------------------------------------------
# The ASR model outputs raw lowercase text without punctuation.
# PunctFixer adds punctuation and capitalization, but may occasionally:
# - Merge words (e.g., "i morgen" → "imorgen")
# - Split contractions differently
# - Result in different word counts than the raw output
#
# We handle this by:
# 1. If word counts match: direct position-based mapping (common case)
# 2. If counts differ: fuzzy matching with lookahead to realign
# -----------------------------------------------------------------
try:
# Split punctuated text into words, keeping punctuation attached
punctuated_words = punctuated_text.split()
# Helper to strip punctuation for comparison (normalize for matching)
def strip_punct(s):
return re.sub(r'[^\w]', '', s).lower()
# Align punctuated words to raw words
if len(punctuated_words) == len(timestamps_data):
# Same word count - direct mapping (most common case)
# Verify base word matches before replacing to catch any edge cases
for i, pw in enumerate(punctuated_words):
if strip_punct(pw) == strip_punct(timestamps_data[i]['word']):
timestamps_data[i]['word'] = pw
else:
# Different word counts - PunctFixer may have merged/split words
# Use two-pointer approach with lookahead for realignment
pi = 0 # punctuated index
for ti in range(len(timestamps_data)):
if pi >= len(punctuated_words):
break
raw_word = strip_punct(timestamps_data[ti]['word'])
punct_word = strip_punct(punctuated_words[pi])
if raw_word == punct_word:
timestamps_data[ti]['word'] = punctuated_words[pi]
pi += 1
else:
# Words don't match - try lookahead to find alignment
# This handles cases where PunctFixer inserted/removed words
for look_ahead in range(1, min(3, len(punctuated_words) - pi)):
if strip_punct(punctuated_words[pi + look_ahead]) == raw_word:
pi += look_ahead
timestamps_data[ti]['word'] = punctuated_words[pi]
pi += 1
break
except Exception as e:
# Graceful fallback: keep original raw words if mapping fails
print(f"Punctuation mapping failed: {str(e)}")
# Run VAD to detect silence periods
silence_periods = []
try:
silence_periods = detect_silence_periods(audio_data, sample_rate, prob_threshold, min_off_ms, min_on_ms)
print_speech_silence_log(timestamps_data, silence_periods)
except Exception as e:
print(f"[VAD] Error during silence detection: {str(e)}\n{traceback.format_exc()}")
# Calculate audio duration
audio_duration = len(audio_data) / sample_rate
# Build export metadata
export_metadata = {
'model': 'nvidia/parakeet-rnnt-110m-da-dk',
'audio_duration': round(audio_duration, 2),
'word_count': len(timestamps_data),
'token_count': token_count,
'hypothesis_score': round(hypothesis_score, 4) if hypothesis_score is not None else None,
'frame_duration': 0.08
}
# Return text, timestamps, audio data, raw_text, export metadata, and silence periods
return punctuated_text, timestamps_data, (audio_data, sample_rate), raw_text, export_metadata, silence_periods
return "No transcription available.", [], None, "", {}, []
except Exception as e:
return f"Error during transcription: {str(e)}\n{traceback.format_exc()}", [], None, "", {}, []
def extract_audio_segment(audio_state, start_time, end_time, current_window=None):
"""Fast audio extraction from memory with waveform visualization.
Args:
audio_state: Tuple of (audio_data, sample_rate)
start_time: Start time of the interval to play
end_time: End time of the interval to play
current_window: Dict with 'start' and 'end' of current waveform window, or None
Returns:
Tuple of (html_output, new_window_state)
"""
# Wrapper to ensure controls never collapse
def wrap_output(content, window_state=None):
return f'<div style="min-height: 200px;">{content}</div>', window_state
try:
if audio_state is None:
return wrap_output("<p style='color: red; padding: 20px;'>No audio loaded. Please transcribe audio first.</p>")
audio_data, sample_rate = audio_state
audio_duration = len(audio_data) / sample_rate
# Default context padding is 160ms
DEFAULT_PADDING = 0.16
# Determine if we need to redraw the waveform or just update the shaded area
need_redraw = True
if current_window is not None:
# Check if the new interval fits within the current window
if start_time >= current_window['start'] and end_time <= current_window['end']:
need_redraw = False
# Reuse the current window boundaries
padded_start = current_window['start']
padded_end = current_window['end']
if need_redraw:
# Calculate new window with ±160ms padding
padded_start = max(0, start_time - DEFAULT_PADDING)
padded_end = min(audio_duration, end_time + DEFAULT_PADDING)
# Extract padded segment for waveform visualization
start_sample_padded = int(padded_start * sample_rate)
end_sample_padded = int(padded_end * sample_rate)
segment_for_waveform = audio_data[start_sample_padded:end_sample_padded]
# Extract ACTUAL segment for playback (no padding)
start_sample = int(start_time * sample_rate)
end_sample = int(end_time * sample_rate)
segment_for_playback = audio_data[start_sample:end_sample]
# Generate waveform visualization with padded segment (reduced height)
fig, ax = plt.subplots(figsize=(12, 2.25))
# Downsample for visualization using block averaging (more accurate than skipping)
max_points = 8000
if len(segment_for_waveform) > max_points:
# Reshape into blocks and take mean of each block
# Pad to multiple of block_size to avoid losing end samples
block_size = len(segment_for_waveform) // max_points
remainder = len(segment_for_waveform) % block_size
if remainder > 0:
# Pad with the last value to make it divisible
padding_needed = block_size - remainder
segment_padded = np.pad(segment_for_waveform, (0, padding_needed), mode='edge')
else:
segment_padded = segment_for_waveform
segment_vis = segment_padded.reshape(-1, block_size).mean(axis=1)
# Generate matching time points spanning the FULL padded range
times_vis = np.linspace(padded_start, padded_end, len(segment_vis))
else:
segment_vis = segment_for_waveform
times_vis = np.linspace(padded_start, padded_end, len(segment_for_waveform))
ax.plot(times_vis, segment_vis, linewidth=0.5, color='#666')
ax.fill_between(times_vis, segment_vis, alpha=0.3, color='#ccc')
# Highlight the actual playback region (without padding)
ax.axvspan(start_time, end_time, alpha=0.3, color='#4CAF50', label='Playback region')
ax.axvspan(padded_start, start_time, alpha=0.1, color='#888', label='Context (not played)')
ax.axvspan(end_time, padded_end, alpha=0.1, color='#888')
ax.set_xlabel('Time (seconds)', fontsize=10)
ax.set_ylabel('Amplitude', fontsize=10)
# Calculate context on each side in ms
left_context_ms = int((start_time - padded_start) * 1000)
right_context_ms = int((padded_end - end_time) * 1000)
# Format context string - symmetric or asymmetric
if left_context_ms == right_context_ms:
context_str = f'(±{left_context_ms}ms context)'
else:
context_str = f'(-{left_context_ms}ms context +{right_context_ms}ms context)'
ax.set_title(f'Audio Segment: {start_time:.3f}s – {end_time:.3f}s {context_str}', fontsize=11)
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
# Convert plot to base64 image
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode()
plt.close(fig)
# Convert PLAYBACK segment (no padding) to base64 WAV
audio_buf = io.BytesIO()
sf.write(audio_buf, segment_for_playback, sample_rate, format='WAV')
audio_buf.seek(0)
audio_base64 = base64.b64encode(audio_buf.read()).decode()
audio_data_url = f"data:audio/wav;base64,{audio_base64}"
# Add unique ID to force Gradio to re-render (triggers autoplay)
unique_id = int(time.time() * 1000)
# Calculate context on each side in ms for the info text
left_context_ms = int((start_time - padded_start) * 1000)
right_context_ms = int((padded_end - end_time) * 1000)
# Format context string - symmetric or asymmetric
if left_context_ms == right_context_ms:
context_info = f'±{left_context_ms}ms'
else:
context_info = f'-{left_context_ms}ms / +{right_context_ms}ms'
# Create HTML with waveform and native audio controls
html_output = f'''
<div style="margin: 10px 0;" data-render-id="{unique_id}">
<img src="data:image/png;base64,{img_base64}" style="width: 100%; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<div style="margin-top: 10px; display: flex; align-items: center; gap: 15px;">
<audio id="segment-audio" controls autoplay style="flex: 1;">
<source src="{audio_data_url}" type="audio/wav">
</audio>
</div>
<div style="margin-top: 8px; text-align: center;">
<span style="font-size: 14px; font-weight: bold; color: #333;">
Segment: {start_time:.3f}s – {end_time:.3f}s
</span>
<span style="font-size: 12px; color: #666; margin-left: 15px;">
Duration: {(end_time - start_time)*1000:.0f}ms | Context shown: {context_info}
</span>
</div>
</div>
'''
# Return HTML and new window state
new_window = {'start': padded_start, 'end': padded_end}
return wrap_output(html_output, new_window)
except Exception as e:
return wrap_output(f"<pre style='padding: 20px;'>Error: {str(e)}\n{traceback.format_exc()}</pre>", current_window)
def build_timestamps_iframe_html(entries_json, export_json_str):
"""Build the interactive word timestamps iframe HTML.
Args:
entries_json: JSON string of word/silence entries
export_json_str: JSON string of full export data for download
Returns:
Complete iframe HTML for embedding in Gradio
"""
iframe_html = f'''
<!DOCTYPE html>
<html>
<head>
<style>
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
body {{ font-family: -apple-system, BlinkMacSystemFont, sans-serif; padding: 10px; background: #f9f9f9; }}
h3 {{ margin-bottom: 8px; font-size: 16px; }}
.help {{ font-size: 11px; color: #666; margin-bottom: 10px; }}
.container {{ max-height: 180px; overflow-y: auto; background: #fff; border-radius: 8px; padding: 8px; border: 1px solid #ddd; }}
.word-btn {{
display: inline-block;
background: #e8f4f8;
padding: 5px 10px;
margin: 3px;
border-radius: 4px;
cursor: pointer;
border: 1px solid #cde;
font-size: 13px;
transition: all 0.15s;
}}
.word-btn:hover {{ background: #c5e5f5; }}
.word-btn.selected {{ background: #4CAF50; color: white; border-color: #3a9; }}
.silence-btn {{
display: inline-block;
background: #ffe4c4;
padding: 5px 8px;
margin: 3px;
border-radius: 4px;
cursor: pointer;
border: 1px solid #dca;
font-size: 11px;
transition: all 0.15s;
}}
.silence-btn:hover {{ background: #ffd4a4; }}
.silence-btn.selected {{ background: #ff9800; color: white; border-color: #e68a00; }}
.checkbox-container {{
display: inline-flex;
align-items: center;
margin-left: 15px;
font-size: 12px;
cursor: pointer;
}}
.checkbox-container input {{
margin-right: 5px;
cursor: pointer;
}}
.checkbox-container:hover {{
color: #0066cc;
}}
.time {{ color: #0066cc; font-size: 10px; font-weight: bold; }}
.silence-time {{ color: #996600; font-size: 10px; font-weight: bold; }}
.duration {{ color: #666; font-size: 10px; margin-left: 3px; }}
.word {{ margin-left: 4px; }}
</style>
</head>
<body>
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px;">
<div style="display: flex; align-items: center;">
<h3 style="margin: 0;">Word Timestamps</h3>
<label class="checkbox-container" title="Extends word end times toward midpoint of gap to next word (max 120ms). Helps capture word endings that may be cut off.">
<input type="checkbox" id="adjust-intervals">
Apply Time Interval Adjustment
</label>
</div>
<a href="#" id="download-json" style="font-size: 12px; color: #0066cc; text-decoration: none;">📥 Download JSON</a>
</div>
<script>var exportJsonStr = {json.dumps(export_json_str)};</script>
<p class="help"><b>Click</b> = select &nbsp;|&nbsp; <b>Ctrl+Click</b> = toggle &nbsp;|&nbsp; <b>Shift+Click</b> = range &nbsp;&nbsp;&nbsp; <span style="background: #ffe4c4; padding: 2px 8px; border-radius: 3px; border: 1px solid #dca;"></span> = detected non speech</p>
<div class="container" id="words"></div>
<script>
var entries = {entries_json};
var container = document.getElementById('words');
// Merge consecutive silence periods (no word between them)
function mergeConsecutiveSilences(entryList) {{
var merged = [];
var pendingSilence = null;
entryList.forEach(function(entry) {{
if (entry.type === 'silence') {{
if (pendingSilence === null) {{
pendingSilence = {{ type: 'silence', start: entry.start, end: entry.end }};
}} else {{
pendingSilence.end = entry.end;
}}
}} else {{
if (pendingSilence !== null) {{
merged.push(pendingSilence);
pendingSilence = null;
}}
merged.push(entry);
}}
}});
if (pendingSilence !== null) {{
merged.push(pendingSilence);
}}
return merged;
}}
entries = mergeConsecutiveSilences(entries);
var words = entries.filter(function(e) {{ return e.type === 'word'; }});
var silences = entries.filter(function(e) {{ return e.type === 'silence'; }});
function calculateAdjustedEnd(wordIndex) {{
var word = words[wordIndex];
var nextWord = words[wordIndex + 1];
if (!nextWord) return word.end;
var gap = nextWord.start - word.end;
var extension = Math.min(gap / 2, 0.12);
return word.end + extension;
}}
var adjustedEnds = words.map(function(w, i) {{ return calculateAdjustedEnd(i); }});
var lastClickedIndex = -1;
function getAllButtons() {{
return Array.from(container.querySelectorAll('.word-btn, .silence-btn'));
}}
function handleItemClick(btn, e) {{
var allBtns = getAllButtons();
var clickedIndex = allBtns.indexOf(btn);
if (e.shiftKey && lastClickedIndex >= 0) {{
var start = Math.min(lastClickedIndex, clickedIndex);
var end = Math.max(lastClickedIndex, clickedIndex);
allBtns.forEach(function(b, i) {{
if (i >= start && i <= end) {{
b.classList.add('selected');
}}
}});
}} else if (e.ctrlKey) {{
btn.classList.toggle('selected');
}} else {{
allBtns.forEach(function(b) {{ b.classList.remove('selected'); }});
btn.classList.add('selected');
}}
lastClickedIndex = clickedIndex;
updateInterval();
}}
var wordIndex = 0;
entries.forEach(function(entry, i) {{
var btn = document.createElement('span');
if (entry.type === 'word') {{
var wi = wordIndex;
btn.className = 'word-btn';
btn.dataset.origS = entry.start;
btn.dataset.origE = entry.end;
btn.dataset.adjE = adjustedEnds[wi];
btn.dataset.s = entry.start;
btn.dataset.e = entry.end;
btn.dataset.word = entry.word;
btn.innerHTML = '<span class="time">[' + entry.start.toFixed(3) + '-' + entry.end.toFixed(3) + 's]</span><span class="word"> ' + entry.word + '</span>';
btn.onclick = function(e) {{ handleItemClick(this, e); }};
wordIndex++;
}} else {{
btn.className = 'silence-btn';
btn.dataset.s = entry.start;
btn.dataset.e = entry.end;
var durationMs = Math.round((entry.end - entry.start) * 1000);
btn.innerHTML = '<span class="silence-time">[' + entry.start.toFixed(3) + '-' + entry.end.toFixed(3) + 's]</span><span class="duration">' + durationMs + 'ms</span>';
btn.onclick = function(e) {{ handleItemClick(this, e); }};
}}
container.appendChild(btn);
// Add vertical space after sentence-ending punctuation
if (entry.type === 'word') {{
var lastChar = entry.word.slice(-1);
if (lastChar === '.' || lastChar === '!' || lastChar === '?') {{
var spacer = document.createElement('div');
spacer.style.height = '15px';
container.appendChild(spacer);
}}
}}
}});
function updateWordLabels() {{
var adjusted = document.getElementById('adjust-intervals').checked;
document.querySelectorAll('.word-btn').forEach(function(btn) {{
var s = parseFloat(btn.dataset.origS);
var e = adjusted ? parseFloat(btn.dataset.adjE) : parseFloat(btn.dataset.origE);
btn.dataset.s = s;
btn.dataset.e = e;
btn.innerHTML = '<span class="time">[' + s.toFixed(3) + '-' + e.toFixed(3) + 's]</span><span class="word"> ' + btn.dataset.word + '</span>';
}});
updateInterval();
}}
document.getElementById('adjust-intervals').addEventListener('change', updateWordLabels);
function updateInterval() {{
var sel = document.querySelectorAll('.word-btn.selected, .silence-btn.selected');
if (sel.length === 0) return;
var minS = Infinity, maxE = 0;
sel.forEach(function(b) {{
minS = Math.min(minS, parseFloat(b.dataset.s));
maxE = Math.max(maxE, parseFloat(b.dataset.e));
}});
var interval = minS.toFixed(3) + '-' + maxE.toFixed(3);
try {{
var boxes = parent.document.querySelectorAll('input[data-testid="textbox"], textarea');
boxes.forEach(function(box) {{
if (box.placeholder && box.placeholder.indexOf('start-end') !== -1) {{
box.value = interval;
box.dispatchEvent(new Event('input', {{bubbles: true}}));
}}
}});
}} catch(err) {{ console.log('Could not update parent:', err); }}
}}
function highlightFromInterval(intervalStr) {{
if (!intervalStr) return;
var parts = intervalStr.replace(',', '-').split('-');
if (parts.length !== 2) return;
var s = parseFloat(parts[0]), e = parseFloat(parts[1]);
if (isNaN(s) || isNaN(e)) return;
document.querySelectorAll('.word-btn').forEach(function(btn) {{
var ws = parseFloat(btn.dataset.s);
var we = parseFloat(btn.dataset.e);
var itemDuration = we - ws;
var overlapStart = Math.max(ws, s);
var overlapEnd = Math.min(we, e);
var overlap = Math.max(0, overlapEnd - overlapStart);
if (itemDuration > 0 && (overlap / itemDuration) > 0.5) {{
btn.classList.add('selected');
}} else {{
btn.classList.remove('selected');
}}
}});
}}
function setupParentWatcher() {{
try {{
var boxes = parent.document.querySelectorAll('input[data-testid="textbox"], textarea');
boxes.forEach(function(box) {{
if (box.placeholder && box.placeholder.indexOf('start-end') !== -1) {{
box.addEventListener('blur', function() {{
highlightFromInterval(this.value);
}});
box.addEventListener('keydown', function(e) {{
if (e.key === 'Enter') {{
highlightFromInterval(this.value);
}}
}});
}}
}});
}} catch(err) {{ console.log('Could not setup parent watcher:', err); }}
}}
setTimeout(setupParentWatcher, 500);
document.getElementById('download-json').onclick = function(e) {{
e.preventDefault();
var dataUrl = 'data:application/json;charset=utf-8,' + encodeURIComponent(exportJsonStr);
var a = document.createElement('a');
a.href = dataUrl;
a.download = 'transcript.json';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
return false;
}};
</script>
</body>
</html>
'''
iframe_srcdoc = iframe_html.replace('"', '&quot;')
return f'''
<iframe srcdoc="{iframe_srcdoc}" style="width: 100%; height: 250px; border: none; border-radius: 8px;"></iframe>
'''
with gr.Blocks() as demo:
gr.Markdown(
"""# [Parakeet-RNNT-110M-Danish](https://huggingface.co/nvidia/parakeet-rnnt-110m-da-dk) Speech-to-Text Transcription Demo"""
)
gr.Markdown(
"""Upload audio and press Transcribe, or alternatively upload your own transcription in one of the supported formats. See [examples of supported formats here](https://drive.google.com/drive/folders/1qrjfHjfssAZQIvSLJi36rLjpuqw3eOjj)."""
)
# State to store audio data in memory for fast extraction
audio_state = gr.State()
timestamps_state = gr.State([]) # Store timestamps for dropdown
audio_input = gr.Audio(
type="filepath",
label="Upload or record your audio",
sources=["upload", "microphone"],
format="wav"
)
# VAD Controls - inline labels with number inputs
with gr.Row():
gr.Markdown("**VAD: Probability Threshold**")
vad_prob_threshold = gr.Number(
show_label=False,
value=0.5,
minimum=0.0,
maximum=1.0,
step=0.05,
scale=0,
min_width=80
)
gr.Markdown("**Min Voice Off (ms)**")
vad_min_off = gr.Number(
show_label=False,
value=48,
minimum=16,
maximum=1000,
step=16,
scale=0,
min_width=80
)
gr.Markdown("**Min Voice On (ms)**")
vad_min_on = gr.Number(
show_label=False,
value=64,
minimum=16,
maximum=1000,
step=16,
scale=0,
min_width=80
)
with gr.Row():
transcribe_button = gr.Button("Transcribe", scale=1)
transcript_file_input = gr.File(
label="Load Transcript (JSON)",
file_types=[".json"],
scale=1
)
transcription_output = gr.Textbox(label="Transcription", lines=5)
timestamps_output = gr.HTML(label="Word Timestamps")
# Time interval input - directly under timestamps (single row, no label)
with gr.Row():
time_input = gr.Textbox(
label="",
show_label=False,
container=False,
placeholder="Time interval: start-end (e.g., 0.56-1.20)",
scale=3,
elem_id="time-interval-box"
)
play_interval_button = gr.Button("▶ Play Interval", scale=1)
# Track last played interval for smart replay
last_interval_state = gr.State("")
# Track current waveform window boundaries for smart redraw
waveform_window_state = gr.State(None)
# Waveform player - below interval controls
waveform_player = gr.HTML(label="Segment Player")
def load_transcript_and_setup(audio, transcript_file, prob_threshold, min_off_ms, min_on_ms):
"""Load external transcript and setup UI - mirrors transcribe_and_setup_audio."""
if transcript_file is None:
# Return empty/unchanged outputs if no file selected
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
text, timestamps_data, audio_data, raw_text, export_metadata, silence_periods = load_transcript(
audio, transcript_file, prob_threshold, int(min_off_ms), int(min_on_ms)
)
# Check for errors
if audio_data is None:
# Error case - text contains error message
return text, "", None, [], gr.update()
# Build combined entries (words + silence) sorted by start time
entries = []
for item in timestamps_data:
entries.append({
'type': 'word',
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3)
})
for item in silence_periods:
entries.append({
'type': 'silence',
'start': round(item['start'], 3),
'end': round(item['end'], 3)
})
entries.sort(key=lambda x: x['start'])
entries_json = json.dumps(entries)
# Build word data as JSON for the iframe
words_json = json.dumps([{
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3)
} for item in timestamps_data])
# Pre-generate full export JSON
segments = [{
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3),
'word_index': i
} for i, item in enumerate(timestamps_data)]
export_data = {
'metadata': export_metadata,
'text': text,
'raw_text': raw_text,
'segments': segments
}
export_json_str = json.dumps(export_data, ensure_ascii=False, indent=2)
# Build iframe HTML using helper function
timestamps_html = build_timestamps_iframe_html(entries_json, export_json_str)
initial_player = '''
<div style="padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px; color: #666;">
<p>Select words above and click <b>▶ Play Interval</b> to hear the segment</p>
</div>
'''
return text, timestamps_html, audio_data, timestamps_data, initial_player
def transcribe_and_setup_audio(audio, prob_threshold, min_off_ms, min_on_ms):
text, timestamps_data, audio_data, raw_text, export_metadata, silence_periods = transcribe_audio(
audio, prob_threshold, int(min_off_ms), int(min_on_ms)
)
# Build combined entries (words + silence) sorted by start time
entries = []
for item in timestamps_data:
entries.append({
'type': 'word',
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3)
})
for item in silence_periods:
entries.append({
'type': 'silence',
'start': round(item['start'], 3),
'end': round(item['end'], 3)
})
entries.sort(key=lambda x: x['start'])
entries_json = json.dumps(entries)
# Build word data as JSON for the iframe (kept for backward compat)
words_json = json.dumps([{
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3)
} for item in timestamps_data])
# Pre-generate full export JSON
segments = [{
'word': item['word'],
'start': round(item['start'], 3),
'end': round(item['end'], 3),
'word_index': i
} for i, item in enumerate(timestamps_data)]
export_data = {
'metadata': export_metadata,
'text': text,
'raw_text': raw_text,
'segments': segments
}
export_json_str = json.dumps(export_data, ensure_ascii=False, indent=2)
# Build iframe HTML using helper function
timestamps_html = build_timestamps_iframe_html(entries_json, export_json_str)
# Initial placeholder for waveform player so it doesn't collapse
initial_player = '''
<div style="padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px; color: #666;">
<p>Select words above and click <b>▶ Play Interval</b> to hear the segment</p>
</div>
'''
return text, timestamps_html, audio_data, timestamps_data, initial_player
def play_time_interval_fast(audio_state, time_interval, last_interval, current_window):
"""Fast extraction using preloaded audio from memory."""
def wrap_error(msg):
return f'<div style="min-height: 150px; padding: 20px; text-align: center; background: #f5f5f5; border-radius: 8px;"><p style="color: #666;">{msg}</p></div>', last_interval, current_window
try:
if not time_interval or not audio_state:
return wrap_error("No interval or audio loaded. Select words and try again.")
# Parse the time interval
time_interval = time_interval.strip().replace(',', '-')
parts = time_interval.split('-')
if len(parts) != 2:
return wrap_error("Invalid interval format. Use: start-end (e.g., 1.20-2.50)")
start_time = float(parts[0].strip())
end_time = float(parts[1].strip())
if start_time >= end_time:
return wrap_error("Start time must be before end time.")
# Load/reload audio segment (autoplay will replay even if same interval)
# Pass current window state for smart redraw logic
result_html, new_window = extract_audio_segment(audio_state, start_time, end_time, current_window)
return result_html, time_interval, new_window
except Exception as e:
return wrap_error(f"Error: {str(e)}")
transcribe_button.click(
fn=transcribe_and_setup_audio,
inputs=[audio_input, vad_prob_threshold, vad_min_off, vad_min_on],
outputs=[transcription_output, timestamps_output, audio_state, timestamps_state, waveform_player]
)
# Load transcript file input
transcript_file_input.change(
fn=load_transcript_and_setup,
inputs=[audio_input, transcript_file_input, vad_prob_threshold, vad_min_off, vad_min_on],
outputs=[transcription_output, timestamps_output, audio_state, timestamps_state, waveform_player]
)
# Play interval button
play_interval_button.click(
fn=play_time_interval_fast,
inputs=[audio_state, time_input, last_interval_state, waveform_window_state],
outputs=[waveform_player, last_interval_state, waveform_window_state]
)
demo.launch()