Feature Extraction
Transformers
Safetensors
PyTorch
English
eden
text-enhancement
grammar-correction
text-rewriting
encoder-decoder
transformer
custom_code
Instructions to use Rybib/EDEN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Rybib/EDEN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Rybib/EDEN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Rybib/EDEN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Text cleaning, synthetic corruption, dataset loading, and tokenization.""" | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import random | |
| import re | |
| from pathlib import Path | |
| from typing import Callable, Iterable | |
| import torch | |
| from .config import TrainConfig | |
| from .constants import * | |
| from .io_utils import * | |
| from .runtime import require_package | |
| def normalise_text(text: str) -> str: | |
| text = str(text or "") | |
| text = text.replace("\u2018", "'").replace("\u2019", "'") | |
| text = text.replace("\u201c", '"').replace("\u201d", '"') | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def strip_instruction(text: str) -> str: | |
| text = normalise_text(text) | |
| lowered = text.lower() | |
| prefixes = [ | |
| "fix grammar:", "fix the grammar:", "correct grammar:", | |
| "correct spelling:", "fix spelling:", "rewrite:", "rewrite this:", | |
| "paraphrase:", "improve:", "improve this:", "make this sound better:", | |
| "clarify:", "punctuate:", "capitalize:", | |
| ] | |
| for prefix in prefixes: | |
| if lowered.startswith(prefix): | |
| return text[len(prefix):].strip(" -:") | |
| return text | |
| def valid_pair(noisy: str, clean: str) -> bool: | |
| noisy = normalise_text(noisy) | |
| clean = normalise_text(clean) | |
| if not noisy or not clean: | |
| return False | |
| if noisy == clean and len(clean) < 15: | |
| return False | |
| if len(noisy) < 4 or len(clean) < 4: | |
| return False | |
| if len(noisy) > 1400 or len(clean) > 1400: | |
| return False | |
| return True | |
| def add_pair(rows: list[tuple[str, str]], noisy: str, clean: str) -> None: | |
| noisy = strip_instruction(noisy) | |
| clean = strip_instruction(clean) | |
| if valid_pair(noisy, clean): | |
| rows.append((noisy, clean)) | |
| def dedupe_pairs(rows: Iterable[tuple[str, str]], limit: int | None = None) -> list[tuple[str, str]]: | |
| seen = set() | |
| out = [] | |
| for noisy, clean in rows: | |
| noisy = normalise_text(noisy) | |
| clean = normalise_text(clean) | |
| if not valid_pair(noisy, clean): | |
| continue | |
| key = (noisy.lower(), clean.lower()) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append((noisy, clean)) | |
| if limit and len(out) >= limit: | |
| break | |
| return out | |
| def read_pairs_jsonl(path: Path) -> list[tuple[str, str]]: | |
| rows = [] | |
| with path.open("r", encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| item = json.loads(line) | |
| src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") | |
| tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") | |
| add_pair(rows, src, tgt) | |
| return dedupe_pairs(rows) | |
| def read_pairs_file(path: Path) -> list[tuple[str, str]]: | |
| path = Path(path) | |
| if not path.exists(): | |
| raise FileNotFoundError(path) | |
| if path.suffix.lower() in {".jsonl", ".ndjson"}: | |
| return read_pairs_jsonl(path) | |
| if path.suffix.lower() == ".json": | |
| raw = json.loads(path.read_text(encoding="utf-8")) | |
| rows = [] | |
| for item in raw if isinstance(raw, list) else []: | |
| if isinstance(item, dict): | |
| src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") | |
| tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") | |
| add_pair(rows, src, tgt) | |
| elif isinstance(item, (list, tuple)) and len(item) == 2: | |
| add_pair(rows, item[0], item[1]) | |
| return dedupe_pairs(rows) | |
| if path.suffix.lower() in {".csv", ".tsv"}: | |
| rows = [] | |
| delimiter = "\t" if path.suffix.lower() == ".tsv" else "," | |
| with path.open("r", encoding="utf-8", newline="") as fh: | |
| reader = csv.DictReader(fh, delimiter=delimiter) | |
| for item in reader: | |
| src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") | |
| tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") | |
| add_pair(rows, src, tgt) | |
| return dedupe_pairs(rows) | |
| raise ValueError(f"Unsupported data file: {path}") | |
| def keyboard_typo(word: str) -> str: | |
| if len(word) < 2: | |
| return word | |
| i = random.randrange(len(word)) | |
| ch = word[i].lower() | |
| if ch not in KEYBOARD_ADJ: | |
| return word | |
| repl = random.choice(KEYBOARD_ADJ[ch]) | |
| if word[i].isupper(): | |
| repl = repl.upper() | |
| return word[:i] + repl + word[i + 1:] | |
| def corrupt_word(word: str) -> str: | |
| if len(word) <= 2: | |
| return word | |
| lower = word.lower() | |
| if lower in COMMON_TYPOS and random.random() < 0.65: | |
| repl = random.choice(COMMON_TYPOS[lower]) | |
| return repl.capitalize() if word[0].isupper() else repl | |
| mode = random.choice(["swap", "drop", "double", "keyboard", "dyslexia"]) | |
| if mode == "swap" and len(word) > 3: | |
| i = random.randint(0, len(word) - 2) | |
| chars = list(word) | |
| chars[i], chars[i + 1] = chars[i + 1], chars[i] | |
| return "".join(chars) | |
| if mode == "drop" and len(word) > 4: | |
| i = random.randint(1, len(word) - 2) | |
| return word[:i] + word[i + 1:] | |
| if mode == "double": | |
| i = random.randrange(len(word)) | |
| return word[:i] + word[i] + word[i:] | |
| if mode == "keyboard": | |
| return keyboard_typo(word) | |
| if mode == "dyslexia": | |
| chars = list(word) | |
| for i, ch in enumerate(chars): | |
| repl = LETTER_SWAPS.get(ch.lower()) | |
| if repl and random.random() < 0.55: | |
| chars[i] = repl.upper() if ch.isupper() else repl | |
| return "".join(chars) | |
| return word | |
| def maybe_homophone(token: str) -> str: | |
| lower = re.sub(r"[^a-z']", "", token.lower()) | |
| choices = [b for a, b in HOMOPHONES if a == lower] | |
| if not choices: | |
| return token | |
| repl = random.choice(choices) | |
| return repl.capitalize() if token[:1].isupper() else repl | |
| def corrupt_sentence(sentence: str, intensity: float = 0.35) -> str: | |
| pieces = [] | |
| for raw in sentence.split(): | |
| prefix = "" | |
| suffix = "" | |
| word = raw | |
| while word and not word[0].isalnum(): | |
| prefix += word[0] | |
| word = word[1:] | |
| while word and not word[-1].isalnum(): | |
| suffix = word[-1] + suffix | |
| word = word[:-1] | |
| if word: | |
| if random.random() < 0.10: | |
| word = maybe_homophone(word) | |
| if random.random() < intensity: | |
| word = corrupt_word(word) | |
| if random.random() < 0.08: | |
| word = word.lower() if random.random() < 0.75 else word.upper() | |
| if random.random() < 0.22: | |
| suffix = suffix.replace(",", "").replace(".", "") | |
| pieces.append(prefix + word + suffix) | |
| text = " ".join(pieces) | |
| if random.random() < 0.30: | |
| text = text.lower() | |
| if random.random() < 0.10: | |
| text = text.rstrip(".!?") | |
| if random.random() < 0.12: | |
| text = re.sub(r"([.!?])\s+", " ", text, count=1) | |
| if random.random() < 0.08: | |
| text = re.sub(r"\s+", " ", text.replace(" ", "", 1)) | |
| return normalise_text(text) | |
| def sentence_split(text: str) -> list[str]: | |
| text = normalise_text(text) | |
| return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] | |
| def synthetic_pairs(clean_texts: Iterable[str], max_pairs: int) -> list[tuple[str, str]]: | |
| rows = [] | |
| clean = [normalise_text(t) for t in clean_texts if 10 < len(normalise_text(t)) < 900] | |
| random.shuffle(clean) | |
| for text in clean: | |
| for intensity in (0.18, 0.35, 0.55): | |
| add_pair(rows, corrupt_sentence(text, intensity), text) | |
| if len(rows) >= max_pairs: | |
| return dedupe_pairs(rows, max_pairs) | |
| # Punctuation and capitalization restoration. | |
| if any(c in text for c in ".,!?;:"): | |
| no_punct = re.sub(r"[.,!?;:]", "", text) | |
| add_pair(rows, no_punct, text) | |
| if any(c.isupper() for c in text): | |
| add_pair(rows, text.lower(), text) | |
| # Paragraph flow: merge sentence boundaries and ask the model to restore | |
| # a smoother version. | |
| sents = sentence_split(text) | |
| if len(sents) >= 2: | |
| merged = " ".join(s.rstrip(".!?") for s in sents) | |
| add_pair(rows, merged, text) | |
| # Identity preservation teaches the model not to rewrite good text | |
| # unnecessarily. | |
| if random.random() < 0.30: | |
| add_pair(rows, text, text) | |
| if len(rows) >= max_pairs: | |
| break | |
| return dedupe_pairs(rows, max_pairs) | |
| def try_load_dataset(log_fn: Callable[[str], None], *args, **kwargs): | |
| datasets = require_package("datasets") | |
| try: | |
| return datasets.load_dataset(*args, **kwargs) | |
| except Exception as exc: | |
| log_fn(f" skipped {args}: {exc}") | |
| return None | |
| def load_builtin_pairs(max_pairs: int, include_c4: bool, log_fn: Callable[[str], None]) -> list[tuple[str, str]]: | |
| rows: list[tuple[str, str]] = [] | |
| def remaining() -> int: | |
| return max(0, max_pairs - len(rows)) | |
| def quota(frac: float, floor: int = 500) -> int: | |
| return min(max(floor, int(max_pairs * frac)), remaining()) | |
| # Seed examples make smoke/offline setup useful and strengthen the exact | |
| # everyday writing style this app is for. | |
| for clean in SEED_CLEAN_SENTENCES: | |
| add_pair(rows, corrupt_sentence(clean, 0.45), clean) | |
| add_pair(rows, clean.lower(), clean) | |
| add_pair(rows, clean, clean) | |
| log_fn("Loading JFLEG grammar correction...") | |
| for split in ("validation", "test"): | |
| ds = try_load_dataset(log_fn, "jfleg", split=split) | |
| if ds is None: | |
| continue | |
| start = len(rows) | |
| target = start + quota(0.08) | |
| for item in ds: | |
| src = item.get("sentence", "") | |
| for correction in item.get("corrections", []) or []: | |
| add_pair(rows, src, correction) | |
| if len(rows) >= target: | |
| break | |
| log_fn("Loading Grammarly CoEdIT correction/rewrite tasks...") | |
| ds = try_load_dataset(log_fn, "grammarly/coedit", split="train") | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.32) | |
| for item in ds: | |
| src = item.get("src") or item.get("input") or item.get("source") or "" | |
| tgt = item.get("tgt") or item.get("target") or item.get("output") or "" | |
| add_pair(rows, src, tgt) | |
| if len(rows) >= target: | |
| break | |
| log_fn("Loading W&I/LOCNESS learner-English correction if available...") | |
| ds = try_load_dataset(log_fn, "wi_locness", "wi", split="train") | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.12) | |
| for item in ds: | |
| for edit in item.get("edits", []) or []: | |
| orig = edit.get("orig") or "" | |
| corrections = edit.get("cor") or [] | |
| if corrections: | |
| add_pair(rows, orig, corrections[0]) | |
| if len(rows) >= target: | |
| break | |
| log_fn("Loading ASSET simplification/rewrite examples...") | |
| ds = try_load_dataset(log_fn, "asset", "simplification", split="validation") | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.08) | |
| for item in ds: | |
| src = item.get("original") or "" | |
| for tgt in item.get("simplifications", []) or []: | |
| add_pair(rows, src, tgt) | |
| if len(rows) >= target: | |
| break | |
| log_fn("Loading WikiSplit sentence-flow examples...") | |
| ds = try_load_dataset(log_fn, "wiki_split", split="train") | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.08) | |
| for item in ds: | |
| src = item.get("complex_sentence") or "" | |
| s1 = item.get("simple_sentence_1") or "" | |
| s2 = item.get("simple_sentence_2") or "" | |
| tgt = normalise_text(f"{s1} {s2}") if s2 else s1 | |
| add_pair(rows, src, tgt) | |
| if len(rows) >= target: | |
| break | |
| log_fn("Loading MRPC paraphrase pairs...") | |
| ds = try_load_dataset(log_fn, "glue", "mrpc", split="train") | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.06) | |
| for item in ds: | |
| if int(item.get("label", 0)) == 1: | |
| s1 = item.get("sentence1") or "" | |
| s2 = item.get("sentence2") or "" | |
| add_pair(rows, s1, s2) | |
| add_pair(rows, s2, s1) | |
| if len(rows) >= target: | |
| break | |
| if include_c4: | |
| log_fn("Loading optional C4-200M GEC stream...") | |
| ds = try_load_dataset(log_fn, "liweili/c4_200m", split="train", streaming=True) | |
| if ds is not None: | |
| start = len(rows) | |
| target = start + quota(0.15) | |
| for item in ds: | |
| src = item.get("input") or item.get("src") or "" | |
| tgt = item.get("output") or item.get("tgt") or "" | |
| add_pair(rows, src, tgt) | |
| if len(rows) >= target: | |
| break | |
| base = dedupe_pairs(rows) | |
| clean_pool = [clean for _, clean in base] | |
| log_fn("Generating synthetic typo, dyslexia-like, punctuation, and preservation pairs...") | |
| synth_target = max(1000, max_pairs - len(base)) | |
| rows = base + synthetic_pairs(clean_pool + SEED_CLEAN_SENTENCES, synth_target) | |
| random.shuffle(rows) | |
| return dedupe_pairs(rows, max_pairs) | |
| def save_pairs(rows: list[tuple[str, str]], path: Path) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = path.with_suffix(".tmp") | |
| with tmp.open("w", encoding="utf-8") as fh: | |
| for noisy, clean in rows: | |
| fh.write(json.dumps({"input": noisy, "target": clean}, ensure_ascii=False) + "\n") | |
| tmp.replace(path) | |
| def load_prepared_pairs(path: Path = PAIRS_PATH) -> list[tuple[str, str]]: | |
| if not path.exists(): | |
| return [] | |
| return read_pairs_jsonl(path) | |
| def train_tokenizer(rows: list[tuple[str, str]], vocab_size: int, path: Path = TOKENIZER_PATH): | |
| require_package("tokenizers") | |
| from tokenizers import Tokenizer | |
| from tokenizers.decoders import ByteLevel as ByteLevelDecoder | |
| from tokenizers.models import BPE | |
| from tokenizers.pre_tokenizers import ByteLevel | |
| from tokenizers.trainers import BpeTrainer | |
| tok = Tokenizer(BPE(unk_token="[UNK]")) | |
| tok.pre_tokenizer = ByteLevel(add_prefix_space=False) | |
| tok.decoder = ByteLevelDecoder() | |
| trainer = BpeTrainer( | |
| vocab_size=vocab_size, | |
| min_frequency=2, | |
| special_tokens=SPECIAL_TOKENS, | |
| show_progress=True, | |
| ) | |
| texts = [] | |
| for noisy, clean in rows: | |
| texts.append(noisy) | |
| texts.append(clean) | |
| tok.train_from_iterator(texts, trainer=trainer) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| tok.save(str(path)) | |
| return tok | |
| def load_tokenizer(path: Path = TOKENIZER_PATH): | |
| require_package("tokenizers") | |
| from tokenizers import Tokenizer | |
| from tokenizers.decoders import ByteLevel as ByteLevelDecoder | |
| tok = Tokenizer.from_file(str(path)) | |
| if tok.decoder is None: | |
| tok.decoder = ByteLevelDecoder() | |
| return tok | |
| def encode_pair(tok, noisy: str, clean: str, max_len: int) -> tuple[list[int], list[int], list[int]]: | |
| src_tokens = tok.encode(normalise_text(noisy)).ids[: max_len - 2] | |
| tgt_tokens = tok.encode(normalise_text(clean)).ids[: max_len - 2] | |
| src = [BOS_ID] + src_tokens + [EOS_ID] | |
| tgt_in = [BOS_ID] + tgt_tokens | |
| tgt_out = tgt_tokens + [EOS_ID] | |
| return src, tgt_in, tgt_out | |
| def pad_batch(seqs: list[list[int]], pad_value: int) -> torch.Tensor: | |
| max_len = max(len(s) for s in seqs) | |
| out = torch.full((len(seqs), max_len), pad_value, dtype=torch.long) | |
| for i, seq in enumerate(seqs): | |
| out[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) | |
| return out | |
| def collate_pairs(rows: list[tuple[str, str]], tok, cfg: TrainConfig) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| srcs, tins, touts = [], [], [] | |
| for noisy, clean in rows: | |
| src, tin, tout = encode_pair(tok, noisy, clean, cfg.max_len) | |
| srcs.append(src) | |
| tins.append(tin) | |
| touts.append(tout) | |
| return pad_batch(srcs, PAD_ID), pad_batch(tins, PAD_ID), pad_batch(touts, -100) | |
| def make_batches(rows: list[tuple[str, str]], batch_size: int, shuffle_batches: bool) -> list[list[int]]: | |
| idx = list(range(len(rows))) | |
| idx.sort(key=lambda i: len(rows[i][0]) + len(rows[i][1])) | |
| batches = [idx[i : i + batch_size] for i in range(0, len(idx), batch_size)] | |
| if shuffle_batches: | |
| # Shuffle locally similar lengths to keep padding low without losing | |
| # all stochasticity. | |
| chunks = [batches[i : i + 256] for i in range(0, len(batches), 256)] | |
| for chunk in chunks: | |
| random.shuffle(chunk) | |
| batches = [b for chunk in chunks for b in chunk] | |
| random.shuffle(batches) | |
| return batches | |
| def split_train_val(rows: list[tuple[str, str]], val_split: float, seed: int): | |
| random.Random(seed).shuffle(rows) | |
| val_n = max(1, int(len(rows) * val_split)) | |
| return rows[val_n:], rows[:val_n] | |