EDEN / eden /data.py
Rybib's picture
Upload EDEN model and code
2f65125 verified
Raw
History Blame Contribute Delete
17.3 kB
"""Text cleaning, synthetic corruption, dataset loading, and tokenization."""
from __future__ import annotations
import csv
import json
import random
import re
from pathlib import Path
from typing import Callable, Iterable
import torch
from .config import TrainConfig
from .constants import *
from .io_utils import *
from .runtime import require_package
def normalise_text(text: str) -> str:
text = str(text or "")
text = text.replace("\u2018", "'").replace("\u2019", "'")
text = text.replace("\u201c", '"').replace("\u201d", '"')
text = re.sub(r"\s+", " ", text).strip()
return text
def strip_instruction(text: str) -> str:
text = normalise_text(text)
lowered = text.lower()
prefixes = [
"fix grammar:", "fix the grammar:", "correct grammar:",
"correct spelling:", "fix spelling:", "rewrite:", "rewrite this:",
"paraphrase:", "improve:", "improve this:", "make this sound better:",
"clarify:", "punctuate:", "capitalize:",
]
for prefix in prefixes:
if lowered.startswith(prefix):
return text[len(prefix):].strip(" -:")
return text
def valid_pair(noisy: str, clean: str) -> bool:
noisy = normalise_text(noisy)
clean = normalise_text(clean)
if not noisy or not clean:
return False
if noisy == clean and len(clean) < 15:
return False
if len(noisy) < 4 or len(clean) < 4:
return False
if len(noisy) > 1400 or len(clean) > 1400:
return False
return True
def add_pair(rows: list[tuple[str, str]], noisy: str, clean: str) -> None:
noisy = strip_instruction(noisy)
clean = strip_instruction(clean)
if valid_pair(noisy, clean):
rows.append((noisy, clean))
def dedupe_pairs(rows: Iterable[tuple[str, str]], limit: int | None = None) -> list[tuple[str, str]]:
seen = set()
out = []
for noisy, clean in rows:
noisy = normalise_text(noisy)
clean = normalise_text(clean)
if not valid_pair(noisy, clean):
continue
key = (noisy.lower(), clean.lower())
if key in seen:
continue
seen.add(key)
out.append((noisy, clean))
if limit and len(out) >= limit:
break
return out
def read_pairs_jsonl(path: Path) -> list[tuple[str, str]]:
rows = []
with path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line:
continue
item = json.loads(line)
src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad")
tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good")
add_pair(rows, src, tgt)
return dedupe_pairs(rows)
def read_pairs_file(path: Path) -> list[tuple[str, str]]:
path = Path(path)
if not path.exists():
raise FileNotFoundError(path)
if path.suffix.lower() in {".jsonl", ".ndjson"}:
return read_pairs_jsonl(path)
if path.suffix.lower() == ".json":
raw = json.loads(path.read_text(encoding="utf-8"))
rows = []
for item in raw if isinstance(raw, list) else []:
if isinstance(item, dict):
src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad")
tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good")
add_pair(rows, src, tgt)
elif isinstance(item, (list, tuple)) and len(item) == 2:
add_pair(rows, item[0], item[1])
return dedupe_pairs(rows)
if path.suffix.lower() in {".csv", ".tsv"}:
rows = []
delimiter = "\t" if path.suffix.lower() == ".tsv" else ","
with path.open("r", encoding="utf-8", newline="") as fh:
reader = csv.DictReader(fh, delimiter=delimiter)
for item in reader:
src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad")
tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good")
add_pair(rows, src, tgt)
return dedupe_pairs(rows)
raise ValueError(f"Unsupported data file: {path}")
def keyboard_typo(word: str) -> str:
if len(word) < 2:
return word
i = random.randrange(len(word))
ch = word[i].lower()
if ch not in KEYBOARD_ADJ:
return word
repl = random.choice(KEYBOARD_ADJ[ch])
if word[i].isupper():
repl = repl.upper()
return word[:i] + repl + word[i + 1:]
def corrupt_word(word: str) -> str:
if len(word) <= 2:
return word
lower = word.lower()
if lower in COMMON_TYPOS and random.random() < 0.65:
repl = random.choice(COMMON_TYPOS[lower])
return repl.capitalize() if word[0].isupper() else repl
mode = random.choice(["swap", "drop", "double", "keyboard", "dyslexia"])
if mode == "swap" and len(word) > 3:
i = random.randint(0, len(word) - 2)
chars = list(word)
chars[i], chars[i + 1] = chars[i + 1], chars[i]
return "".join(chars)
if mode == "drop" and len(word) > 4:
i = random.randint(1, len(word) - 2)
return word[:i] + word[i + 1:]
if mode == "double":
i = random.randrange(len(word))
return word[:i] + word[i] + word[i:]
if mode == "keyboard":
return keyboard_typo(word)
if mode == "dyslexia":
chars = list(word)
for i, ch in enumerate(chars):
repl = LETTER_SWAPS.get(ch.lower())
if repl and random.random() < 0.55:
chars[i] = repl.upper() if ch.isupper() else repl
return "".join(chars)
return word
def maybe_homophone(token: str) -> str:
lower = re.sub(r"[^a-z']", "", token.lower())
choices = [b for a, b in HOMOPHONES if a == lower]
if not choices:
return token
repl = random.choice(choices)
return repl.capitalize() if token[:1].isupper() else repl
def corrupt_sentence(sentence: str, intensity: float = 0.35) -> str:
pieces = []
for raw in sentence.split():
prefix = ""
suffix = ""
word = raw
while word and not word[0].isalnum():
prefix += word[0]
word = word[1:]
while word and not word[-1].isalnum():
suffix = word[-1] + suffix
word = word[:-1]
if word:
if random.random() < 0.10:
word = maybe_homophone(word)
if random.random() < intensity:
word = corrupt_word(word)
if random.random() < 0.08:
word = word.lower() if random.random() < 0.75 else word.upper()
if random.random() < 0.22:
suffix = suffix.replace(",", "").replace(".", "")
pieces.append(prefix + word + suffix)
text = " ".join(pieces)
if random.random() < 0.30:
text = text.lower()
if random.random() < 0.10:
text = text.rstrip(".!?")
if random.random() < 0.12:
text = re.sub(r"([.!?])\s+", " ", text, count=1)
if random.random() < 0.08:
text = re.sub(r"\s+", " ", text.replace(" ", "", 1))
return normalise_text(text)
def sentence_split(text: str) -> list[str]:
text = normalise_text(text)
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def synthetic_pairs(clean_texts: Iterable[str], max_pairs: int) -> list[tuple[str, str]]:
rows = []
clean = [normalise_text(t) for t in clean_texts if 10 < len(normalise_text(t)) < 900]
random.shuffle(clean)
for text in clean:
for intensity in (0.18, 0.35, 0.55):
add_pair(rows, corrupt_sentence(text, intensity), text)
if len(rows) >= max_pairs:
return dedupe_pairs(rows, max_pairs)
# Punctuation and capitalization restoration.
if any(c in text for c in ".,!?;:"):
no_punct = re.sub(r"[.,!?;:]", "", text)
add_pair(rows, no_punct, text)
if any(c.isupper() for c in text):
add_pair(rows, text.lower(), text)
# Paragraph flow: merge sentence boundaries and ask the model to restore
# a smoother version.
sents = sentence_split(text)
if len(sents) >= 2:
merged = " ".join(s.rstrip(".!?") for s in sents)
add_pair(rows, merged, text)
# Identity preservation teaches the model not to rewrite good text
# unnecessarily.
if random.random() < 0.30:
add_pair(rows, text, text)
if len(rows) >= max_pairs:
break
return dedupe_pairs(rows, max_pairs)
def try_load_dataset(log_fn: Callable[[str], None], *args, **kwargs):
datasets = require_package("datasets")
try:
return datasets.load_dataset(*args, **kwargs)
except Exception as exc:
log_fn(f" skipped {args}: {exc}")
return None
def load_builtin_pairs(max_pairs: int, include_c4: bool, log_fn: Callable[[str], None]) -> list[tuple[str, str]]:
rows: list[tuple[str, str]] = []
def remaining() -> int:
return max(0, max_pairs - len(rows))
def quota(frac: float, floor: int = 500) -> int:
return min(max(floor, int(max_pairs * frac)), remaining())
# Seed examples make smoke/offline setup useful and strengthen the exact
# everyday writing style this app is for.
for clean in SEED_CLEAN_SENTENCES:
add_pair(rows, corrupt_sentence(clean, 0.45), clean)
add_pair(rows, clean.lower(), clean)
add_pair(rows, clean, clean)
log_fn("Loading JFLEG grammar correction...")
for split in ("validation", "test"):
ds = try_load_dataset(log_fn, "jfleg", split=split)
if ds is None:
continue
start = len(rows)
target = start + quota(0.08)
for item in ds:
src = item.get("sentence", "")
for correction in item.get("corrections", []) or []:
add_pair(rows, src, correction)
if len(rows) >= target:
break
log_fn("Loading Grammarly CoEdIT correction/rewrite tasks...")
ds = try_load_dataset(log_fn, "grammarly/coedit", split="train")
if ds is not None:
start = len(rows)
target = start + quota(0.32)
for item in ds:
src = item.get("src") or item.get("input") or item.get("source") or ""
tgt = item.get("tgt") or item.get("target") or item.get("output") or ""
add_pair(rows, src, tgt)
if len(rows) >= target:
break
log_fn("Loading W&I/LOCNESS learner-English correction if available...")
ds = try_load_dataset(log_fn, "wi_locness", "wi", split="train")
if ds is not None:
start = len(rows)
target = start + quota(0.12)
for item in ds:
for edit in item.get("edits", []) or []:
orig = edit.get("orig") or ""
corrections = edit.get("cor") or []
if corrections:
add_pair(rows, orig, corrections[0])
if len(rows) >= target:
break
log_fn("Loading ASSET simplification/rewrite examples...")
ds = try_load_dataset(log_fn, "asset", "simplification", split="validation")
if ds is not None:
start = len(rows)
target = start + quota(0.08)
for item in ds:
src = item.get("original") or ""
for tgt in item.get("simplifications", []) or []:
add_pair(rows, src, tgt)
if len(rows) >= target:
break
log_fn("Loading WikiSplit sentence-flow examples...")
ds = try_load_dataset(log_fn, "wiki_split", split="train")
if ds is not None:
start = len(rows)
target = start + quota(0.08)
for item in ds:
src = item.get("complex_sentence") or ""
s1 = item.get("simple_sentence_1") or ""
s2 = item.get("simple_sentence_2") or ""
tgt = normalise_text(f"{s1} {s2}") if s2 else s1
add_pair(rows, src, tgt)
if len(rows) >= target:
break
log_fn("Loading MRPC paraphrase pairs...")
ds = try_load_dataset(log_fn, "glue", "mrpc", split="train")
if ds is not None:
start = len(rows)
target = start + quota(0.06)
for item in ds:
if int(item.get("label", 0)) == 1:
s1 = item.get("sentence1") or ""
s2 = item.get("sentence2") or ""
add_pair(rows, s1, s2)
add_pair(rows, s2, s1)
if len(rows) >= target:
break
if include_c4:
log_fn("Loading optional C4-200M GEC stream...")
ds = try_load_dataset(log_fn, "liweili/c4_200m", split="train", streaming=True)
if ds is not None:
start = len(rows)
target = start + quota(0.15)
for item in ds:
src = item.get("input") or item.get("src") or ""
tgt = item.get("output") or item.get("tgt") or ""
add_pair(rows, src, tgt)
if len(rows) >= target:
break
base = dedupe_pairs(rows)
clean_pool = [clean for _, clean in base]
log_fn("Generating synthetic typo, dyslexia-like, punctuation, and preservation pairs...")
synth_target = max(1000, max_pairs - len(base))
rows = base + synthetic_pairs(clean_pool + SEED_CLEAN_SENTENCES, synth_target)
random.shuffle(rows)
return dedupe_pairs(rows, max_pairs)
def save_pairs(rows: list[tuple[str, str]], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp")
with tmp.open("w", encoding="utf-8") as fh:
for noisy, clean in rows:
fh.write(json.dumps({"input": noisy, "target": clean}, ensure_ascii=False) + "\n")
tmp.replace(path)
def load_prepared_pairs(path: Path = PAIRS_PATH) -> list[tuple[str, str]]:
if not path.exists():
return []
return read_pairs_jsonl(path)
def train_tokenizer(rows: list[tuple[str, str]], vocab_size: int, path: Path = TOKENIZER_PATH):
require_package("tokenizers")
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
tok = Tokenizer(BPE(unk_token="[UNK]"))
tok.pre_tokenizer = ByteLevel(add_prefix_space=False)
tok.decoder = ByteLevelDecoder()
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
special_tokens=SPECIAL_TOKENS,
show_progress=True,
)
texts = []
for noisy, clean in rows:
texts.append(noisy)
texts.append(clean)
tok.train_from_iterator(texts, trainer=trainer)
path.parent.mkdir(parents=True, exist_ok=True)
tok.save(str(path))
return tok
def load_tokenizer(path: Path = TOKENIZER_PATH):
require_package("tokenizers")
from tokenizers import Tokenizer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
tok = Tokenizer.from_file(str(path))
if tok.decoder is None:
tok.decoder = ByteLevelDecoder()
return tok
def encode_pair(tok, noisy: str, clean: str, max_len: int) -> tuple[list[int], list[int], list[int]]:
src_tokens = tok.encode(normalise_text(noisy)).ids[: max_len - 2]
tgt_tokens = tok.encode(normalise_text(clean)).ids[: max_len - 2]
src = [BOS_ID] + src_tokens + [EOS_ID]
tgt_in = [BOS_ID] + tgt_tokens
tgt_out = tgt_tokens + [EOS_ID]
return src, tgt_in, tgt_out
def pad_batch(seqs: list[list[int]], pad_value: int) -> torch.Tensor:
max_len = max(len(s) for s in seqs)
out = torch.full((len(seqs), max_len), pad_value, dtype=torch.long)
for i, seq in enumerate(seqs):
out[i, : len(seq)] = torch.tensor(seq, dtype=torch.long)
return out
def collate_pairs(rows: list[tuple[str, str]], tok, cfg: TrainConfig) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
srcs, tins, touts = [], [], []
for noisy, clean in rows:
src, tin, tout = encode_pair(tok, noisy, clean, cfg.max_len)
srcs.append(src)
tins.append(tin)
touts.append(tout)
return pad_batch(srcs, PAD_ID), pad_batch(tins, PAD_ID), pad_batch(touts, -100)
def make_batches(rows: list[tuple[str, str]], batch_size: int, shuffle_batches: bool) -> list[list[int]]:
idx = list(range(len(rows)))
idx.sort(key=lambda i: len(rows[i][0]) + len(rows[i][1]))
batches = [idx[i : i + batch_size] for i in range(0, len(idx), batch_size)]
if shuffle_batches:
# Shuffle locally similar lengths to keep padding low without losing
# all stochasticity.
chunks = [batches[i : i + 256] for i in range(0, len(batches), 256)]
for chunk in chunks:
random.shuffle(chunk)
batches = [b for chunk in chunks for b in chunk]
random.shuffle(batches)
return batches
def split_train_val(rows: list[tuple[str, str]], val_split: float, seed: int):
random.Random(seed).shuffle(rows)
val_n = max(1, int(len(rows) * val_split))
return rows[val_n:], rows[:val_n]