pmarmaroli's picture
Stuttering classification demo: Wav2Vec2-XLSR-53 heatmap + donut + timeline
140caa0
"""
Stuttering Classification Demo
Wav2Vec2-XLSR-53 Fine-Tuned for Stuttering Detection
Built by Vocametrix — vocametrix.com
"""
import gradio as gr
# ── Monkey-patch: fix gradio_client crash when JSON Schema uses boolean values ──
import gradio_client.utils as _gc_utils
_orig_json_schema_fn = _gc_utils._json_schema_to_python_type
def _patched_json_schema_fn(schema, defs):
if isinstance(schema, bool):
return "Any" if schema else "None"
return _orig_json_schema_fn(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_json_schema_fn
# ── End monkey-patch ──
import numpy as np
import librosa
import torch
import torch.nn.functional as F
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
MODEL_NAME = "vocametrix/wav2vec2-xlsr-53-stuttering-classification"
SAMPLE_RATE = 16000
DEFAULT_CHUNK_DURATION = 4.0 # seconds
DEFAULT_OVERLAP = 0.50 # 50%
LABEL_COLORS = {
"fluent": "#10B981",
"block": "#EF4444",
"prolongation": "#F59E0B",
"Wordrepetition": "#8B5CF6",
"Soundrepetition": "#EC4899",
"interjection": "#3B82F6",
}
LABEL_DISPLAY = {
"fluent": "Fluent",
"block": "Block",
"prolongation": "Prolongation",
"Wordrepetition": "Word Repetition",
"Soundrepetition": "Sound Repetition",
"interjection": "Interjection",
}
LABEL_DESCRIPTIONS = {
"fluent": "Fluent speech — no disfluency detected",
"block": "Silent or audible block in speech flow",
"prolongation": 'Prolonged sounds (e.g., "ssssnake")',
"Wordrepetition": 'Repetition of whole words (e.g., "I-I-I want")',
"Soundrepetition": 'Repetition of individual sounds (e.g., "b-b-ball")',
"interjection": 'Filler words or sounds (e.g., "um", "uh")',
}
FALLBACK_COLOR = "#9CA3AF"
# ═══════════════════════════════════════════════════════════════
# LOAD MODEL (cached at startup)
# ═══════════════════════════════════════════════════════════════
print("Loading model from Hugging Face Hub...")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")
# ═══════════════════════════════════════════════════════════════
# INFERENCE
# ═══════════════════════════════════════════════════════════════
def classify_chunk(audio_chunk, sr=SAMPLE_RATE):
"""Classify a single audio chunk. Returns (label, confidence, all_probs)."""
inputs = feature_extractor(
audio_chunk, sampling_rate=sr, return_tensors="pt", padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1)
pred_id = torch.argmax(probs, dim=-1).item()
confidence = probs[0][pred_id].item()
label = model.config.id2label[pred_id]
all_probs = {
model.config.id2label[i]: float(probs[0][i]) for i in range(probs.shape[1])
}
return label, confidence, all_probs
def chunked_classification(audio_data, sr, chunk_duration, overlap_pct):
"""Split audio into overlapping chunks and classify each."""
chunk_samples = int(chunk_duration * sr)
step_samples = int(chunk_samples * (1 - overlap_pct / 100))
if step_samples < 1:
step_samples = chunk_samples
results = []
pos = 0
while pos < len(audio_data):
end = min(pos + chunk_samples, len(audio_data))
chunk = audio_data[pos:end]
# Pad short final chunk
if len(chunk) < chunk_samples:
chunk = np.pad(chunk, (0, chunk_samples - len(chunk)))
label, confidence, all_probs = classify_chunk(chunk, sr)
results.append(
{
"start": pos / sr,
"end": min(end, len(audio_data)) / sr,
"label": label,
"confidence": confidence,
"probs": all_probs,
}
)
pos += step_samples
return results
# ═══════════════════════════════════════════════════════════════
# VISUALIZATION (HTML/SVG)
# ═══════════════════════════════════════════════════════════════
def generate_results_html(results, total_duration):
"""Generate the full results visualization as HTML."""
if not results:
return '<div style="text-align:center; color:#9ca3af; padding:40px;">No results</div>'
# ── Summary statistics ──
label_counts = {}
total_confidence = 0
for r in results:
lbl = r["label"]
label_counts[lbl] = label_counts.get(lbl, 0) + 1
total_confidence += r["confidence"]
n = len(results)
avg_confidence = total_confidence / n if n > 0 else 0
fluent_count = label_counts.get("fluent", 0)
disfluent_count = n - fluent_count
stuttering_pct = (disfluent_count / n * 100) if n > 0 else 0
# ── Summary cards ──
summary_html = f"""
<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 24px;">
<div style="background: linear-gradient(135deg, #f0fdf4, #dcfce7); border: 1px solid #86efac; border-radius: 12px; padding: 16px; text-align: center;">
<div style="font-size: 2rem; font-weight: 800; color: #16a34a;">{n}</div>
<div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Total Chunks</div>
</div>
<div style="background: linear-gradient(135deg, #faf5ff, #f3e8ff); border: 1px solid #c4b5fd; border-radius: 12px; padding: 16px; text-align: center;">
<div style="font-size: 2rem; font-weight: 800; color: #7c3aed;">{avg_confidence:.0%}</div>
<div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Avg Confidence</div>
</div>
<div style="background: linear-gradient(135deg, #fef2f2, #fee2e2); border: 1px solid #fca5a5; border-radius: 12px; padding: 16px; text-align: center;">
<div style="font-size: 2rem; font-weight: 800; color: #dc2626;">{stuttering_pct:.0f}%</div>
<div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Disfluent</div>
</div>
<div style="background: linear-gradient(135deg, #eff6ff, #dbeafe); border: 1px solid #93c5fd; border-radius: 12px; padding: 16px; text-align: center;">
<div style="font-size: 2rem; font-weight: 800; color: #2563eb;">{total_duration:.1f}s</div>
<div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Duration</div>
</div>
</div>
"""
# ── Color legend ──
present_labels = sorted(set(r["label"] for r in results))
legend_items = []
for lbl in present_labels:
color = LABEL_COLORS.get(lbl, FALLBACK_COLOR)
display = LABEL_DISPLAY.get(lbl, lbl)
count = label_counts.get(lbl, 0)
legend_items.append(
f'<div style="display:flex; align-items:center; gap:6px;">'
f'<div style="width:14px; height:14px; border-radius:4px; background:{color};"></div>'
f'<span style="font-size:0.8rem; font-weight:600; color:#374151;">{display}</span>'
f'<span style="font-size:0.75rem; color:#6b7280;">({count})</span>'
f"</div>"
)
legend_html = (
f'<div style="display:flex; flex-wrap:wrap; gap:16px; margin-bottom:16px; '
f'padding:12px 16px; background:#f9fafb; border-radius:10px; border:1px solid #e5e7eb;">'
f'{"".join(legend_items)}</div>'
)
# ── Chunk heatmap grid ──
cols = min(len(results), 15)
chunks_html_items = []
for i, r in enumerate(results):
color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR)
display = LABEL_DISPLAY.get(r["label"], r["label"])
desc = LABEL_DESCRIPTIONS.get(r["label"], "")
opacity = 0.5 + (r["confidence"]) * 0.5
# Build probability bars for tooltip
prob_bars = ""
sorted_probs = sorted(r["probs"].items(), key=lambda x: -x[1])
for lbl_name, prob in sorted_probs[:3]:
bar_color = LABEL_COLORS.get(lbl_name, FALLBACK_COLOR)
bar_display = LABEL_DISPLAY.get(lbl_name, lbl_name)
bar_width = max(2, prob * 100)
prob_bars += (
f'<div style="display:flex; align-items:center; gap:4px; font-size:0.7rem;">'
f'<span style="width:80px; color:#6b7280; text-align:right;">{bar_display}</span>'
f'<div style="flex:1; background:#e5e7eb; border-radius:3px; height:8px; min-width:60px;">'
f'<div style="width:{bar_width}%; background:{bar_color}; height:100%; border-radius:3px;"></div>'
f"</div>"
f'<span style="width:35px; color:#374151; font-weight:600;">{prob:.0%}</span>'
f"</div>"
)
tooltip_html = (
f'<div class="chunk-tooltip">'
f'<div style="display:flex; align-items:center; gap:6px; margin-bottom:6px;">'
f'<div style="width:10px; height:10px; border-radius:3px; background:{color};"></div>'
f'<span style="font-weight:700; color:#111827;">{display}</span>'
f'<span style="color:#6b7280; font-size:0.75rem;">{r["confidence"]:.0%}</span>'
f"</div>"
f'<div style="font-size:0.7rem; color:#6b7280; margin-bottom:6px;">'
f'{r["start"]:.1f}s — {r["end"]:.1f}s</div>'
f'<div style="font-size:0.7rem; color:#9ca3af; font-style:italic; margin-bottom:8px;">{desc}</div>'
f"{prob_bars}"
f"</div>"
)
chunks_html_items.append(
f'<div class="chunk-cell" style="position:relative;">'
f'<div style="aspect-ratio:1; border-radius:6px; background:{color}; opacity:{opacity}; '
f"cursor:pointer; transition:all 0.15s ease; border:2px solid transparent; "
f'display:flex; align-items:center; justify-content:center; font-size:0.65rem; '
f'color:white; font-weight:700; text-shadow:0 1px 2px rgba(0,0,0,0.3);" '
f'class="chunk-square" '
f'title="{display} ({r["confidence"]:.0%}) — {r["start"]:.1f}s-{r["end"]:.1f}s">'
f"{i + 1}"
f"</div>"
f"{tooltip_html}"
f"</div>"
)
heatmap_html = (
f'<div style="display:grid; grid-template-columns:repeat({cols}, minmax(0,1fr)); gap:6px;">'
f'{"".join(chunks_html_items)}'
f"</div>"
)
# ── Donut chart ──
donut_html = generate_donut_html(label_counts, n)
# ── Timeline bar ──
timeline_html = generate_timeline_html(results, total_duration)
# ── Detailed table ──
table_rows = ""
for i, r in enumerate(results):
color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR)
display = LABEL_DISPLAY.get(r["label"], r["label"])
conf_bar_width = r["confidence"] * 100
table_rows += (
f"<tr>"
f'<td style="padding:8px 12px; font-size:0.8rem; color:#6b7280; text-align:center;">{i + 1}</td>'
f'<td style="padding:8px 12px; font-size:0.8rem; color:#6b7280; font-family:monospace;">'
f'{r["start"]:.1f}s — {r["end"]:.1f}s</td>'
f'<td style="padding:8px 12px;">'
f'<div style="display:flex; align-items:center; gap:6px;">'
f'<div style="width:10px; height:10px; border-radius:3px; background:{color};"></div>'
f'<span style="font-size:0.8rem; font-weight:600; color:#111827;">{display}</span>'
f"</div></td>"
f'<td style="padding:8px 12px;">'
f'<div style="display:flex; align-items:center; gap:6px;">'
f'<div style="width:60px; background:#e5e7eb; border-radius:3px; height:6px;">'
f'<div style="width:{conf_bar_width}%; background:{color}; height:100%; border-radius:3px;"></div>'
f"</div>"
f'<span style="font-size:0.8rem; font-weight:600; color:#374151;">{r["confidence"]:.0%}</span>'
f"</div></td>"
f"</tr>"
)
table_html = (
f'<details style="margin-top:20px;">'
f'<summary style="cursor:pointer; font-weight:700; color:#7c3aed; font-size:0.9rem; padding:8px 0;">'
f"📋 Detailed Chunk Table ({n} chunks)</summary>"
f'<div style="overflow-x:auto; margin-top:8px;">'
f'<table style="width:100%; border-collapse:collapse; border:1px solid #e5e7eb; border-radius:8px; overflow:hidden;">'
f'<thead><tr style="background:#f9fafb;">'
f'<th style="padding:10px 12px; text-align:center; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">#</th>'
f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Time Range</th>'
f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Label</th>'
f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Confidence</th>'
f"</tr></thead>"
f"<tbody>{table_rows}</tbody>"
f"</table></div></details>"
)
# ── Assemble ──
full_html = f"""
<div style="font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;">
{summary_html}
<div style="display:grid; grid-template-columns: 1fr 240px; gap:20px; margin-bottom:20px; align-items:start;">
<div>
<h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;">
🧱 Classification Heatmap
</h3>
{legend_html}
{heatmap_html}
</div>
<div>
<h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;">
📊 Distribution
</h3>
{donut_html}
</div>
</div>
<h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;">
📈 Timeline
</h3>
{timeline_html}
{table_html}
</div>
"""
return full_html
def generate_donut_html(label_counts, total):
"""Generate an SVG donut chart."""
size = 200
cx, cy = size // 2, size // 2
outer_r = 80
inner_r = 50
if total == 0:
return ""
# Sort labels by count descending
sorted_labels = sorted(label_counts.items(), key=lambda x: -x[1])
slices_svg = ""
angle_offset = -90 # start from top
for lbl, count in sorted_labels:
pct = count / total
if pct <= 0:
continue
color = LABEL_COLORS.get(lbl, FALLBACK_COLOR)
angle_span = pct * 360
start_angle = angle_offset
end_angle = angle_offset + angle_span
# Convert to radians
sa = np.radians(start_angle)
ea = np.radians(end_angle)
# Outer arc
x1_o = cx + outer_r * np.cos(sa)
y1_o = cy + outer_r * np.sin(sa)
x2_o = cx + outer_r * np.cos(ea)
y2_o = cy + outer_r * np.sin(ea)
# Inner arc
x1_i = cx + inner_r * np.cos(ea)
y1_i = cy + inner_r * np.sin(ea)
x2_i = cx + inner_r * np.cos(sa)
y2_i = cy + inner_r * np.sin(sa)
large_arc = 1 if angle_span > 180 else 0
path = (
f"M {x1_o},{y1_o} "
f"A {outer_r},{outer_r} 0 {large_arc} 1 {x2_o},{y2_o} "
f"L {x1_i},{y1_i} "
f"A {inner_r},{inner_r} 0 {large_arc} 0 {x2_i},{y2_i} Z"
)
slices_svg += f'<path d="{path}" fill="{color}" stroke="white" stroke-width="2" />'
angle_offset = end_angle
# Center text
center_svg = (
f'<text x="{cx}" y="{cy - 6}" text-anchor="middle" fill="#111827" '
f'font-size="24" font-weight="800" font-family="Inter, sans-serif">{total}</text>'
f'<text x="{cx}" y="{cy + 14}" text-anchor="middle" fill="#6b7280" '
f'font-size="10" font-weight="500" font-family="Inter, sans-serif">chunks</text>'
)
svg = (
f'<svg width="{size}" height="{size}" viewBox="0 0 {size} {size}" '
f'style="display:block; margin:0 auto;">'
f"{slices_svg}{center_svg}</svg>"
)
# Legend below donut
legend_items = ""
for lbl, count in sorted_labels:
color = LABEL_COLORS.get(lbl, FALLBACK_COLOR)
display = LABEL_DISPLAY.get(lbl, lbl)
pct = count / total * 100
legend_items += (
f'<div style="display:flex; align-items:center; gap:6px; font-size:0.75rem;">'
f'<div style="width:10px; height:10px; border-radius:3px; background:{color}; flex-shrink:0;"></div>'
f'<span style="color:#374151; white-space:nowrap;">{display}</span>'
f'<span style="color:#9ca3af; margin-left:auto;">{count} ({pct:.0f}%)</span>'
f"</div>"
)
return (
f'<div style="text-align:center;">{svg}'
f'<div style="display:flex; flex-direction:column; gap:4px; margin-top:12px;">'
f"{legend_items}</div></div>"
)
def generate_timeline_html(results, total_duration):
"""Generate a horizontal timeline bar showing chunk labels."""
if not results or total_duration <= 0:
return ""
bars = ""
for r in results:
color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR)
display = LABEL_DISPLAY.get(r["label"], r["label"])
width_pct = (r["end"] - r["start"]) / total_duration * 100
opacity = 0.5 + r["confidence"] * 0.5
start_time = r["start"]
time_label = "" if width_pct < 5 else f"{start_time:.0f}s"
bars += (
f'<div style="width:{width_pct}%; height:32px; background:{color}; opacity:{opacity}; '
f'display:flex; align-items:center; justify-content:center; font-size:0.6rem; '
f'color:white; font-weight:700; text-shadow:0 1px 2px rgba(0,0,0,0.3); overflow:hidden; '
f'white-space:nowrap; min-width:2px;" '
f'title="{display} ({r["confidence"]:.0%}) — {r["start"]:.1f}s-{r["end"]:.1f}s">'
f"{time_label}"
f"</div>"
)
time_labels = ""
step = max(1, int(total_duration / 10))
for t in range(0, int(total_duration) + 1, step):
left_pct = t / total_duration * 100
time_labels += (
f'<span style="position:absolute; left:{left_pct}%; transform:translateX(-50%); '
f'font-size:0.7rem; color:#9ca3af; font-family:monospace;">{t}s</span>'
)
return (
f'<div style="margin-bottom:20px;">'
f'<div style="display:flex; border-radius:8px; overflow:hidden; border:1px solid #e5e7eb;">'
f"{bars}</div>"
f'<div style="position:relative; height:20px; margin-top:4px;">{time_labels}</div>'
f"</div>"
)
# ═══════════════════════════════════════════════════════════════
# MAIN HANDLER
# ═══════════════════════════════════════════════════════════════
def process_audio(audio_input, chunk_duration, overlap_pct):
"""Main processing function called by Gradio."""
if audio_input is None:
return '<div style="text-align:center; color:#9ca3af; padding:40px; font-size:1.1rem;">Upload or record an audio file to get started.</div>'
try:
# Parse audio input
if isinstance(audio_input, tuple):
sr, audio_data = audio_input
if audio_data is None or len(audio_data) == 0:
return '<div style="color:#ef4444; padding:20px;">Empty audio data.</div>'
# Convert to float32
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
elif audio_data.dtype == np.int32:
audio_data = audio_data.astype(np.float32) / 2147483648.0
else:
audio_data = audio_data.astype(np.float32)
# Mono
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
elif isinstance(audio_input, str):
audio_data, sr = librosa.load(audio_input, sr=SAMPLE_RATE, mono=True)
else:
return '<div style="color:#ef4444; padding:20px;">Unsupported audio format.</div>'
# Resample to 16 kHz
if sr != SAMPLE_RATE:
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLE_RATE)
sr = SAMPLE_RATE
total_duration = len(audio_data) / sr
if total_duration < 0.5:
return '<div style="color:#ef4444; padding:20px;">Audio too short. Please provide at least 0.5 seconds.</div>'
# Run chunked classification
results = chunked_classification(audio_data, sr, chunk_duration, overlap_pct)
if not results:
return '<div style="color:#ef4444; padding:20px;">No chunks could be classified.</div>'
return generate_results_html(results, total_duration)
except Exception as e:
import traceback
traceback.print_exc()
return f'<div style="color:#ef4444; padding:20px;">Error: {str(e)}</div>'
# ═══════════════════════════════════════════════════════════════
# GRADIO UI
# ═══════════════════════════════════════════════════════════════
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap');
* { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; }
.gradio-container {
max-width: 900px !important;
margin: auto !important;
background: #ffffff !important;
}
.main-title {
text-align: center;
font-size: 2rem;
font-weight: 800;
background: linear-gradient(to right, #9333ea, #db2777);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
margin-bottom: 0.25rem;
padding-bottom: 0.25rem;
}
.subtitle {
text-align: center;
color: #6b7280;
font-size: 1rem;
margin-bottom: 1.5rem;
}
.info-box {
background: linear-gradient(135deg, #faf5ff, #f3e8ff);
border: 1px solid #e9d5ff;
border-radius: 12px;
padding: 16px 20px;
margin-bottom: 16px;
font-size: 0.85rem;
color: #374151;
line-height: 1.6;
}
.info-box a {
color: #7e22ce;
text-decoration: none;
font-weight: 600;
}
.info-box a:hover {
text-decoration: underline;
}
.footer {
text-align: center;
padding: 16px;
color: #9ca3af;
font-size: 0.8rem;
border-top: 1px solid #e5e7eb;
margin-top: 24px;
}
.footer a {
color: #9333ea;
text-decoration: none;
font-weight: 500;
}
.footer a:hover {
text-decoration: underline;
}
footer { display: none !important; }
/* Chunk tooltip */
.chunk-cell { position: relative; }
.chunk-tooltip {
display: none;
position: absolute;
bottom: calc(100% + 8px);
left: 50%;
transform: translateX(-50%);
background: #1f2937;
color: #f9fafb;
padding: 10px 14px;
border-radius: 10px;
font-size: 0.75rem;
z-index: 50;
min-width: 200px;
box-shadow: 0 10px 25px rgba(0,0,0,0.3);
pointer-events: none;
}
.chunk-tooltip::after {
content: '';
position: absolute;
top: 100%;
left: 50%;
transform: translateX(-50%);
border: 6px solid transparent;
border-top-color: #1f2937;
}
.chunk-cell:hover .chunk-tooltip { display: block; }
.chunk-cell:hover .chunk-square {
transform: scale(1.15);
border-color: white !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
}
"""
_theme = gr.themes.Soft(
primary_hue="purple",
secondary_hue="purple",
neutral_hue="gray",
)
with gr.Blocks(css=CUSTOM_CSS, theme=_theme, title="Stuttering Classification — Vocametrix") as demo:
# Header
gr.HTML('<div class="main-title">Stuttering Classification</div>')
gr.HTML(
'<div class="subtitle">Upload speech audio and visualize stuttering patterns chunk by chunk</div>'
)
# Info box
gr.HTML("""
<div class="info-box">
<strong>Model:</strong>
<a href="https://huggingface.co/vocametrix/wav2vec2-xlsr-53-stuttering-classification" target="_blank">
Wav2Vec2-XLSR-53
</a> fine-tuned end-to-end on SEP-28k-Extended (~28k 3-second clips) for 6-class stuttering detection.
<br>
<strong>Classes:</strong> Fluent · Block · Prolongation · Word Repetition · Sound Repetition · Interjection
<br>
<strong>Paper:</strong> <em>"SpeechTherapyAgent: A Clinician-in-the-Loop AI Virtual Speech Therapist"</em>
— Sheikh, Marmaroli et al. (under review)
</div>
""")
with gr.Row():
with gr.Column(scale=2):
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="numpy",
label="Speech Audio",
)
with gr.Column(scale=1):
chunk_duration = gr.Slider(
minimum=2,
maximum=6,
value=4,
step=1,
label="Chunk Duration (seconds)",
info="Duration of each analysis window",
)
overlap_pct = gr.Slider(
minimum=0,
maximum=75,
value=50,
step=25,
label="Chunk Overlap (%)",
info="Overlap between consecutive chunks",
)
classify_btn = gr.Button(
"🧠 Classify Stuttering",
variant="primary",
size="lg",
)
# Results output
results_output = gr.HTML(
value='<div style="text-align:center; color:#9ca3af; padding:40px; font-size:1.1rem;">'
"Upload or record an audio file, then click <strong>Classify Stuttering</strong>.</div>"
)
# Event handlers
classify_btn.click(
fn=process_audio,
inputs=[audio_input, chunk_duration, overlap_pct],
outputs=[results_output],
)
# Footer
gr.HTML("""
<div class="footer">
Built by <a href="https://www.vocametrix.com" target="_blank">Vocametrix</a>
&nbsp;·&nbsp;
<a href="https://huggingface.co/vocametrix/wav2vec2-xlsr-53-stuttering-classification" target="_blank">Model Card</a>
&nbsp;·&nbsp;
<a href="mailto:info@vocametrix.com">info@vocametrix.com</a>
</div>
""")
if __name__ == "__main__":
demo.launch()