File size: 8,161 Bytes
12fd5f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
Dataset class that handles all data sources and produces training triplets:
(input_text, style_vector, target_text)
Data sources priority:
1. W&I+LOCNESS — real learner errors with expert corrections
2. JFLEG — naturalistic fluency corrections
3. GYAFC — informal→formal style transfer
4. Synthetic — dyslexia simulator augmentation on Wikipedia/books
5. Custom — any user-provided correction pairs
OPTIMISATION: Everything is pre-computed at init and cached to disk:
- Tokenisation (input_ids, attention_mask, labels)
- Style vectors (spaCy + MLP)
- Disk cache at data/cache/<hash>.pt — skips re-computation on re-runs
__getitem__ is a pure dict return — zero computation per batch.
"""
import json
import os
from pathlib import Path
from typing import List, Dict, Optional
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ..style.fingerprinter import StyleFingerprinter
from ..preprocessing.dyslexia_simulator import DyslexiaSimulator
from loguru import logger
import random
import hashlib
TASK_PREFIX = (
"Correct the following text for grammar, spelling, and clarity. "
"Maintain the author's original tone and writing style. "
"Elevate vocabulary to academic register. "
"Do NOT change the meaning or add new information. "
"Preserve named entities exactly. "
"Text to correct: "
)
CACHE_DIR = Path("data/cache")
class WritingCorrectionDataset(Dataset):
"""PyTorch dataset for writing correction training triplets.
Fully pre-computed at init with disk caching:
- First run: tokenises + extracts style vectors (~10 min), saves to disk
- Subsequent runs: loads from disk cache (~5 seconds)
- __getitem__ is a pure dict return (zero computation)
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
fingerprinter: StyleFingerprinter,
max_input_length: int = 256,
max_target_length: int = 256,
augment_with_synthetic: bool = True,
synthetic_ratio: float = 0.3,
):
self.tokenizer = tokenizer
self.fingerprinter = fingerprinter
self.max_input_length = max_input_length
self.max_target_length = max_target_length
# Load data
self.examples = self._load(data_path)
logger.info(f"Loaded {len(self.examples)} examples from {data_path}")
# Augment with synthetic dyslexia data
if augment_with_synthetic and self.examples:
self._add_synthetic(synthetic_ratio)
logger.info(f"Total dataset size: {len(self.examples)} examples")
# Compute cache key from data content + config
cache_key = self._compute_cache_key(data_path, augment_with_synthetic, synthetic_ratio)
cache_path = CACHE_DIR / f"{cache_key}.pt"
# Try loading from disk cache
if cache_path.exists():
logger.info(f"Loading pre-computed dataset from cache: {cache_path}")
self._precomputed = torch.load(cache_path, map_location="cpu", weights_only=False)
logger.info(f"Loaded {len(self._precomputed)} cached examples")
else:
# Pre-compute everything and save to disk
self._precomputed = self._precompute_all()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
torch.save(self._precomputed, cache_path)
logger.info(f"Saved pre-computed dataset to cache: {cache_path}")
def _compute_cache_key(self, data_path: str, augment: bool, ratio: float) -> str:
"""Generate a cache key based on data file content and processing params."""
h = hashlib.md5()
# Hash the data file content
try:
h.update(Path(data_path).read_bytes())
except FileNotFoundError:
h.update(data_path.encode())
# Hash processing parameters
h.update(f"aug={augment}|ratio={ratio}|maxin={self.max_input_length}|maxtgt={self.max_target_length}".encode())
return h.hexdigest()[:16]
def _load(self, path: str) -> List[Dict]:
"""Load JSONL data file."""
examples = []
try:
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "input" in obj and "target" in obj:
examples.append(obj)
except json.JSONDecodeError:
continue
except FileNotFoundError:
logger.warning(f"Data file not found: {path}")
return examples
def _add_synthetic(self, ratio: float):
"""Augment dataset with synthetic dyslexia examples."""
simulator = DyslexiaSimulator(error_rate=0.15, seed=42)
num_synthetic = int(len(self.examples) * ratio)
# Sample target texts to corrupt
source_examples = random.Random(42).choices(self.examples, k=num_synthetic)
synthetic_count = 0
for example in source_examples:
target = example["target"]
corrupted, clean = simulator.simulate(target)
# Only add if corruption actually changed the text
if corrupted != clean:
self.examples.append({
"input": corrupted,
"target": clean,
"source": "synthetic",
})
synthetic_count += 1
logger.info(f"Added {synthetic_count} synthetic augmentation examples")
def _precompute_all(self) -> List[Dict[str, torch.Tensor]]:
"""Pre-compute tokenisation + style vectors for ALL examples.
This makes __getitem__ a pure dict return with zero computation.
"""
logger.info("Pre-computing tokenisation and style vectors for all examples...")
precomputed = []
style_cache = {} # Deduplicate identical target texts
for i, example in enumerate(self.examples):
input_text = TASK_PREFIX + example["input"]
target_text = example["target"]
# Tokenise input
input_encoding = self.tokenizer(
input_text,
max_length=self.max_input_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Tokenise target (labels)
target_encoding = self.tokenizer(
target_text,
max_length=self.max_target_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Style vector (cached by content hash)
cache_key = hashlib.md5(target_text.encode()).hexdigest()[:16]
if cache_key not in style_cache:
with torch.no_grad():
style_cache[cache_key] = self.fingerprinter.extract_vector(target_text)
style_vector = style_cache[cache_key]
# Labels — set padding tokens to -100 so they're ignored in loss
labels = target_encoding["input_ids"].squeeze()
labels[labels == self.tokenizer.pad_token_id] = -100
precomputed.append({
"input_ids": input_encoding["input_ids"].squeeze(),
"attention_mask": input_encoding["attention_mask"].squeeze(),
"labels": labels,
"style_vector": style_vector,
"input_text": example["input"],
"target_text": target_text,
})
if (i + 1) % 2000 == 0:
logger.info(f" Pre-computed: {i + 1}/{len(self.examples)}")
logger.info(f"Pre-computation complete ({len(style_cache)} unique style vectors)")
return precomputed
def __len__(self):
return len(self._precomputed)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Pure dict return — zero computation per batch."""
return self._precomputed[idx]
|