""" Dataset class that handles all data sources and produces training triplets: (input_text, style_vector, target_text) Data sources priority: 1. W&I+LOCNESS — real learner errors with expert corrections 2. JFLEG — naturalistic fluency corrections 3. GYAFC — informal→formal style transfer 4. Synthetic — dyslexia simulator augmentation on Wikipedia/books 5. Custom — any user-provided correction pairs OPTIMISATION: Everything is pre-computed at init and cached to disk: - Tokenisation (input_ids, attention_mask, labels) - Style vectors (spaCy + MLP) - Disk cache at data/cache/.pt — skips re-computation on re-runs __getitem__ is a pure dict return — zero computation per batch. """ import json import os from pathlib import Path from typing import List, Dict, Optional import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizer from ..style.fingerprinter import StyleFingerprinter from ..preprocessing.dyslexia_simulator import DyslexiaSimulator from loguru import logger import random import hashlib TASK_PREFIX = ( "Correct the following text for grammar, spelling, and clarity. " "Maintain the author's original tone and writing style. " "Elevate vocabulary to academic register. " "Do NOT change the meaning or add new information. " "Preserve named entities exactly. " "Text to correct: " ) CACHE_DIR = Path("data/cache") class WritingCorrectionDataset(Dataset): """PyTorch dataset for writing correction training triplets. Fully pre-computed at init with disk caching: - First run: tokenises + extracts style vectors (~10 min), saves to disk - Subsequent runs: loads from disk cache (~5 seconds) - __getitem__ is a pure dict return (zero computation) """ def __init__( self, data_path: str, tokenizer: PreTrainedTokenizer, fingerprinter: StyleFingerprinter, max_input_length: int = 256, max_target_length: int = 256, augment_with_synthetic: bool = True, synthetic_ratio: float = 0.3, ): self.tokenizer = tokenizer self.fingerprinter = fingerprinter self.max_input_length = max_input_length self.max_target_length = max_target_length # Load data self.examples = self._load(data_path) logger.info(f"Loaded {len(self.examples)} examples from {data_path}") # Augment with synthetic dyslexia data if augment_with_synthetic and self.examples: self._add_synthetic(synthetic_ratio) logger.info(f"Total dataset size: {len(self.examples)} examples") # Compute cache key from data content + config cache_key = self._compute_cache_key(data_path, augment_with_synthetic, synthetic_ratio) cache_path = CACHE_DIR / f"{cache_key}.pt" # Try loading from disk cache if cache_path.exists(): logger.info(f"Loading pre-computed dataset from cache: {cache_path}") self._precomputed = torch.load(cache_path, map_location="cpu", weights_only=False) logger.info(f"Loaded {len(self._precomputed)} cached examples") else: # Pre-compute everything and save to disk self._precomputed = self._precompute_all() CACHE_DIR.mkdir(parents=True, exist_ok=True) torch.save(self._precomputed, cache_path) logger.info(f"Saved pre-computed dataset to cache: {cache_path}") def _compute_cache_key(self, data_path: str, augment: bool, ratio: float) -> str: """Generate a cache key based on data file content and processing params.""" h = hashlib.md5() # Hash the data file content try: h.update(Path(data_path).read_bytes()) except FileNotFoundError: h.update(data_path.encode()) # Hash processing parameters h.update(f"aug={augment}|ratio={ratio}|maxin={self.max_input_length}|maxtgt={self.max_target_length}".encode()) return h.hexdigest()[:16] def _load(self, path: str) -> List[Dict]: """Load JSONL data file.""" examples = [] try: with open(path) as f: for line in f: line = line.strip() if not line: continue try: obj = json.loads(line) if "input" in obj and "target" in obj: examples.append(obj) except json.JSONDecodeError: continue except FileNotFoundError: logger.warning(f"Data file not found: {path}") return examples def _add_synthetic(self, ratio: float): """Augment dataset with synthetic dyslexia examples.""" simulator = DyslexiaSimulator(error_rate=0.15, seed=42) num_synthetic = int(len(self.examples) * ratio) # Sample target texts to corrupt source_examples = random.Random(42).choices(self.examples, k=num_synthetic) synthetic_count = 0 for example in source_examples: target = example["target"] corrupted, clean = simulator.simulate(target) # Only add if corruption actually changed the text if corrupted != clean: self.examples.append({ "input": corrupted, "target": clean, "source": "synthetic", }) synthetic_count += 1 logger.info(f"Added {synthetic_count} synthetic augmentation examples") def _precompute_all(self) -> List[Dict[str, torch.Tensor]]: """Pre-compute tokenisation + style vectors for ALL examples. This makes __getitem__ a pure dict return with zero computation. """ logger.info("Pre-computing tokenisation and style vectors for all examples...") precomputed = [] style_cache = {} # Deduplicate identical target texts for i, example in enumerate(self.examples): input_text = TASK_PREFIX + example["input"] target_text = example["target"] # Tokenise input input_encoding = self.tokenizer( input_text, max_length=self.max_input_length, padding="max_length", truncation=True, return_tensors="pt", ) # Tokenise target (labels) target_encoding = self.tokenizer( target_text, max_length=self.max_target_length, padding="max_length", truncation=True, return_tensors="pt", ) # Style vector (cached by content hash) cache_key = hashlib.md5(target_text.encode()).hexdigest()[:16] if cache_key not in style_cache: with torch.no_grad(): style_cache[cache_key] = self.fingerprinter.extract_vector(target_text) style_vector = style_cache[cache_key] # Labels — set padding tokens to -100 so they're ignored in loss labels = target_encoding["input_ids"].squeeze() labels[labels == self.tokenizer.pad_token_id] = -100 precomputed.append({ "input_ids": input_encoding["input_ids"].squeeze(), "attention_mask": input_encoding["attention_mask"].squeeze(), "labels": labels, "style_vector": style_vector, "input_text": example["input"], "target_text": target_text, }) if (i + 1) % 2000 == 0: logger.info(f" Pre-computed: {i + 1}/{len(self.examples)}") logger.info(f"Pre-computation complete ({len(style_cache)} unique style vectors)") return precomputed def __len__(self): return len(self._precomputed) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Pure dict return — zero computation per batch.""" return self._precomputed[idx]