# 9 disfluency operators. each is a pure function (text, rng, cfg) -> text. # compose via apply(text, cfg, rng) which samples N operators per example and # chains them deterministically given the rng seed. # # the clean target is always recoverable from the raw via deletion + casing + # punct + homophone normalization. faithfulness is guaranteed by construction # because nothing here invents content. import random import re import string from typing import Callable, Dict INJECT_PROBES = ("um", "uh", "er", "ah", "like", "you know", "i mean", "so ", "well ") def _split_words(text: str) -> list[str]: return text.split() def _join_words(words: list[str]) -> str: return " ".join(words) def add_filler(text: str, rng: random.Random, op_cfg: dict) -> str: words = _split_words(text) if len(words) < 2: return text inserts = rng.randint(1, op_cfg.get("max_inserts", 2)) vocab = op_cfg["vocab"] for _ in range(inserts): pos = rng.randint(0, len(words)) words.insert(pos, rng.choice(vocab)) return _join_words(words) def word_stutter(text: str, rng: random.Random, op_cfg: dict) -> str: words = _split_words(text) if len(words) < 2: return text pos = rng.randint(0, len(words) - 1) repeats = rng.randint(1, op_cfg.get("max_repeats", 1)) repeated = [words[pos]] * repeats + words[pos:] return _join_words(words[:pos] + repeated) def false_start(text: str, rng: random.Random, op_cfg: dict) -> str: prefix = rng.choice(op_cfg["prefixes"]) return f"{prefix} {text}" def strip_punct(text: str, rng: random.Random, op_cfg: dict) -> str: drop_rate = op_cfg.get("drop_rate", 0.6) out_chars = [] for ch in text: if ch in string.punctuation and ch != "'" and rng.random() < drop_rate: continue out_chars.append(ch) return "".join(out_chars) def lowercase(text: str, rng: random.Random, op_cfg: dict) -> str: return text.lower() def merge_sentences(text: str, rng: random.Random, op_cfg: dict) -> str: # drop a single sentence-ending punctuation if present and lowercase the # following character so the two sentences read run-on. matches = list(re.finditer(r"([.!?])\s+(\w)", text)) if not matches: return text m = rng.choice(matches) return text[: m.start()] + " " + m.group(2).lower() + text[m.end():] def dropped_apostrophe(text: str, rng: random.Random, op_cfg: dict) -> str: return text.replace("'", "") def mishear_homophone(text: str, rng: random.Random, op_cfg: dict) -> str: pairs = op_cfg.get("pairs", []) if not pairs: return text a, b = rng.choice(pairs) pattern = re.compile(rf"\b{re.escape(a)}\b", re.IGNORECASE) matches = list(pattern.finditer(text)) if not matches: return text m = rng.choice(matches) return text[: m.start()] + b + text[m.end():] def repeated_chunk(text: str, rng: random.Random, op_cfg: dict) -> str: words = _split_words(text) chunk_min = op_cfg.get("chunk_size_min", 2) chunk_max = op_cfg.get("chunk_size_max", 4) if len(words) < chunk_min + 1: return text chunk = rng.randint(chunk_min, min(chunk_max, len(words) - 1)) start = rng.randint(0, len(words) - chunk) repeated = words[start : start + chunk] return _join_words(words[: start + chunk] + repeated + words[start + chunk :]) # operator registry. key matches the yaml ops.* keys. OPS: Dict[str, Callable[[str, random.Random, dict], str]] = { "add_filler": add_filler, "word_stutter": word_stutter, "false_start": false_start, "strip_punct": strip_punct, "lowercase": lowercase, "merge_sentences": merge_sentences, "dropped_apostrophe": dropped_apostrophe, "mishear_homophone": mishear_homophone, "repeated_chunk": repeated_chunk, } def apply(text: str, ops_cfg: dict, sampling_cfg, rng: random.Random) -> str: # pick N operators, then for each operator coin-flip its own p. order is # randomized to vary the corruption shape across examples. n_min = sampling_cfg.ops_per_example_min n_max = sampling_cfg.ops_per_example_max n = rng.randint(n_min, n_max) op_names = list(OPS.keys()) rng.shuffle(op_names) chosen = op_names[:n] out = text for name in chosen: op_cfg = ops_cfg.get(name, {}) p = op_cfg.get("p", 0.0) if rng.random() < p: out = OPS[name](out, rng, op_cfg) # final whitespace squeeze return re.sub(r"\s+", " ", out).strip()