mumble-cleanup / src /cleanup /data /inject.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
4.53 kB
# 9 disfluency operators. each is a pure function (text, rng, cfg) -> text.
# compose via apply(text, cfg, rng) which samples N operators per example and
# chains them deterministically given the rng seed.
#
# the clean target is always recoverable from the raw via deletion + casing +
# punct + homophone normalization. faithfulness is guaranteed by construction
# because nothing here invents content.
import random
import re
import string
from typing import Callable, Dict
INJECT_PROBES = ("um", "uh", "er", "ah", "like", "you know", "i mean", "so ", "well ")
def _split_words(text: str) -> list[str]:
return text.split()
def _join_words(words: list[str]) -> str:
return " ".join(words)
def add_filler(text: str, rng: random.Random, op_cfg: dict) -> str:
words = _split_words(text)
if len(words) < 2:
return text
inserts = rng.randint(1, op_cfg.get("max_inserts", 2))
vocab = op_cfg["vocab"]
for _ in range(inserts):
pos = rng.randint(0, len(words))
words.insert(pos, rng.choice(vocab))
return _join_words(words)
def word_stutter(text: str, rng: random.Random, op_cfg: dict) -> str:
words = _split_words(text)
if len(words) < 2:
return text
pos = rng.randint(0, len(words) - 1)
repeats = rng.randint(1, op_cfg.get("max_repeats", 1))
repeated = [words[pos]] * repeats + words[pos:]
return _join_words(words[:pos] + repeated)
def false_start(text: str, rng: random.Random, op_cfg: dict) -> str:
prefix = rng.choice(op_cfg["prefixes"])
return f"{prefix} {text}"
def strip_punct(text: str, rng: random.Random, op_cfg: dict) -> str:
drop_rate = op_cfg.get("drop_rate", 0.6)
out_chars = []
for ch in text:
if ch in string.punctuation and ch != "'" and rng.random() < drop_rate:
continue
out_chars.append(ch)
return "".join(out_chars)
def lowercase(text: str, rng: random.Random, op_cfg: dict) -> str:
return text.lower()
def merge_sentences(text: str, rng: random.Random, op_cfg: dict) -> str:
# drop a single sentence-ending punctuation if present and lowercase the
# following character so the two sentences read run-on.
matches = list(re.finditer(r"([.!?])\s+(\w)", text))
if not matches:
return text
m = rng.choice(matches)
return text[: m.start()] + " " + m.group(2).lower() + text[m.end():]
def dropped_apostrophe(text: str, rng: random.Random, op_cfg: dict) -> str:
return text.replace("'", "")
def mishear_homophone(text: str, rng: random.Random, op_cfg: dict) -> str:
pairs = op_cfg.get("pairs", [])
if not pairs:
return text
a, b = rng.choice(pairs)
pattern = re.compile(rf"\b{re.escape(a)}\b", re.IGNORECASE)
matches = list(pattern.finditer(text))
if not matches:
return text
m = rng.choice(matches)
return text[: m.start()] + b + text[m.end():]
def repeated_chunk(text: str, rng: random.Random, op_cfg: dict) -> str:
words = _split_words(text)
chunk_min = op_cfg.get("chunk_size_min", 2)
chunk_max = op_cfg.get("chunk_size_max", 4)
if len(words) < chunk_min + 1:
return text
chunk = rng.randint(chunk_min, min(chunk_max, len(words) - 1))
start = rng.randint(0, len(words) - chunk)
repeated = words[start : start + chunk]
return _join_words(words[: start + chunk] + repeated + words[start + chunk :])
# operator registry. key matches the yaml ops.* keys.
OPS: Dict[str, Callable[[str, random.Random, dict], str]] = {
"add_filler": add_filler,
"word_stutter": word_stutter,
"false_start": false_start,
"strip_punct": strip_punct,
"lowercase": lowercase,
"merge_sentences": merge_sentences,
"dropped_apostrophe": dropped_apostrophe,
"mishear_homophone": mishear_homophone,
"repeated_chunk": repeated_chunk,
}
def apply(text: str, ops_cfg: dict, sampling_cfg, rng: random.Random) -> str:
# pick N operators, then for each operator coin-flip its own p. order is
# randomized to vary the corruption shape across examples.
n_min = sampling_cfg.ops_per_example_min
n_max = sampling_cfg.ops_per_example_max
n = rng.randint(n_min, n_max)
op_names = list(OPS.keys())
rng.shuffle(op_names)
chosen = op_names[:n]
out = text
for name in chosen:
op_cfg = ops_cfg.get(name, {})
p = op_cfg.get("p", 0.0)
if rng.random() < p:
out = OPS[name](out, rng, op_cfg)
# final whitespace squeeze
return re.sub(r"\s+", " ", out).strip()