rewrite / src /training /dataset.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Dataset class that handles all data sources and produces training triplets:
(input_text, style_vector, target_text)
Data sources priority:
1. W&I+LOCNESS — real learner errors with expert corrections
2. JFLEG — naturalistic fluency corrections
3. GYAFC — informal→formal style transfer
4. Synthetic — dyslexia simulator augmentation on Wikipedia/books
5. Custom — any user-provided correction pairs
OPTIMISATION: Everything is pre-computed at init and cached to disk:
- Tokenisation (input_ids, attention_mask, labels)
- Style vectors (spaCy + MLP)
- Disk cache at data/cache/<hash>.pt — skips re-computation on re-runs
__getitem__ is a pure dict return — zero computation per batch.
"""
import json
import os
from pathlib import Path
from typing import List, Dict, Optional
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ..style.fingerprinter import StyleFingerprinter
from ..preprocessing.dyslexia_simulator import DyslexiaSimulator
from loguru import logger
import random
import hashlib
TASK_PREFIX = (
"Correct the following text for grammar, spelling, and clarity. "
"Maintain the author's original tone and writing style. "
"Elevate vocabulary to academic register. "
"Do NOT change the meaning or add new information. "
"Preserve named entities exactly. "
"Text to correct: "
)
CACHE_DIR = Path("data/cache")
class WritingCorrectionDataset(Dataset):
"""PyTorch dataset for writing correction training triplets.
Fully pre-computed at init with disk caching:
- First run: tokenises + extracts style vectors (~10 min), saves to disk
- Subsequent runs: loads from disk cache (~5 seconds)
- __getitem__ is a pure dict return (zero computation)
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
fingerprinter: StyleFingerprinter,
max_input_length: int = 256,
max_target_length: int = 256,
augment_with_synthetic: bool = True,
synthetic_ratio: float = 0.3,
):
self.tokenizer = tokenizer
self.fingerprinter = fingerprinter
self.max_input_length = max_input_length
self.max_target_length = max_target_length
# Load data
self.examples = self._load(data_path)
logger.info(f"Loaded {len(self.examples)} examples from {data_path}")
# Augment with synthetic dyslexia data
if augment_with_synthetic and self.examples:
self._add_synthetic(synthetic_ratio)
logger.info(f"Total dataset size: {len(self.examples)} examples")
# Compute cache key from data content + config
cache_key = self._compute_cache_key(data_path, augment_with_synthetic, synthetic_ratio)
cache_path = CACHE_DIR / f"{cache_key}.pt"
# Try loading from disk cache
if cache_path.exists():
logger.info(f"Loading pre-computed dataset from cache: {cache_path}")
self._precomputed = torch.load(cache_path, map_location="cpu", weights_only=False)
logger.info(f"Loaded {len(self._precomputed)} cached examples")
else:
# Pre-compute everything and save to disk
self._precomputed = self._precompute_all()
CACHE_DIR.mkdir(parents=True, exist_ok=True)
torch.save(self._precomputed, cache_path)
logger.info(f"Saved pre-computed dataset to cache: {cache_path}")
def _compute_cache_key(self, data_path: str, augment: bool, ratio: float) -> str:
"""Generate a cache key based on data file content and processing params."""
h = hashlib.md5()
# Hash the data file content
try:
h.update(Path(data_path).read_bytes())
except FileNotFoundError:
h.update(data_path.encode())
# Hash processing parameters
h.update(f"aug={augment}|ratio={ratio}|maxin={self.max_input_length}|maxtgt={self.max_target_length}".encode())
return h.hexdigest()[:16]
def _load(self, path: str) -> List[Dict]:
"""Load JSONL data file."""
examples = []
try:
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
if "input" in obj and "target" in obj:
examples.append(obj)
except json.JSONDecodeError:
continue
except FileNotFoundError:
logger.warning(f"Data file not found: {path}")
return examples
def _add_synthetic(self, ratio: float):
"""Augment dataset with synthetic dyslexia examples."""
simulator = DyslexiaSimulator(error_rate=0.15, seed=42)
num_synthetic = int(len(self.examples) * ratio)
# Sample target texts to corrupt
source_examples = random.Random(42).choices(self.examples, k=num_synthetic)
synthetic_count = 0
for example in source_examples:
target = example["target"]
corrupted, clean = simulator.simulate(target)
# Only add if corruption actually changed the text
if corrupted != clean:
self.examples.append({
"input": corrupted,
"target": clean,
"source": "synthetic",
})
synthetic_count += 1
logger.info(f"Added {synthetic_count} synthetic augmentation examples")
def _precompute_all(self) -> List[Dict[str, torch.Tensor]]:
"""Pre-compute tokenisation + style vectors for ALL examples.
This makes __getitem__ a pure dict return with zero computation.
"""
logger.info("Pre-computing tokenisation and style vectors for all examples...")
precomputed = []
style_cache = {} # Deduplicate identical target texts
for i, example in enumerate(self.examples):
input_text = TASK_PREFIX + example["input"]
target_text = example["target"]
# Tokenise input
input_encoding = self.tokenizer(
input_text,
max_length=self.max_input_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Tokenise target (labels)
target_encoding = self.tokenizer(
target_text,
max_length=self.max_target_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# Style vector (cached by content hash)
cache_key = hashlib.md5(target_text.encode()).hexdigest()[:16]
if cache_key not in style_cache:
with torch.no_grad():
style_cache[cache_key] = self.fingerprinter.extract_vector(target_text)
style_vector = style_cache[cache_key]
# Labels — set padding tokens to -100 so they're ignored in loss
labels = target_encoding["input_ids"].squeeze()
labels[labels == self.tokenizer.pad_token_id] = -100
precomputed.append({
"input_ids": input_encoding["input_ids"].squeeze(),
"attention_mask": input_encoding["attention_mask"].squeeze(),
"labels": labels,
"style_vector": style_vector,
"input_text": example["input"],
"target_text": target_text,
})
if (i + 1) % 2000 == 0:
logger.info(f" Pre-computed: {i + 1}/{len(self.examples)}")
logger.info(f"Pre-computation complete ({len(style_cache)} unique style vectors)")
return precomputed
def __len__(self):
return len(self._precomputed)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Pure dict return — zero computation per batch."""
return self._precomputed[idx]