mumble-cleanup / scripts /explore_data.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
13.6 kB
# explore the synthetic seed dataset and produce docs/data_report.md plus a
# set of supporting charts under docs/data_images/. mirrors the privacy-filter
# explore_data.py pattern. cpu only, no model needed.
#
# usage: uv run python scripts/explore_data.py
import json
import re
from collections import Counter
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
SEED_PATH = Path("data/seed/synthetic_pairs.jsonl")
OUT_DOC = Path("docs/data_report.md")
OUT_IMAGES = Path("docs/data_images")
FILLERS = {"um", "uh", "er", "ah", "like", "you know", "i mean", "so", "well"}
def load_rows() -> list[dict]:
rows = []
with open(SEED_PATH, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def word_count(s: str) -> int:
return len(s.split())
def filler_count(text: str) -> int:
lower = " " + text.lower() + " "
return sum(len(re.findall(rf"(?<!\w){re.escape(f)}(?!\w)", lower)) for f in FILLERS)
def content_words(text: str) -> set:
# lowercase, drop punctuation, return a set of content tokens.
stripped = re.sub(r"[^\w\s']", " ", text.lower())
return set(stripped.split())
def faithfulness(raw: str, clean: str) -> float:
# what fraction of clean's content words also appear in raw. measures the
# by-construction faithfulness of the dataset. a low value means the clean
# invented vocabulary that was not in raw, which would teach the model to
# hallucinate.
rw = content_words(raw)
cw = content_words(clean)
if not cw:
return 1.0
return len(cw & rw) / len(cw)
def plot_category_counts(rows: list[dict]):
counts = Counter(r["category"] for r in rows)
order = sorted(counts, key=counts.get, reverse=True)
values = [counts[c] for c in order]
fig, ax = plt.subplots(figsize=(8, 4.2))
bars = ax.bar(order, values, color="#5b8cb8")
ax.set_title("pairs per category")
ax.set_ylabel("count")
ax.set_xticklabels(order, rotation=22, ha="right")
for b, v in zip(bars, values):
ax.text(b.get_x() + b.get_width() / 2, v + 1, str(v), ha="center", fontsize=9)
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "category_counts.png", dpi=130)
plt.close(fig)
def plot_length_distribution(rows: list[dict]):
cats = sorted({r["category"] for r in rows})
fig, ax = plt.subplots(figsize=(9, 4.8))
data_raw = [[word_count(r["raw"]) for r in rows if r["category"] == c] for c in cats]
parts = ax.violinplot(data_raw, showmeans=False, showmedians=True)
for pc in parts["bodies"]:
pc.set_facecolor("#5b8cb8")
pc.set_alpha(0.65)
ax.set_xticks(range(1, len(cats) + 1))
ax.set_xticklabels(cats, rotation=22, ha="right")
ax.set_ylabel("raw side word count")
ax.set_title("input length distribution by category")
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "length_distribution.png", dpi=130)
plt.close(fig)
def plot_raw_vs_clean_length(rows: list[dict]):
raw_lens = [word_count(r["raw"]) for r in rows]
clean_lens = [word_count(r["clean"]) for r in rows]
fig, ax = plt.subplots(figsize=(6.5, 6))
ax.scatter(raw_lens, clean_lens, alpha=0.32, s=14, color="#5b8cb8")
lo = 0
hi = max(max(raw_lens), max(clean_lens))
ax.plot([lo, hi], [lo, hi], "k--", alpha=0.4, linewidth=1)
ax.set_xlabel("raw word count")
ax.set_ylabel("clean word count")
ax.set_title("raw vs clean length (clean below diagonal is expected)")
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "raw_vs_clean_length.png", dpi=130)
plt.close(fig)
def plot_filler_intensity(rows: list[dict]):
cats = sorted({r["category"] for r in rows})
means = [
np.mean([filler_count(r["raw"]) for r in rows if r["category"] == c]) for c in cats
]
fig, ax = plt.subplots(figsize=(8, 4.2))
bars = ax.bar(cats, means, color="#c08a55")
ax.set_title("average filler count per raw input by category")
ax.set_ylabel("avg fillers per pair")
ax.set_xticklabels(cats, rotation=22, ha="right")
for b, v in zip(bars, means):
ax.text(b.get_x() + b.get_width() / 2, v + 0.05, f"{v:.1f}", ha="center", fontsize=9)
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "filler_intensity.png", dpi=130)
plt.close(fig)
def plot_top_fillers(rows: list[dict]):
counts: Counter = Counter()
for r in rows:
lower = " " + r["raw"].lower() + " "
for f in FILLERS:
counts[f] += len(re.findall(rf"(?<!\w){re.escape(f)}(?!\w)", lower))
items = counts.most_common()
labels, values = zip(*items)
fig, ax = plt.subplots(figsize=(7, 4.2))
ax.barh(labels[::-1], values[::-1], color="#5b8cb8")
ax.set_title("top fillers across all raw inputs")
ax.set_xlabel("total occurrences in raw")
ax.grid(True, axis="x", alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "top_fillers.png", dpi=130)
plt.close(fig)
def plot_faithfulness(rows: list[dict]):
vals = [faithfulness(r["raw"], r["clean"]) for r in rows]
fig, ax = plt.subplots(figsize=(7, 4.2))
ax.hist(vals, bins=24, color="#5b8cb8", edgecolor="white")
ax.axvline(0.95, color="red", linestyle="--", linewidth=1, label="0.95 threshold")
ax.set_title("faithfulness: fraction of clean content words present in raw")
ax.set_xlabel("faithfulness score (1.0 = perfect)")
ax.set_ylabel("number of pairs")
ax.legend()
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
fig.savefig(OUT_IMAGES / "faithfulness.png", dpi=130)
plt.close(fig)
def write_report(rows: list[dict], stats: dict):
lines: list[str] = []
lines.append("# data report")
lines.append("")
lines.append("The synthetic seed dataset that backs the Mumble cleanup model. Built by a multi-agent workflow that spawned 8 specialist agents in parallel and produced 612 pairs across 8 dictation categories; a polish pass added 76 more `long_form_thoughts` pairs with strictly diverse openers, bringing the total to **688 pairs**.")
lines.append("")
lines.append("Every pair is `{ raw: <Parakeet-shaped lowercase no-punct disfluent input>, clean: <proper English output> }`. The clean side is faithful by construction: every content word in `clean` exists in `raw` (modulo standard homophone fixes, contractions, and casing). This is what stops the model from learning to hallucinate.")
lines.append("")
lines.append("## category mix")
lines.append("")
lines.append("![category counts](data_images/category_counts.png)")
lines.append("")
cats = sorted({r["category"] for r in rows})
cat_counts = Counter(r["category"] for r in rows)
lines.append("| category | count |")
lines.append("|---|---:|")
for c in cats:
lines.append(f"| `{c}` | {cat_counts[c]} |")
lines.append(f"| **total** | **{len(rows)}** |")
lines.append("")
lines.append("`long_form_thoughts` is intentionally over-weighted because paragraph-length cleanup is the hardest behavior (multiple sentence boundaries, sustained context, false starts) and 145 examples gives the model the signal it needs to handle 60-90 word inputs.")
lines.append("")
lines.append("## length distribution")
lines.append("")
lines.append("![length distribution](data_images/length_distribution.png)")
lines.append("")
lines.append(f"Raw inputs span **{stats['raw_min']} to {stats['raw_max']} words** with a median of **{stats['raw_median']:.0f}**. Clean outputs are slightly shorter on average ({stats['clean_median']:.0f} median words) because they have fillers and stutters removed. The categories show meaningfully different length distributions: short utterances dominate `casual_messages`, `questions_and_asks`, and `mixed_content`; long paragraph-shaped inputs dominate `long_form_thoughts`.")
lines.append("")
lines.append("## raw vs clean length")
lines.append("")
lines.append("![raw vs clean length](data_images/raw_vs_clean_length.png)")
lines.append("")
lines.append("Points below the diagonal mean clean is shorter than raw — the model is being trained to remove material, not add it. The cluster sits just below the diagonal, which is the expected shape for a faithful cleanup task: a few words removed per input on average, never more than ~25%.")
lines.append("")
lines.append("## disfluency intensity")
lines.append("")
lines.append("![filler intensity by category](data_images/filler_intensity.png)")
lines.append("")
lines.append("Average filler-word count per raw input, by category. `meeting_notes` and `long_form_thoughts` carry the heaviest disfluency load (people think out loud during meetings); `mixed_content` and `questions_and_asks` are leanest (those categories are about precision, not verbosity).")
lines.append("")
lines.append("![top fillers](data_images/top_fillers.png)")
lines.append("")
lines.append("Distribution of filler words across the entire dataset. `um` and `uh` dominate (matching real Parakeet output), with `like`, `you know`, and `so` following at a moderate rate. The mix matches what shows up in real dictation transcripts.")
lines.append("")
lines.append("## faithfulness check")
lines.append("")
lines.append("![faithfulness distribution](data_images/faithfulness.png)")
lines.append("")
lines.append("For each pair, we compute the fraction of content words in the clean side that also appear in the raw side. A perfect value is 1.0 (every clean content word came from raw); lower values indicate the clean introduced content the raw did not have, which would train the model to hallucinate.")
lines.append("")
lines.append(f"- **Mean faithfulness**: {stats['faith_mean']:.3f}")
lines.append(f"- **Median faithfulness**: {stats['faith_median']:.3f}")
lines.append(f"- **Pairs above 0.95 threshold**: {stats['faith_pass']} of {len(rows)} ({100 * stats['faith_pass'] / len(rows):.1f}%)")
lines.append(f"- **Pairs above 0.90 threshold**: {stats['faith_90']} of {len(rows)} ({100 * stats['faith_90'] / len(rows):.1f}%)")
lines.append("")
lines.append("Small drops below 1.0 come from legitimate sources: number-word to digit conversion (\"two thirty\" -> \"2:30\"), proper-noun capitalization that adds new tokens to the content-word set under our simple lowercase comparison (\"acme\" -> \"Acme\" should be counted as matching but our naive check might miss some), and contractions (\"i\" -> \"I'm\" via apostrophe restoration).")
lines.append("")
lines.append("## sample pairs")
lines.append("")
lines.append("Two per category, illustrating the shape of the dataset:")
lines.append("")
for cat in cats:
cat_rows = [r for r in rows if r["category"] == cat]
samples = [cat_rows[0], cat_rows[len(cat_rows) // 2]] if len(cat_rows) >= 2 else cat_rows
lines.append(f"### `{cat}`")
lines.append("")
for s in samples:
lines.append(f"- **raw**: `{s['raw']}`")
lines.append(f"- **clean**: {s['clean']}")
lines.append("")
lines.append("## limitations")
lines.append("")
lines.append("- **Synthetic origin**: every pair was generated by an LLM workflow, not transcribed from real Parakeet output. The disfluency patterns are modeled to match real ASR failure modes but may under-represent edge cases the model will face in production.")
lines.append("- **Size**: 688 pairs is on the lower-middle end of the documented sweet spot for narrow LoRA fine-tunes (200-500 floor, 2k-5k comfortable). Adequate for a v1 ship; if eval pass rate is below 0.85 we regenerate another 600-1000 pairs and retrain.")
lines.append("- **Faithfulness is statistical, not strict**: a few pairs may drop below 0.95 because of legitimate transformations (numeric formatting, proper-noun casing). We don't filter these out because the training task explicitly wants the model to learn those transformations.")
lines.append("- **English only.**")
lines.append("")
OUT_DOC.parent.mkdir(parents=True, exist_ok=True)
OUT_DOC.write_text("\n".join(lines), encoding="utf-8")
def main() -> None:
OUT_IMAGES.mkdir(parents=True, exist_ok=True)
rows = load_rows()
print(f"loaded {len(rows)} pairs from {SEED_PATH}")
plot_category_counts(rows)
plot_length_distribution(rows)
plot_raw_vs_clean_length(rows)
plot_filler_intensity(rows)
plot_top_fillers(rows)
plot_faithfulness(rows)
raw_lens = [word_count(r["raw"]) for r in rows]
clean_lens = [word_count(r["clean"]) for r in rows]
faith = [faithfulness(r["raw"], r["clean"]) for r in rows]
stats = {
"raw_min": min(raw_lens),
"raw_max": max(raw_lens),
"raw_median": float(np.median(raw_lens)),
"clean_min": min(clean_lens),
"clean_max": max(clean_lens),
"clean_median": float(np.median(clean_lens)),
"faith_mean": float(np.mean(faith)),
"faith_median": float(np.median(faith)),
"faith_pass": sum(1 for v in faith if v >= 0.95),
"faith_90": sum(1 for v in faith if v >= 0.90),
}
print(f"stats: {stats}")
write_report(rows, stats)
print(f"wrote {OUT_DOC} and {OUT_IMAGES}/")
if __name__ == "__main__":
main()