|
|
""" |
|
|
Data Collection Tool for Speech Pathology Annotation |
|
|
|
|
|
This module provides a Gradio-based interface for collecting and annotating |
|
|
phoneme-level speech pathology data. Clinicians can record or upload audio, |
|
|
then annotate errors at the phoneme level with timestamps. |
|
|
|
|
|
Usage: |
|
|
python scripts/data_collection.py |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import tempfile |
|
|
from pathlib import Path |
|
|
from typing import Optional, List, Dict, Any, Tuple |
|
|
from datetime import datetime |
|
|
import numpy as np |
|
|
|
|
|
import gradio as gr |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
|
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from models.error_taxonomy import ErrorType, SeverityLevel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
DATA_DIR = Path("data/raw") |
|
|
ANNOTATIONS_FILE = Path("data/annotations.json") |
|
|
SAMPLE_RATE = 16000 |
|
|
FRAME_DURATION_MS = 20 |
|
|
|
|
|
|
|
|
DATA_DIR.mkdir(parents=True, exist_ok=True) |
|
|
ANNOTATIONS_FILE.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
annotations_db: List[Dict[str, Any]] = [] |
|
|
if ANNOTATIONS_FILE.exists(): |
|
|
try: |
|
|
with open(ANNOTATIONS_FILE, 'r', encoding='utf-8') as f: |
|
|
annotations_db = json.load(f) |
|
|
logger.info(f"✅ Loaded {len(annotations_db)} existing annotations") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Could not load annotations: {e}") |
|
|
|
|
|
|
|
|
def save_audio_file(audio_data: Optional[Tuple[int, np.ndarray]], filename: str) -> Optional[str]: |
|
|
"""Save uploaded/recorded audio to file.""" |
|
|
if audio_data is None: |
|
|
return None |
|
|
|
|
|
sample_rate, audio_array = audio_data |
|
|
|
|
|
|
|
|
if sample_rate != SAMPLE_RATE: |
|
|
audio_array = librosa.resample( |
|
|
audio_array.astype(np.float32), |
|
|
orig_sr=sample_rate, |
|
|
target_sr=SAMPLE_RATE |
|
|
) |
|
|
sample_rate = SAMPLE_RATE |
|
|
|
|
|
|
|
|
if np.max(np.abs(audio_array)) > 0: |
|
|
audio_array = audio_array / np.max(np.abs(audio_array)) |
|
|
|
|
|
|
|
|
output_path = DATA_DIR / filename |
|
|
sf.write(str(output_path), audio_array, sample_rate) |
|
|
logger.info(f"✅ Saved audio to {output_path}") |
|
|
|
|
|
return str(output_path) |
|
|
|
|
|
|
|
|
def get_phoneme_list(text: str) -> List[str]: |
|
|
"""Convert text to phoneme list using PhonemeMapper.""" |
|
|
try: |
|
|
mapper = PhonemeMapper( |
|
|
frame_duration_ms=FRAME_DURATION_MS, |
|
|
sample_rate=SAMPLE_RATE |
|
|
) |
|
|
phonemes = mapper.g2p.convert(text) |
|
|
return [p for p in phonemes if p.strip()] if phonemes else [] |
|
|
except Exception as e: |
|
|
logger.error(f"❌ G2P conversion failed: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def calculate_frame_count(audio_path: str) -> int: |
|
|
"""Calculate number of frames for audio file.""" |
|
|
try: |
|
|
duration = librosa.get_duration(path=audio_path) |
|
|
frames = int((duration * 1000) / FRAME_DURATION_MS) |
|
|
return max(1, frames) |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Could not calculate frames: {e}") |
|
|
return 0 |
|
|
|
|
|
|
|
|
def save_annotation( |
|
|
audio_path: str, |
|
|
expected_text: str, |
|
|
phoneme_errors: List[Dict[str, Any]], |
|
|
annotator_name: str, |
|
|
notes: str |
|
|
) -> Dict[str, Any]: |
|
|
"""Save annotation to database.""" |
|
|
try: |
|
|
duration = librosa.get_duration(path=audio_path) |
|
|
|
|
|
annotation = { |
|
|
'id': f"annot_{int(time.time())}", |
|
|
'audio_file': audio_path, |
|
|
'expected_text': expected_text, |
|
|
'duration': float(duration), |
|
|
'annotator': annotator_name, |
|
|
'notes': notes, |
|
|
'created_at': datetime.utcnow().isoformat() + "Z", |
|
|
'phoneme_errors': phoneme_errors, |
|
|
'total_errors': len(phoneme_errors), |
|
|
'error_types': { |
|
|
'substitution': sum(1 for e in phoneme_errors if e.get('error_type') == 'substitution'), |
|
|
'omission': sum(1 for e in phoneme_errors if e.get('error_type') == 'omission'), |
|
|
'distortion': sum(1 for e in phoneme_errors if e.get('error_type') == 'distortion'), |
|
|
'stutter': sum(1 for e in phoneme_errors if e.get('error_type') == 'stutter'), |
|
|
} |
|
|
} |
|
|
|
|
|
annotations_db.append(annotation) |
|
|
|
|
|
|
|
|
with open(ANNOTATIONS_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(annotations_db, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
logger.info(f"✅ Saved annotation {annotation['id']} with {len(phoneme_errors)} errors") |
|
|
|
|
|
return { |
|
|
'status': 'success', |
|
|
'annotation_id': annotation['id'], |
|
|
'total_errors': len(phoneme_errors), |
|
|
'message': f"✅ Annotation saved! Total annotations: {len(annotations_db)}" |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to save annotation: {e}", exc_info=True) |
|
|
return { |
|
|
'status': 'error', |
|
|
'message': f"❌ Failed to save: {str(e)}" |
|
|
} |
|
|
|
|
|
|
|
|
def create_annotation_interface(): |
|
|
"""Create Gradio interface for data collection.""" |
|
|
|
|
|
with gr.Blocks(title="Speech Pathology Data Collection", theme=gr.themes.Soft()) as interface: |
|
|
gr.Markdown(""" |
|
|
# 🎤 Speech Pathology Data Collection Tool |
|
|
|
|
|
**Purpose:** Collect and annotate phoneme-level speech pathology data for training. |
|
|
|
|
|
**Instructions:** |
|
|
1. Upload or record audio (5-30 seconds, 16kHz WAV) |
|
|
2. Enter expected text/transcript |
|
|
3. Review phoneme list |
|
|
4. Annotate errors at phoneme level |
|
|
5. Save annotation |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📥 Audio Input") |
|
|
|
|
|
audio_input = gr.Audio( |
|
|
type="numpy", |
|
|
label="Record or Upload Audio", |
|
|
sources=["microphone", "upload"], |
|
|
format="wav" |
|
|
) |
|
|
|
|
|
expected_text = gr.Textbox( |
|
|
label="Expected Text/Transcript", |
|
|
placeholder="Enter the expected text that should be spoken", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
phoneme_display = gr.Textbox( |
|
|
label="Phonemes (G2P)", |
|
|
lines=5, |
|
|
interactive=False, |
|
|
info="Phonemes extracted from expected text" |
|
|
) |
|
|
|
|
|
btn_get_phonemes = gr.Button("🔍 Extract Phonemes", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### ✏️ Annotation") |
|
|
|
|
|
annotator_name = gr.Textbox( |
|
|
label="Annotator Name", |
|
|
placeholder="Your name", |
|
|
value="clinician" |
|
|
) |
|
|
|
|
|
error_frame_id = gr.Number( |
|
|
label="Frame ID (0-based)", |
|
|
value=0, |
|
|
precision=0, |
|
|
info="Frame number where error occurs" |
|
|
) |
|
|
|
|
|
error_phoneme = gr.Textbox( |
|
|
label="Phoneme with Error", |
|
|
placeholder="/r/", |
|
|
info="The phoneme that has an error" |
|
|
) |
|
|
|
|
|
error_type = gr.Dropdown( |
|
|
label="Error Type", |
|
|
choices=["normal", "substitution", "omission", "distortion", "stutter"], |
|
|
value="normal", |
|
|
info="Type of error detected" |
|
|
) |
|
|
|
|
|
wrong_sound = gr.Textbox( |
|
|
label="Wrong Sound (if substitution)", |
|
|
placeholder="/w/", |
|
|
info="What sound was produced instead (for substitutions)" |
|
|
) |
|
|
|
|
|
error_severity = gr.Slider( |
|
|
label="Severity (0-1)", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.5, |
|
|
step=0.1, |
|
|
info="Severity of the error" |
|
|
) |
|
|
|
|
|
error_timestamp = gr.Number( |
|
|
label="Timestamp (seconds)", |
|
|
value=0.0, |
|
|
precision=2, |
|
|
info="Time in audio where error occurs" |
|
|
) |
|
|
|
|
|
btn_add_error = gr.Button("➕ Add Error", variant="primary") |
|
|
|
|
|
errors_list = gr.Dataframe( |
|
|
label="Annotated Errors", |
|
|
headers=["Frame", "Phoneme", "Type", "Wrong Sound", "Severity", "Time"], |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
notes = gr.Textbox( |
|
|
label="Notes", |
|
|
placeholder="Additional notes about this sample", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
btn_save = gr.Button("💾 Save Annotation", variant="primary", size="lg") |
|
|
|
|
|
output_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### 📊 Statistics") |
|
|
stats_display = gr.Markdown("**Total Annotations:** 0 | **Total Errors:** 0") |
|
|
|
|
|
|
|
|
errors_data = gr.State(value=[]) |
|
|
|
|
|
def extract_phonemes(text: str) -> str: |
|
|
"""Extract phonemes from text.""" |
|
|
if not text: |
|
|
return "Enter expected text first" |
|
|
phonemes = get_phoneme_list(text) |
|
|
return " ".join([f"/{p}/" for p in phonemes]) if phonemes else "No phonemes found" |
|
|
|
|
|
def add_error( |
|
|
frame_id: int, |
|
|
phoneme: str, |
|
|
error_type: str, |
|
|
wrong_sound: str, |
|
|
severity: float, |
|
|
timestamp: float, |
|
|
current_errors: List[Dict] |
|
|
) -> Tuple[List[Dict], gr.Dataframe]: |
|
|
"""Add an error to the list.""" |
|
|
error = { |
|
|
'frame_id': int(frame_id), |
|
|
'phoneme': phoneme.strip(), |
|
|
'error_type': error_type, |
|
|
'wrong_sound': wrong_sound.strip() if wrong_sound else None, |
|
|
'severity': float(severity), |
|
|
'timestamp': float(timestamp), |
|
|
'confidence': 1.0 |
|
|
} |
|
|
|
|
|
new_errors = current_errors + [error] |
|
|
|
|
|
|
|
|
df_data = [ |
|
|
[ |
|
|
e['frame_id'], |
|
|
e['phoneme'], |
|
|
e['error_type'], |
|
|
e.get('wrong_sound', 'N/A'), |
|
|
f"{e['severity']:.2f}", |
|
|
f"{e['timestamp']:.2f}s" |
|
|
] |
|
|
for e in new_errors |
|
|
] |
|
|
|
|
|
return new_errors, df_data |
|
|
|
|
|
def save_annotation_handler( |
|
|
audio_data: Optional[Tuple[int, np.ndarray]], |
|
|
expected_text: str, |
|
|
errors: List[Dict], |
|
|
annotator: str, |
|
|
notes: str |
|
|
) -> str: |
|
|
"""Handle annotation saving.""" |
|
|
if audio_data is None: |
|
|
return "❌ Please provide audio first" |
|
|
|
|
|
if not expected_text: |
|
|
return "❌ Please provide expected text" |
|
|
|
|
|
|
|
|
filename = f"sample_{int(time.time())}.wav" |
|
|
audio_path = save_audio_file(audio_data, filename) |
|
|
|
|
|
if not audio_path: |
|
|
return "❌ Failed to save audio file" |
|
|
|
|
|
|
|
|
result = save_annotation( |
|
|
audio_path=audio_path, |
|
|
expected_text=expected_text, |
|
|
phoneme_errors=errors, |
|
|
annotator_name=annotator, |
|
|
notes=notes |
|
|
) |
|
|
|
|
|
return result.get('message', 'Unknown status') |
|
|
|
|
|
def update_stats() -> str: |
|
|
"""Update statistics display.""" |
|
|
total_annotations = len(annotations_db) |
|
|
total_errors = sum(a.get('total_errors', 0) for a in annotations_db) |
|
|
|
|
|
error_breakdown = {} |
|
|
for ann in annotations_db: |
|
|
for err_type, count in ann.get('error_types', {}).items(): |
|
|
error_breakdown[err_type] = error_breakdown.get(err_type, 0) + count |
|
|
|
|
|
stats_text = f""" |
|
|
**Total Annotations:** {total_annotations} | **Total Errors:** {total_errors} |
|
|
|
|
|
**Error Breakdown:** |
|
|
- Substitution: {error_breakdown.get('substitution', 0)} |
|
|
- Omission: {error_breakdown.get('omission', 0)} |
|
|
- Distortion: {error_breakdown.get('distortion', 0)} |
|
|
- Stutter: {error_breakdown.get('stutter', 0)} |
|
|
""" |
|
|
return stats_text |
|
|
|
|
|
|
|
|
btn_get_phonemes.click( |
|
|
fn=extract_phonemes, |
|
|
inputs=[expected_text], |
|
|
outputs=[phoneme_display] |
|
|
) |
|
|
|
|
|
btn_add_error.click( |
|
|
fn=add_error, |
|
|
inputs=[ |
|
|
error_frame_id, |
|
|
error_phoneme, |
|
|
error_type, |
|
|
wrong_sound, |
|
|
error_severity, |
|
|
error_timestamp, |
|
|
errors_data |
|
|
], |
|
|
outputs=[errors_data, errors_list] |
|
|
) |
|
|
|
|
|
btn_save.click( |
|
|
fn=save_annotation_handler, |
|
|
inputs=[audio_input, expected_text, errors_data, annotator_name, notes], |
|
|
outputs=[output_status] |
|
|
).then( |
|
|
fn=update_stats, |
|
|
outputs=[stats_display] |
|
|
) |
|
|
|
|
|
|
|
|
interface.load(fn=update_stats, outputs=[stats_display]) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface = create_annotation_interface() |
|
|
interface.launch(server_name="0.0.0.0", server_port=7861, share=False) |
|
|
|
|
|
|