| """ |
| Dataset class that handles all data sources and produces training triplets: |
| (input_text, style_vector, target_text) |
| |
| Data sources priority: |
| 1. W&I+LOCNESS — real learner errors with expert corrections |
| 2. JFLEG — naturalistic fluency corrections |
| 3. GYAFC — informal→formal style transfer |
| 4. Synthetic — dyslexia simulator augmentation on Wikipedia/books |
| 5. Custom — any user-provided correction pairs |
| |
| OPTIMISATION: Everything is pre-computed at init and cached to disk: |
| - Tokenisation (input_ids, attention_mask, labels) |
| - Style vectors (spaCy + MLP) |
| - Disk cache at data/cache/<hash>.pt — skips re-computation on re-runs |
| __getitem__ is a pure dict return — zero computation per batch. |
| """ |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import List, Dict, Optional |
| import torch |
| from torch.utils.data import Dataset |
| from transformers import PreTrainedTokenizer |
| from ..style.fingerprinter import StyleFingerprinter |
| from ..preprocessing.dyslexia_simulator import DyslexiaSimulator |
| from loguru import logger |
| import random |
| import hashlib |
|
|
|
|
| TASK_PREFIX = ( |
| "Correct the following text for grammar, spelling, and clarity. " |
| "Maintain the author's original tone and writing style. " |
| "Elevate vocabulary to academic register. " |
| "Do NOT change the meaning or add new information. " |
| "Preserve named entities exactly. " |
| "Text to correct: " |
| ) |
|
|
| CACHE_DIR = Path("data/cache") |
|
|
|
|
| class WritingCorrectionDataset(Dataset): |
| """PyTorch dataset for writing correction training triplets. |
| |
| Fully pre-computed at init with disk caching: |
| - First run: tokenises + extracts style vectors (~10 min), saves to disk |
| - Subsequent runs: loads from disk cache (~5 seconds) |
| - __getitem__ is a pure dict return (zero computation) |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| tokenizer: PreTrainedTokenizer, |
| fingerprinter: StyleFingerprinter, |
| max_input_length: int = 256, |
| max_target_length: int = 256, |
| augment_with_synthetic: bool = True, |
| synthetic_ratio: float = 0.3, |
| ): |
| self.tokenizer = tokenizer |
| self.fingerprinter = fingerprinter |
| self.max_input_length = max_input_length |
| self.max_target_length = max_target_length |
|
|
| |
| self.examples = self._load(data_path) |
| logger.info(f"Loaded {len(self.examples)} examples from {data_path}") |
|
|
| |
| if augment_with_synthetic and self.examples: |
| self._add_synthetic(synthetic_ratio) |
|
|
| logger.info(f"Total dataset size: {len(self.examples)} examples") |
|
|
| |
| cache_key = self._compute_cache_key(data_path, augment_with_synthetic, synthetic_ratio) |
| cache_path = CACHE_DIR / f"{cache_key}.pt" |
|
|
| |
| if cache_path.exists(): |
| logger.info(f"Loading pre-computed dataset from cache: {cache_path}") |
| self._precomputed = torch.load(cache_path, map_location="cpu", weights_only=False) |
| logger.info(f"Loaded {len(self._precomputed)} cached examples") |
| else: |
| |
| self._precomputed = self._precompute_all() |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| torch.save(self._precomputed, cache_path) |
| logger.info(f"Saved pre-computed dataset to cache: {cache_path}") |
|
|
| def _compute_cache_key(self, data_path: str, augment: bool, ratio: float) -> str: |
| """Generate a cache key based on data file content and processing params.""" |
| h = hashlib.md5() |
| |
| try: |
| h.update(Path(data_path).read_bytes()) |
| except FileNotFoundError: |
| h.update(data_path.encode()) |
| |
| h.update(f"aug={augment}|ratio={ratio}|maxin={self.max_input_length}|maxtgt={self.max_target_length}".encode()) |
| return h.hexdigest()[:16] |
|
|
| def _load(self, path: str) -> List[Dict]: |
| """Load JSONL data file.""" |
| examples = [] |
| try: |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| obj = json.loads(line) |
| if "input" in obj and "target" in obj: |
| examples.append(obj) |
| except json.JSONDecodeError: |
| continue |
| except FileNotFoundError: |
| logger.warning(f"Data file not found: {path}") |
| return examples |
|
|
| def _add_synthetic(self, ratio: float): |
| """Augment dataset with synthetic dyslexia examples.""" |
| simulator = DyslexiaSimulator(error_rate=0.15, seed=42) |
| num_synthetic = int(len(self.examples) * ratio) |
|
|
| |
| source_examples = random.Random(42).choices(self.examples, k=num_synthetic) |
|
|
| synthetic_count = 0 |
| for example in source_examples: |
| target = example["target"] |
| corrupted, clean = simulator.simulate(target) |
|
|
| |
| if corrupted != clean: |
| self.examples.append({ |
| "input": corrupted, |
| "target": clean, |
| "source": "synthetic", |
| }) |
| synthetic_count += 1 |
|
|
| logger.info(f"Added {synthetic_count} synthetic augmentation examples") |
|
|
| def _precompute_all(self) -> List[Dict[str, torch.Tensor]]: |
| """Pre-compute tokenisation + style vectors for ALL examples. |
| This makes __getitem__ a pure dict return with zero computation. |
| """ |
| logger.info("Pre-computing tokenisation and style vectors for all examples...") |
| precomputed = [] |
| style_cache = {} |
|
|
| for i, example in enumerate(self.examples): |
| input_text = TASK_PREFIX + example["input"] |
| target_text = example["target"] |
|
|
| |
| input_encoding = self.tokenizer( |
| input_text, |
| max_length=self.max_input_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| |
| target_encoding = self.tokenizer( |
| target_text, |
| max_length=self.max_target_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| |
| cache_key = hashlib.md5(target_text.encode()).hexdigest()[:16] |
| if cache_key not in style_cache: |
| with torch.no_grad(): |
| style_cache[cache_key] = self.fingerprinter.extract_vector(target_text) |
| style_vector = style_cache[cache_key] |
|
|
| |
| labels = target_encoding["input_ids"].squeeze() |
| labels[labels == self.tokenizer.pad_token_id] = -100 |
|
|
| precomputed.append({ |
| "input_ids": input_encoding["input_ids"].squeeze(), |
| "attention_mask": input_encoding["attention_mask"].squeeze(), |
| "labels": labels, |
| "style_vector": style_vector, |
| "input_text": example["input"], |
| "target_text": target_text, |
| }) |
|
|
| if (i + 1) % 2000 == 0: |
| logger.info(f" Pre-computed: {i + 1}/{len(self.examples)}") |
|
|
| logger.info(f"Pre-computation complete ({len(style_cache)} unique style vectors)") |
| return precomputed |
|
|
| def __len__(self): |
| return len(self._precomputed) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """Pure dict return — zero computation per batch.""" |
| return self._precomputed[idx] |
|
|