| """ |
| Stuttering Classification Demo |
| Wav2Vec2-XLSR-53 Fine-Tuned for Stuttering Detection |
| Built by Vocametrix — vocametrix.com |
| """ |
|
|
| import gradio as gr |
|
|
| |
| import gradio_client.utils as _gc_utils |
|
|
| _orig_json_schema_fn = _gc_utils._json_schema_to_python_type |
|
|
|
|
| def _patched_json_schema_fn(schema, defs): |
| if isinstance(schema, bool): |
| return "Any" if schema else "None" |
| return _orig_json_schema_fn(schema, defs) |
|
|
|
|
| _gc_utils._json_schema_to_python_type = _patched_json_schema_fn |
| |
|
|
| import numpy as np |
| import librosa |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForAudioClassification, AutoFeatureExtractor |
|
|
| |
| |
| |
|
|
| MODEL_NAME = "vocametrix/wav2vec2-xlsr-53-stuttering-classification" |
| SAMPLE_RATE = 16000 |
| DEFAULT_CHUNK_DURATION = 4.0 |
| DEFAULT_OVERLAP = 0.50 |
|
|
| LABEL_COLORS = { |
| "fluent": "#10B981", |
| "block": "#EF4444", |
| "prolongation": "#F59E0B", |
| "Wordrepetition": "#8B5CF6", |
| "Soundrepetition": "#EC4899", |
| "interjection": "#3B82F6", |
| } |
|
|
| LABEL_DISPLAY = { |
| "fluent": "Fluent", |
| "block": "Block", |
| "prolongation": "Prolongation", |
| "Wordrepetition": "Word Repetition", |
| "Soundrepetition": "Sound Repetition", |
| "interjection": "Interjection", |
| } |
|
|
| LABEL_DESCRIPTIONS = { |
| "fluent": "Fluent speech — no disfluency detected", |
| "block": "Silent or audible block in speech flow", |
| "prolongation": 'Prolonged sounds (e.g., "ssssnake")', |
| "Wordrepetition": 'Repetition of whole words (e.g., "I-I-I want")', |
| "Soundrepetition": 'Repetition of individual sounds (e.g., "b-b-ball")', |
| "interjection": 'Filler words or sounds (e.g., "um", "uh")', |
| } |
|
|
| FALLBACK_COLOR = "#9CA3AF" |
|
|
| |
| |
| |
|
|
| print("Loading model from Hugging Face Hub...") |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) |
| model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) |
| model.eval() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
| print(f"Model loaded on {device}") |
|
|
|
|
| |
| |
| |
|
|
|
|
| def classify_chunk(audio_chunk, sr=SAMPLE_RATE): |
| """Classify a single audio chunk. Returns (label, confidence, all_probs).""" |
| inputs = feature_extractor( |
| audio_chunk, sampling_rate=sr, return_tensors="pt", padding=True |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = F.softmax(logits, dim=-1) |
| pred_id = torch.argmax(probs, dim=-1).item() |
| confidence = probs[0][pred_id].item() |
| label = model.config.id2label[pred_id] |
|
|
| all_probs = { |
| model.config.id2label[i]: float(probs[0][i]) for i in range(probs.shape[1]) |
| } |
| return label, confidence, all_probs |
|
|
|
|
| def chunked_classification(audio_data, sr, chunk_duration, overlap_pct): |
| """Split audio into overlapping chunks and classify each.""" |
| chunk_samples = int(chunk_duration * sr) |
| step_samples = int(chunk_samples * (1 - overlap_pct / 100)) |
|
|
| if step_samples < 1: |
| step_samples = chunk_samples |
|
|
| results = [] |
| pos = 0 |
| while pos < len(audio_data): |
| end = min(pos + chunk_samples, len(audio_data)) |
| chunk = audio_data[pos:end] |
|
|
| |
| if len(chunk) < chunk_samples: |
| chunk = np.pad(chunk, (0, chunk_samples - len(chunk))) |
|
|
| label, confidence, all_probs = classify_chunk(chunk, sr) |
|
|
| results.append( |
| { |
| "start": pos / sr, |
| "end": min(end, len(audio_data)) / sr, |
| "label": label, |
| "confidence": confidence, |
| "probs": all_probs, |
| } |
| ) |
| pos += step_samples |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
|
|
| def generate_results_html(results, total_duration): |
| """Generate the full results visualization as HTML.""" |
| if not results: |
| return '<div style="text-align:center; color:#9ca3af; padding:40px;">No results</div>' |
|
|
| |
| label_counts = {} |
| total_confidence = 0 |
| for r in results: |
| lbl = r["label"] |
| label_counts[lbl] = label_counts.get(lbl, 0) + 1 |
| total_confidence += r["confidence"] |
|
|
| n = len(results) |
| avg_confidence = total_confidence / n if n > 0 else 0 |
|
|
| fluent_count = label_counts.get("fluent", 0) |
| disfluent_count = n - fluent_count |
| stuttering_pct = (disfluent_count / n * 100) if n > 0 else 0 |
|
|
| |
| summary_html = f""" |
| <div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 12px; margin-bottom: 24px;"> |
| <div style="background: linear-gradient(135deg, #f0fdf4, #dcfce7); border: 1px solid #86efac; border-radius: 12px; padding: 16px; text-align: center;"> |
| <div style="font-size: 2rem; font-weight: 800; color: #16a34a;">{n}</div> |
| <div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Total Chunks</div> |
| </div> |
| <div style="background: linear-gradient(135deg, #faf5ff, #f3e8ff); border: 1px solid #c4b5fd; border-radius: 12px; padding: 16px; text-align: center;"> |
| <div style="font-size: 2rem; font-weight: 800; color: #7c3aed;">{avg_confidence:.0%}</div> |
| <div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Avg Confidence</div> |
| </div> |
| <div style="background: linear-gradient(135deg, #fef2f2, #fee2e2); border: 1px solid #fca5a5; border-radius: 12px; padding: 16px; text-align: center;"> |
| <div style="font-size: 2rem; font-weight: 800; color: #dc2626;">{stuttering_pct:.0f}%</div> |
| <div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Disfluent</div> |
| </div> |
| <div style="background: linear-gradient(135deg, #eff6ff, #dbeafe); border: 1px solid #93c5fd; border-radius: 12px; padding: 16px; text-align: center;"> |
| <div style="font-size: 2rem; font-weight: 800; color: #2563eb;">{total_duration:.1f}s</div> |
| <div style="font-size: 0.8rem; color: #4b5563; font-weight: 600;">Duration</div> |
| </div> |
| </div> |
| """ |
|
|
| |
| present_labels = sorted(set(r["label"] for r in results)) |
| legend_items = [] |
| for lbl in present_labels: |
| color = LABEL_COLORS.get(lbl, FALLBACK_COLOR) |
| display = LABEL_DISPLAY.get(lbl, lbl) |
| count = label_counts.get(lbl, 0) |
| legend_items.append( |
| f'<div style="display:flex; align-items:center; gap:6px;">' |
| f'<div style="width:14px; height:14px; border-radius:4px; background:{color};"></div>' |
| f'<span style="font-size:0.8rem; font-weight:600; color:#374151;">{display}</span>' |
| f'<span style="font-size:0.75rem; color:#6b7280;">({count})</span>' |
| f"</div>" |
| ) |
|
|
| legend_html = ( |
| f'<div style="display:flex; flex-wrap:wrap; gap:16px; margin-bottom:16px; ' |
| f'padding:12px 16px; background:#f9fafb; border-radius:10px; border:1px solid #e5e7eb;">' |
| f'{"".join(legend_items)}</div>' |
| ) |
|
|
| |
| cols = min(len(results), 15) |
| chunks_html_items = [] |
| for i, r in enumerate(results): |
| color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR) |
| display = LABEL_DISPLAY.get(r["label"], r["label"]) |
| desc = LABEL_DESCRIPTIONS.get(r["label"], "") |
| opacity = 0.5 + (r["confidence"]) * 0.5 |
|
|
| |
| prob_bars = "" |
| sorted_probs = sorted(r["probs"].items(), key=lambda x: -x[1]) |
| for lbl_name, prob in sorted_probs[:3]: |
| bar_color = LABEL_COLORS.get(lbl_name, FALLBACK_COLOR) |
| bar_display = LABEL_DISPLAY.get(lbl_name, lbl_name) |
| bar_width = max(2, prob * 100) |
| prob_bars += ( |
| f'<div style="display:flex; align-items:center; gap:4px; font-size:0.7rem;">' |
| f'<span style="width:80px; color:#6b7280; text-align:right;">{bar_display}</span>' |
| f'<div style="flex:1; background:#e5e7eb; border-radius:3px; height:8px; min-width:60px;">' |
| f'<div style="width:{bar_width}%; background:{bar_color}; height:100%; border-radius:3px;"></div>' |
| f"</div>" |
| f'<span style="width:35px; color:#374151; font-weight:600;">{prob:.0%}</span>' |
| f"</div>" |
| ) |
|
|
| tooltip_html = ( |
| f'<div class="chunk-tooltip">' |
| f'<div style="display:flex; align-items:center; gap:6px; margin-bottom:6px;">' |
| f'<div style="width:10px; height:10px; border-radius:3px; background:{color};"></div>' |
| f'<span style="font-weight:700; color:#111827;">{display}</span>' |
| f'<span style="color:#6b7280; font-size:0.75rem;">{r["confidence"]:.0%}</span>' |
| f"</div>" |
| f'<div style="font-size:0.7rem; color:#6b7280; margin-bottom:6px;">' |
| f'{r["start"]:.1f}s — {r["end"]:.1f}s</div>' |
| f'<div style="font-size:0.7rem; color:#9ca3af; font-style:italic; margin-bottom:8px;">{desc}</div>' |
| f"{prob_bars}" |
| f"</div>" |
| ) |
|
|
| chunks_html_items.append( |
| f'<div class="chunk-cell" style="position:relative;">' |
| f'<div style="aspect-ratio:1; border-radius:6px; background:{color}; opacity:{opacity}; ' |
| f"cursor:pointer; transition:all 0.15s ease; border:2px solid transparent; " |
| f'display:flex; align-items:center; justify-content:center; font-size:0.65rem; ' |
| f'color:white; font-weight:700; text-shadow:0 1px 2px rgba(0,0,0,0.3);" ' |
| f'class="chunk-square" ' |
| f'title="{display} ({r["confidence"]:.0%}) — {r["start"]:.1f}s-{r["end"]:.1f}s">' |
| f"{i + 1}" |
| f"</div>" |
| f"{tooltip_html}" |
| f"</div>" |
| ) |
|
|
| heatmap_html = ( |
| f'<div style="display:grid; grid-template-columns:repeat({cols}, minmax(0,1fr)); gap:6px;">' |
| f'{"".join(chunks_html_items)}' |
| f"</div>" |
| ) |
|
|
| |
| donut_html = generate_donut_html(label_counts, n) |
|
|
| |
| timeline_html = generate_timeline_html(results, total_duration) |
|
|
| |
| table_rows = "" |
| for i, r in enumerate(results): |
| color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR) |
| display = LABEL_DISPLAY.get(r["label"], r["label"]) |
| conf_bar_width = r["confidence"] * 100 |
| table_rows += ( |
| f"<tr>" |
| f'<td style="padding:8px 12px; font-size:0.8rem; color:#6b7280; text-align:center;">{i + 1}</td>' |
| f'<td style="padding:8px 12px; font-size:0.8rem; color:#6b7280; font-family:monospace;">' |
| f'{r["start"]:.1f}s — {r["end"]:.1f}s</td>' |
| f'<td style="padding:8px 12px;">' |
| f'<div style="display:flex; align-items:center; gap:6px;">' |
| f'<div style="width:10px; height:10px; border-radius:3px; background:{color};"></div>' |
| f'<span style="font-size:0.8rem; font-weight:600; color:#111827;">{display}</span>' |
| f"</div></td>" |
| f'<td style="padding:8px 12px;">' |
| f'<div style="display:flex; align-items:center; gap:6px;">' |
| f'<div style="width:60px; background:#e5e7eb; border-radius:3px; height:6px;">' |
| f'<div style="width:{conf_bar_width}%; background:{color}; height:100%; border-radius:3px;"></div>' |
| f"</div>" |
| f'<span style="font-size:0.8rem; font-weight:600; color:#374151;">{r["confidence"]:.0%}</span>' |
| f"</div></td>" |
| f"</tr>" |
| ) |
|
|
| table_html = ( |
| f'<details style="margin-top:20px;">' |
| f'<summary style="cursor:pointer; font-weight:700; color:#7c3aed; font-size:0.9rem; padding:8px 0;">' |
| f"📋 Detailed Chunk Table ({n} chunks)</summary>" |
| f'<div style="overflow-x:auto; margin-top:8px;">' |
| f'<table style="width:100%; border-collapse:collapse; border:1px solid #e5e7eb; border-radius:8px; overflow:hidden;">' |
| f'<thead><tr style="background:#f9fafb;">' |
| f'<th style="padding:10px 12px; text-align:center; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">#</th>' |
| f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Time Range</th>' |
| f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Label</th>' |
| f'<th style="padding:10px 12px; text-align:left; font-size:0.75rem; color:#6b7280; font-weight:600; text-transform:uppercase;">Confidence</th>' |
| f"</tr></thead>" |
| f"<tbody>{table_rows}</tbody>" |
| f"</table></div></details>" |
| ) |
|
|
| |
| full_html = f""" |
| <div style="font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;"> |
| {summary_html} |
| |
| <div style="display:grid; grid-template-columns: 1fr 240px; gap:20px; margin-bottom:20px; align-items:start;"> |
| <div> |
| <h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;"> |
| 🧱 Classification Heatmap |
| </h3> |
| {legend_html} |
| {heatmap_html} |
| </div> |
| <div> |
| <h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;"> |
| 📊 Distribution |
| </h3> |
| {donut_html} |
| </div> |
| </div> |
| |
| <h3 style="font-size:1rem; font-weight:700; color:#111827; margin-bottom:12px;"> |
| 📈 Timeline |
| </h3> |
| {timeline_html} |
| |
| {table_html} |
| </div> |
| """ |
| return full_html |
|
|
|
|
| def generate_donut_html(label_counts, total): |
| """Generate an SVG donut chart.""" |
| size = 200 |
| cx, cy = size // 2, size // 2 |
| outer_r = 80 |
| inner_r = 50 |
|
|
| if total == 0: |
| return "" |
|
|
| |
| sorted_labels = sorted(label_counts.items(), key=lambda x: -x[1]) |
|
|
| slices_svg = "" |
| angle_offset = -90 |
|
|
| for lbl, count in sorted_labels: |
| pct = count / total |
| if pct <= 0: |
| continue |
|
|
| color = LABEL_COLORS.get(lbl, FALLBACK_COLOR) |
| angle_span = pct * 360 |
|
|
| start_angle = angle_offset |
| end_angle = angle_offset + angle_span |
|
|
| |
| sa = np.radians(start_angle) |
| ea = np.radians(end_angle) |
|
|
| |
| x1_o = cx + outer_r * np.cos(sa) |
| y1_o = cy + outer_r * np.sin(sa) |
| x2_o = cx + outer_r * np.cos(ea) |
| y2_o = cy + outer_r * np.sin(ea) |
|
|
| |
| x1_i = cx + inner_r * np.cos(ea) |
| y1_i = cy + inner_r * np.sin(ea) |
| x2_i = cx + inner_r * np.cos(sa) |
| y2_i = cy + inner_r * np.sin(sa) |
|
|
| large_arc = 1 if angle_span > 180 else 0 |
|
|
| path = ( |
| f"M {x1_o},{y1_o} " |
| f"A {outer_r},{outer_r} 0 {large_arc} 1 {x2_o},{y2_o} " |
| f"L {x1_i},{y1_i} " |
| f"A {inner_r},{inner_r} 0 {large_arc} 0 {x2_i},{y2_i} Z" |
| ) |
|
|
| slices_svg += f'<path d="{path}" fill="{color}" stroke="white" stroke-width="2" />' |
| angle_offset = end_angle |
|
|
| |
| center_svg = ( |
| f'<text x="{cx}" y="{cy - 6}" text-anchor="middle" fill="#111827" ' |
| f'font-size="24" font-weight="800" font-family="Inter, sans-serif">{total}</text>' |
| f'<text x="{cx}" y="{cy + 14}" text-anchor="middle" fill="#6b7280" ' |
| f'font-size="10" font-weight="500" font-family="Inter, sans-serif">chunks</text>' |
| ) |
|
|
| svg = ( |
| f'<svg width="{size}" height="{size}" viewBox="0 0 {size} {size}" ' |
| f'style="display:block; margin:0 auto;">' |
| f"{slices_svg}{center_svg}</svg>" |
| ) |
|
|
| |
| legend_items = "" |
| for lbl, count in sorted_labels: |
| color = LABEL_COLORS.get(lbl, FALLBACK_COLOR) |
| display = LABEL_DISPLAY.get(lbl, lbl) |
| pct = count / total * 100 |
| legend_items += ( |
| f'<div style="display:flex; align-items:center; gap:6px; font-size:0.75rem;">' |
| f'<div style="width:10px; height:10px; border-radius:3px; background:{color}; flex-shrink:0;"></div>' |
| f'<span style="color:#374151; white-space:nowrap;">{display}</span>' |
| f'<span style="color:#9ca3af; margin-left:auto;">{count} ({pct:.0f}%)</span>' |
| f"</div>" |
| ) |
|
|
| return ( |
| f'<div style="text-align:center;">{svg}' |
| f'<div style="display:flex; flex-direction:column; gap:4px; margin-top:12px;">' |
| f"{legend_items}</div></div>" |
| ) |
|
|
|
|
| def generate_timeline_html(results, total_duration): |
| """Generate a horizontal timeline bar showing chunk labels.""" |
| if not results or total_duration <= 0: |
| return "" |
|
|
| bars = "" |
| for r in results: |
| color = LABEL_COLORS.get(r["label"], FALLBACK_COLOR) |
| display = LABEL_DISPLAY.get(r["label"], r["label"]) |
| width_pct = (r["end"] - r["start"]) / total_duration * 100 |
| opacity = 0.5 + r["confidence"] * 0.5 |
|
|
| start_time = r["start"] |
| time_label = "" if width_pct < 5 else f"{start_time:.0f}s" |
| bars += ( |
| f'<div style="width:{width_pct}%; height:32px; background:{color}; opacity:{opacity}; ' |
| f'display:flex; align-items:center; justify-content:center; font-size:0.6rem; ' |
| f'color:white; font-weight:700; text-shadow:0 1px 2px rgba(0,0,0,0.3); overflow:hidden; ' |
| f'white-space:nowrap; min-width:2px;" ' |
| f'title="{display} ({r["confidence"]:.0%}) — {r["start"]:.1f}s-{r["end"]:.1f}s">' |
| f"{time_label}" |
| f"</div>" |
| ) |
|
|
| time_labels = "" |
| step = max(1, int(total_duration / 10)) |
| for t in range(0, int(total_duration) + 1, step): |
| left_pct = t / total_duration * 100 |
| time_labels += ( |
| f'<span style="position:absolute; left:{left_pct}%; transform:translateX(-50%); ' |
| f'font-size:0.7rem; color:#9ca3af; font-family:monospace;">{t}s</span>' |
| ) |
|
|
| return ( |
| f'<div style="margin-bottom:20px;">' |
| f'<div style="display:flex; border-radius:8px; overflow:hidden; border:1px solid #e5e7eb;">' |
| f"{bars}</div>" |
| f'<div style="position:relative; height:20px; margin-top:4px;">{time_labels}</div>' |
| f"</div>" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def process_audio(audio_input, chunk_duration, overlap_pct): |
| """Main processing function called by Gradio.""" |
| if audio_input is None: |
| return '<div style="text-align:center; color:#9ca3af; padding:40px; font-size:1.1rem;">Upload or record an audio file to get started.</div>' |
|
|
| try: |
| |
| if isinstance(audio_input, tuple): |
| sr, audio_data = audio_input |
| if audio_data is None or len(audio_data) == 0: |
| return '<div style="color:#ef4444; padding:20px;">Empty audio data.</div>' |
| |
| if audio_data.dtype == np.int16: |
| audio_data = audio_data.astype(np.float32) / 32768.0 |
| elif audio_data.dtype == np.int32: |
| audio_data = audio_data.astype(np.float32) / 2147483648.0 |
| else: |
| audio_data = audio_data.astype(np.float32) |
| |
| if len(audio_data.shape) > 1: |
| audio_data = audio_data.mean(axis=1) |
| elif isinstance(audio_input, str): |
| audio_data, sr = librosa.load(audio_input, sr=SAMPLE_RATE, mono=True) |
| else: |
| return '<div style="color:#ef4444; padding:20px;">Unsupported audio format.</div>' |
|
|
| |
| if sr != SAMPLE_RATE: |
| audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=SAMPLE_RATE) |
| sr = SAMPLE_RATE |
|
|
| total_duration = len(audio_data) / sr |
|
|
| if total_duration < 0.5: |
| return '<div style="color:#ef4444; padding:20px;">Audio too short. Please provide at least 0.5 seconds.</div>' |
|
|
| |
| results = chunked_classification(audio_data, sr, chunk_duration, overlap_pct) |
|
|
| if not results: |
| return '<div style="color:#ef4444; padding:20px;">No chunks could be classified.</div>' |
|
|
| return generate_results_html(results, total_duration) |
|
|
| except Exception as e: |
| import traceback |
|
|
| traceback.print_exc() |
| return f'<div style="color:#ef4444; padding:20px;">Error: {str(e)}</div>' |
|
|
|
|
| |
| |
| |
|
|
| CUSTOM_CSS = """ |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700;800&display=swap'); |
| |
| * { font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; } |
| |
| .gradio-container { |
| max-width: 900px !important; |
| margin: auto !important; |
| background: #ffffff !important; |
| } |
| |
| .main-title { |
| text-align: center; |
| font-size: 2rem; |
| font-weight: 800; |
| background: linear-gradient(to right, #9333ea, #db2777); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| background-clip: text; |
| margin-bottom: 0.25rem; |
| padding-bottom: 0.25rem; |
| } |
| |
| .subtitle { |
| text-align: center; |
| color: #6b7280; |
| font-size: 1rem; |
| margin-bottom: 1.5rem; |
| } |
| |
| .info-box { |
| background: linear-gradient(135deg, #faf5ff, #f3e8ff); |
| border: 1px solid #e9d5ff; |
| border-radius: 12px; |
| padding: 16px 20px; |
| margin-bottom: 16px; |
| font-size: 0.85rem; |
| color: #374151; |
| line-height: 1.6; |
| } |
| |
| .info-box a { |
| color: #7e22ce; |
| text-decoration: none; |
| font-weight: 600; |
| } |
| |
| .info-box a:hover { |
| text-decoration: underline; |
| } |
| |
| .footer { |
| text-align: center; |
| padding: 16px; |
| color: #9ca3af; |
| font-size: 0.8rem; |
| border-top: 1px solid #e5e7eb; |
| margin-top: 24px; |
| } |
| |
| .footer a { |
| color: #9333ea; |
| text-decoration: none; |
| font-weight: 500; |
| } |
| |
| .footer a:hover { |
| text-decoration: underline; |
| } |
| |
| footer { display: none !important; } |
| |
| /* Chunk tooltip */ |
| .chunk-cell { position: relative; } |
| .chunk-tooltip { |
| display: none; |
| position: absolute; |
| bottom: calc(100% + 8px); |
| left: 50%; |
| transform: translateX(-50%); |
| background: #1f2937; |
| color: #f9fafb; |
| padding: 10px 14px; |
| border-radius: 10px; |
| font-size: 0.75rem; |
| z-index: 50; |
| min-width: 200px; |
| box-shadow: 0 10px 25px rgba(0,0,0,0.3); |
| pointer-events: none; |
| } |
| .chunk-tooltip::after { |
| content: ''; |
| position: absolute; |
| top: 100%; |
| left: 50%; |
| transform: translateX(-50%); |
| border: 6px solid transparent; |
| border-top-color: #1f2937; |
| } |
| .chunk-cell:hover .chunk-tooltip { display: block; } |
| .chunk-cell:hover .chunk-square { |
| transform: scale(1.15); |
| border-color: white !important; |
| box-shadow: 0 4px 12px rgba(0,0,0,0.3); |
| } |
| """ |
|
|
| _theme = gr.themes.Soft( |
| primary_hue="purple", |
| secondary_hue="purple", |
| neutral_hue="gray", |
| ) |
|
|
| with gr.Blocks(css=CUSTOM_CSS, theme=_theme, title="Stuttering Classification — Vocametrix") as demo: |
|
|
| |
| gr.HTML('<div class="main-title">Stuttering Classification</div>') |
| gr.HTML( |
| '<div class="subtitle">Upload speech audio and visualize stuttering patterns chunk by chunk</div>' |
| ) |
|
|
| |
| gr.HTML(""" |
| <div class="info-box"> |
| <strong>Model:</strong> |
| <a href="https://huggingface.co/vocametrix/wav2vec2-xlsr-53-stuttering-classification" target="_blank"> |
| Wav2Vec2-XLSR-53 |
| </a> fine-tuned end-to-end on SEP-28k-Extended (~28k 3-second clips) for 6-class stuttering detection. |
| <br> |
| <strong>Classes:</strong> Fluent · Block · Prolongation · Word Repetition · Sound Repetition · Interjection |
| <br> |
| <strong>Paper:</strong> <em>"SpeechTherapyAgent: A Clinician-in-the-Loop AI Virtual Speech Therapist"</em> |
| — Sheikh, Marmaroli et al. (under review) |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| audio_input = gr.Audio( |
| sources=["upload", "microphone"], |
| type="numpy", |
| label="Speech Audio", |
| ) |
| with gr.Column(scale=1): |
| chunk_duration = gr.Slider( |
| minimum=2, |
| maximum=6, |
| value=4, |
| step=1, |
| label="Chunk Duration (seconds)", |
| info="Duration of each analysis window", |
| ) |
| overlap_pct = gr.Slider( |
| minimum=0, |
| maximum=75, |
| value=50, |
| step=25, |
| label="Chunk Overlap (%)", |
| info="Overlap between consecutive chunks", |
| ) |
| classify_btn = gr.Button( |
| "🧠 Classify Stuttering", |
| variant="primary", |
| size="lg", |
| ) |
|
|
| |
| results_output = gr.HTML( |
| value='<div style="text-align:center; color:#9ca3af; padding:40px; font-size:1.1rem;">' |
| "Upload or record an audio file, then click <strong>Classify Stuttering</strong>.</div>" |
| ) |
|
|
| |
| classify_btn.click( |
| fn=process_audio, |
| inputs=[audio_input, chunk_duration, overlap_pct], |
| outputs=[results_output], |
| ) |
|
|
| |
| gr.HTML(""" |
| <div class="footer"> |
| Built by <a href="https://www.vocametrix.com" target="_blank">Vocametrix</a> |
| · |
| <a href="https://huggingface.co/vocametrix/wav2vec2-xlsr-53-stuttering-classification" target="_blank">Model Card</a> |
| · |
| <a href="mailto:info@vocametrix.com">info@vocametrix.com</a> |
| </div> |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|