| # Dyslexia Academic Writing Correction System |
| ## Complete End-to-End Implementation Blueprint for Coding Agents |
|
|
| > **System Goal:** A style-preserving, grammar-correcting, academic vocabulary elevating AI model that corrects dyslectic writing while maintaining the author's personal voice, tone, and authorship signal β not a rewriter, a corrector. |
|
|
| --- |
|
|
| ## Table of Contents |
|
|
| 1. [Repository Structure](#1-repository-structure) |
| 2. [Environment Setup](#2-environment-setup) |
| 3. [Dependency Manifest](#3-dependency-manifest) |
| 4. [System Architecture Overview](#4-system-architecture-overview) |
| 5. [Layer 1 β Input Pre-Processing Pipeline](#5-layer-1--input-pre-processing-pipeline) |
| 6. [Layer 2 β Style Fingerprinting Module](#6-layer-2--style-fingerprinting-module) |
| 7. [Layer 3 β Core Generation Model](#7-layer-3--core-generation-model) |
| 8. [Layer 4 β Training Data Strategy](#8-layer-4--training-data-strategy) |
| 9. [Layer 5 β Training Loop & Loss Functions](#9-layer-5--training-loop--loss-functions) |
| 10. [Layer 6 β Academic Vocabulary Control Module](#10-layer-6--academic-vocabulary-control-module) |
| 11. [Layer 7 β Evaluation Framework](#11-layer-7--evaluation-framework) |
| 12. [Layer 8 β Inference Pipeline](#12-layer-8--inference-pipeline) |
| 13. [Layer 9 β API Server](#13-layer-9--api-server) |
| 14. [Layer 10 β Configuration Files](#14-layer-10--configuration-files) |
| 15. [Layer 11 β Full Training Run Sequence](#15-layer-11--full-training-run-sequence) |
| 16. [Mathematical Formulations](#16-mathematical-formulations) |
| 17. [Hyperparameter Reference](#17-hyperparameter-reference) |
| 18. [Dataset Sources & Download Instructions](#18-dataset-sources--download-instructions) |
| 19. [Hardware Requirements](#19-hardware-requirements) |
| 20. [Testing Suite](#20-testing-suite) |
|
|
| --- |
|
|
| ## 1. Repository Structure |
|
|
| ``` |
| dyslexia-writing-ai/ |
| β |
| βββ configs/ |
| β βββ model_config.yaml |
| β βββ training_config.yaml |
| β βββ inference_config.yaml |
| β βββ awl_config.yaml |
| β |
| βββ data/ |
| β βββ raw/ |
| β β βββ wi_locness/ |
| β β βββ jfleg/ |
| β β βββ gyafc/ |
| β β βββ custom_dyslexia/ |
| β βββ processed/ |
| β β βββ train.jsonl |
| β β βββ val.jsonl |
| β β βββ test.jsonl |
| β βββ awl/ |
| β βββ coxhead_awl.txt |
| β βββ academic_synonyms.json |
| β βββ domain_lexicons/ |
| β βββ humanities.txt |
| β βββ sciences.txt |
| β βββ social_sciences.txt |
| β |
| βββ src/ |
| β βββ preprocessing/ |
| β β βββ __init__.py |
| β β βββ spell_corrector.py |
| β β βββ sentence_segmenter.py |
| β β βββ dependency_parser.py |
| β β βββ ner_tagger.py |
| β β βββ dyslexia_simulator.py |
| β β βββ pipeline.py |
| β β |
| β βββ style/ |
| β β βββ __init__.py |
| β β βββ fingerprinter.py |
| β β βββ formality_classifier.py |
| β β βββ emotion_classifier.py |
| β β βββ style_vector.py |
| β β |
| β βββ model/ |
| β β βββ __init__.py |
| β β βββ base_model.py |
| β β βββ lora_adapter.py |
| β β βββ style_conditioner.py |
| β β βββ generation_utils.py |
| β β |
| β βββ training/ |
| β β βββ __init__.py |
| β β βββ dataset.py |
| β β βββ loss_functions.py |
| β β βββ trainer.py |
| β β βββ callbacks.py |
| β β |
| β βββ vocabulary/ |
| β β βββ __init__.py |
| β β βββ awl_loader.py |
| β β βββ lexical_substitution.py |
| β β βββ register_filter.py |
| β β |
| β βββ evaluation/ |
| β β βββ __init__.py |
| β β βββ gleu_scorer.py |
| β β βββ errant_evaluator.py |
| β β βββ style_metrics.py |
| β β βββ authorship_verifier.py |
| β β |
| β βββ inference/ |
| β β βββ __init__.py |
| β β βββ corrector.py |
| β β βββ postprocessor.py |
| β β |
| β βββ api/ |
| β βββ __init__.py |
| β βββ main.py |
| β βββ schemas.py |
| β βββ middleware.py |
| β |
| βββ scripts/ |
| β βββ download_datasets.sh |
| β βββ preprocess_data.py |
| β βββ train.py |
| β βββ evaluate.py |
| β βββ run_inference.py |
| β |
| βββ tests/ |
| β βββ test_preprocessing.py |
| β βββ test_style.py |
| β βββ test_model.py |
| β βββ test_vocabulary.py |
| β βββ test_evaluation.py |
| β |
| βββ notebooks/ |
| β βββ 01_data_exploration.ipynb |
| β βββ 02_style_fingerprint_analysis.ipynb |
| β βββ 03_training_diagnostics.ipynb |
| β βββ 04_evaluation_dashboard.ipynb |
| β |
| βββ requirements.txt |
| βββ requirements-dev.txt |
| βββ pyproject.toml |
| βββ Dockerfile |
| βββ docker-compose.yml |
| βββ README.md |
| ``` |
|
|
| --- |
|
|
| ## 2. Environment Setup |
|
|
| ```bash |
| # Python version requirement |
| python >= 3.10 |
| |
| # Create virtual environment |
| python -m venv venv |
| source venv/bin/activate # Linux/Mac |
| # venv\Scripts\activate # Windows |
| |
| # Install PyTorch with CUDA (choose your CUDA version) |
| pip install torch==2.2.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 |
| |
| # Install all dependencies |
| pip install -r requirements.txt |
| |
| # Download spaCy transformer model |
| python -m spacy download en_core_web_trf |
| |
| # Download NLTK data |
| python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('wordnet')" |
| |
| # Install LanguageTool server (Java required) |
| pip install language-tool-python |
| # It auto-downloads the LanguageTool JAR on first run |
| |
| # Setup Weights & Biases for experiment tracking |
| wandb login |
| ``` |
|
|
| --- |
|
|
| ## 3. Dependency Manifest |
|
|
| ### `requirements.txt` |
|
|
| ```txt |
| # ββ Core ML & Deep Learning ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| torch==2.2.0 |
| torchvision==0.17.0 |
| torchaudio==2.2.0 |
| transformers==4.40.0 |
| datasets==2.18.0 |
| accelerate==0.29.0 |
| peft==0.10.0 # LoRA / parameter-efficient fine-tuning |
| bitsandbytes==0.43.0 # 8-bit & 4-bit quantization |
| sentencepiece==0.2.0 # T5 tokenizer dependency |
| protobuf==4.25.3 # T5 tokenizer dependency |
| |
| # ββ Sentence Embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| sentence-transformers==2.6.1 |
| faiss-cpu==1.8.0 # Vector similarity search |
| |
| # ββ NLP Pre-Processing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| spacy==3.7.4 |
| spacy-transformers==1.3.4 |
| language-tool-python==2.7.1 # LanguageTool grammar checker |
| pyspellchecker==0.8.1 # Context-free spell check (pre-pass) |
| nltk==3.8.1 |
| textstat==0.7.3 # Readability scores (Flesch-Kincaid, etc.) |
| |
| # ββ Lexical Substitution βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| lexsubgen==0.0.4 # BERT-based lexical substitution |
| wordfreq==3.1.1 # Word frequency data |
| PyDictionary==2.0.1 |
| |
| # ββ Training Infrastructure βββββββββββββββββββββββββββββββββββββββββββββββββββ |
| wandb==0.16.6 # Experiment tracking |
| tensorboard==2.16.2 |
| numpy==1.26.4 |
| pandas==2.2.1 |
| scikit-learn==1.4.1.post1 |
| scipy==1.13.0 |
| |
| # ββ Evaluation Tools ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| errant==2.3.3 # Grammar Error Annotation Toolkit |
| sacrebleu==2.4.2 # BLEU/GLEU scoring |
| bert-score==0.3.13 # Semantic similarity scoring |
| rouge-score==0.1.2 |
| |
| # ββ API Server ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| fastapi==0.110.1 |
| uvicorn[standard]==0.29.0 |
| pydantic==2.7.0 |
| python-multipart==0.0.9 |
| httpx==0.27.0 |
| |
| # ββ Inference Optimisation ββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| vllm==0.4.0 # High-throughput LLM serving (optional, GPU only) |
| optimum==1.19.1 # Hugging Face model optimisation |
| |
| # ββ Utilities βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| pyyaml==6.0.1 |
| tqdm==4.66.2 |
| loguru==0.7.2 |
| python-dotenv==1.0.1 |
| click==8.1.7 |
| rich==13.7.1 # Beautiful terminal output |
| joblib==1.4.0 |
| ``` |
|
|
| ### `requirements-dev.txt` |
|
|
| ```txt |
| pytest==8.1.1 |
| pytest-asyncio==0.23.6 |
| pytest-cov==5.0.0 |
| black==24.4.0 |
| ruff==0.4.1 |
| mypy==1.9.0 |
| pre-commit==3.7.0 |
| ipykernel==6.29.4 |
| jupyter==1.0.0 |
| ``` |
|
|
| --- |
|
|
| ## 4. System Architecture Overview |
|
|
| ``` |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β INPUT TEXT (raw dyslectic) β |
| βββββββββββββββββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββββ |
| β |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β LAYER 1: Pre-Processing Pipeline β |
| β spell_corrector β segmenter β dep_parser β β |
| β NER_tagger β readability_scorer β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β cleaned + annotated text |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β LAYER 2: Style Fingerprinting β |
| β sentence_len_dist, syntactic_complexity, β |
| β TTR, voice_ratio, hedging_freq, β |
| β discourse_markers, formality_score, β |
| β emotion_register β style_vector [512-dim] β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β β |
| [user style vec] [master copy style vec] |
| β β |
| βββββββββββΌββββββββββββββββΌβββββββββββββββββββββββ |
| β STYLE BLENDING (weighted interpolation) β |
| β target_style = Ξ±Β·user + (1-Ξ±)Β·master β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β LAYER 3: Core Generation Model β |
| β Base: Flan-T5-XL / BART-large / Llama-3 β |
| β Fine-tuned with LoRA β |
| β Conditioned on: cleaned_text + style_vector β |
| β Loss: CE + style_consistency + semantic_sim β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β draft corrected text |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β LAYER 6: Academic Vocabulary Control β |
| β AWL substitution β register filter β |
| β β nominalisation pass β hedging check β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β LAYER 7: Evaluation & Quality Gate β |
| β GLEU, ERRANT, style_sim, authorship_score β |
| β If quality < threshold β re-generate β |
| βββββββββββββββββββββββββββ¬βββββββββββββββββββββββ |
| β |
| βββββββββββββββββββββββββββΌβββββββββββββββββββββββ |
| β FINAL OUTPUT β |
| β Grammatically perfect Β· Academic register β |
| β Style-preserved Β· Human authorship signal β |
| ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| ``` |
|
|
| --- |
|
|
| ## 5. Layer 1 β Input Pre-Processing Pipeline |
|
|
| ### `src/preprocessing/spell_corrector.py` |
| |
| ```python |
| """ |
| Two-pass spell correction: |
| Pass 1: pyspellchecker (fast, context-free, catches simple typos) |
| Pass 2: LanguageTool (context-aware, catches grammar + dyslexic patterns) |
| |
| Dyslexic error patterns handled: |
| - Letter reversals: b/d, p/q, n/u, m/w |
| - Phonetic spelling: "wuz", "cud", "thay" |
| - Word boundary errors: "alot", "infact", "aswell" |
| - Letter omissions: "becaus", "importnt" |
| - Letter transpositions: "teh", "recieve" |
| - Homophone confusion: there/their/they're |
| """ |
| |
| import language_tool_python |
| from spellchecker import SpellChecker |
| from loguru import logger |
| from typing import Optional |
| import re |
| |
| |
| class DyslexiaAwareSpellCorrector: |
| |
| DYSLEXIC_PHONETIC_MAP = { |
| "wuz": "was", "cud": "could", "wud": "would", "shud": "should", |
| "thay": "they", "thier": "their", "recieve": "receive", |
| "beleive": "believe", "occured": "occurred", "definately": "definitely", |
| "seperate": "separate", "untill": "until", "tommorrow": "tomorrow", |
| "alot": "a lot", "infact": "in fact", "aswell": "as well", |
| "alright": "all right", "cant": "cannot", "wont": "will not", |
| "ive": "I have", "im": "I am", "id": "I would", |
| } |
| |
| def __init__(self, language: str = "en-US"): |
| self.spell = SpellChecker() |
| self.tool = language_tool_python.LanguageTool(language) |
| logger.info("Spell corrector initialised with LanguageTool backend.") |
| |
| def _phonetic_pass(self, text: str) -> str: |
| """Apply known dyslexic phonetic substitutions first.""" |
| pattern = re.compile( |
| r'\b(' + '|'.join(re.escape(k) for k in self.DYSLEXIC_PHONETIC_MAP.keys()) + r')\b', |
| re.IGNORECASE |
| ) |
| def replace(match): |
| return self.DYSLEXIC_PHONETIC_MAP[match.group(0).lower()] |
| return pattern.sub(replace, text) |
| |
| def _spellcheck_pass(self, text: str) -> str: |
| """pyspellchecker pass for simple token-level errors.""" |
| tokens = text.split() |
| corrected = [] |
| for token in tokens: |
| clean = re.sub(r'[^\w]', '', token).lower() |
| if clean and clean not in self.spell: |
| correction = self.spell.correction(clean) |
| if correction: |
| token = token.replace(clean, correction) |
| corrected.append(token) |
| return ' '.join(corrected) |
| |
| def _languagetool_pass(self, text: str) -> str: |
| """LanguageTool pass for context-aware grammar + spelling corrections.""" |
| matches = self.tool.check(text) |
| # Apply corrections in reverse order to preserve offsets |
| for match in reversed(matches): |
| if match.replacements: |
| start = match.offset |
| end = start + match.errorLength |
| text = text[:start] + match.replacements[0] + text[end:] |
| return text |
| |
| def correct(self, text: str) -> str: |
| text = self._phonetic_pass(text) |
| text = self._spellcheck_pass(text) |
| text = self._languagetool_pass(text) |
| return text |
| |
| def close(self): |
| self.tool.close() |
| ``` |
| |
| --- |
| |
| ### `src/preprocessing/pipeline.py` |
| |
| ```python |
| """ |
| Master pre-processing pipeline. Runs all NLP stages in sequence. |
| Returns a PreprocessedDoc object with all annotations attached. |
| """ |
| |
| import spacy |
| from dataclasses import dataclass, field |
| from typing import List, Dict, Any, Optional |
| from .spell_corrector import DyslexiaAwareSpellCorrector |
| import textstat |
|
|
|
|
| @dataclass |
| class EntitySpan: |
| text: str |
| label: str |
| start_char: int |
| end_char: int |
| |
|
|
| @dataclass |
| class PreprocessedDoc: |
| original_text: str |
| corrected_text: str |
| sentences: List[str] |
| entities: List[EntitySpan] # Never to be modified by rewriter |
| dependency_trees: List[Dict] # Grammatical skeletons per sentence |
| pos_tags: List[List[tuple]] # (token, POS) per sentence |
| readability: Dict[str, float] # Flesch-Kincaid, Gunning Fog, etc. |
| sentence_lengths: List[int] |
| protected_spans: List[tuple] # (start, end) char spans to never touch |
| |
|
|
| class PreprocessingPipeline: |
|
|
| def __init__(self, model_name: str = "en_core_web_trf"): |
| self.nlp = spacy.load(model_name) |
| self.corrector = DyslexiaAwareSpellCorrector() |
| |
| def _extract_readability(self, text: str) -> Dict[str, float]: |
| return { |
| "flesch_reading_ease": textstat.flesch_reading_ease(text), |
| "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), |
| "gunning_fog": textstat.gunning_fog(text), |
| "smog_index": textstat.smog_index(text), |
| "automated_readability_index": textstat.automated_readability_index(text), |
| } |
| |
| def _extract_dep_tree(self, sent) -> Dict: |
| """Extract grammatical skeleton: subject-verb-object per sentence.""" |
| tree = {"tokens": [], "root": None, "svo": []} |
| subjects, verbs, objects = [], [], [] |
| for token in sent: |
| tree["tokens"].append({ |
| "text": token.text, |
| "dep": token.dep_, |
| "pos": token.pos_, |
| "head": token.head.text, |
| }) |
| if token.dep_ == "ROOT": |
| tree["root"] = token.text |
| if token.dep_ in ("nsubj", "nsubjpass"): |
| subjects.append(token.text) |
| if token.pos_ == "VERB": |
| verbs.append(token.text) |
| if token.dep_ in ("dobj", "pobj"): |
| objects.append(token.text) |
| tree["svo"] = {"subjects": subjects, "verbs": verbs, "objects": objects} |
| return tree |
| |
| def process(self, raw_text: str) -> PreprocessedDoc: |
| # Step 1: Spell + grammar correction |
| corrected = self.corrector.correct(raw_text) |
| |
| # Step 2: spaCy full parse |
| doc = self.nlp(corrected) |
| |
| # Step 3: Extract sentences |
| sentences = [sent.text.strip() for sent in doc.sents] |
| sentence_lengths = [len(sent.text.split()) for sent in doc.sents] |
| |
| # Step 4: Named entities (protect these spans) |
| entities = [ |
| EntitySpan(ent.text, ent.label_, ent.start_char, ent.end_char) |
| for ent in doc.ents |
| ] |
| protected_spans = [(e.start_char, e.end_char) for e in entities] |
| |
| # Step 5: Dependency trees |
| dep_trees = [self._extract_dep_tree(sent) for sent in doc.sents] |
| |
| # Step 6: POS tags |
| pos_tags = [ |
| [(token.text, token.pos_) for token in sent] |
| for sent in doc.sents |
| ] |
| |
| # Step 7: Readability |
| readability = self._extract_readability(corrected) |
| |
| return PreprocessedDoc( |
| original_text=raw_text, |
| corrected_text=corrected, |
| sentences=sentences, |
| entities=entities, |
| dependency_trees=dep_trees, |
| pos_tags=pos_tags, |
| readability=readability, |
| sentence_lengths=sentence_lengths, |
| protected_spans=protected_spans, |
| ) |
| ``` |
| |
| --- |
|
|
| ### `src/preprocessing/dyslexia_simulator.py` |
| |
| ```python |
| """ |
| Programmatically generates dyslectic training data from clean text. |
| Used to augment training pairs when real dyslectic examples are scarce. |
| |
| Error types simulated (from Rello et al. 2013, 2017 dyslexia research): |
| - Phonetic substitution (most common, ~35% of errors) |
| - Letter transposition (e.g., "teh" for "the") (~18%) |
| - Letter omission (~16%) |
| - Letter doubling (~12%) |
| - Letter reversal b/d, p/q (~10%) |
| - Word boundary errors (~9%) |
| """ |
| |
| import random |
| import re |
| from typing import Tuple |
| |
| |
| class DyslexiaSimulator: |
| |
| LETTER_REVERSALS = {'b': 'd', 'd': 'b', 'p': 'q', 'q': 'p', 'n': 'u', 'u': 'n'} |
| PHONETIC_SUBS = { |
| 'was': 'wuz', 'could': 'cud', 'would': 'wud', 'they': 'thay', |
| 'because': 'becaus', 'important': 'importnt', 'receive': 'recieve', |
| 'believe': 'beleive', 'definitely': 'definately', 'separate': 'seperate', |
| 'a lot': 'alot', 'in fact': 'infact', 'as well': 'aswell', |
| } |
| WORD_MERGES = [ |
| ('a lot', 'alot'), ('in fact', 'infact'), ('as well', 'aswell'), |
| ('all right', 'alright'), ('every one', 'everyone'), |
| ] |
| |
| def __init__(self, error_rate: float = 0.15, seed: int = 42): |
| self.error_rate = error_rate |
| random.seed(seed) |
| |
| def _transpose_letters(self, word: str) -> str: |
| if len(word) < 3: |
| return word |
| i = random.randint(0, len(word) - 2) |
| chars = list(word) |
| chars[i], chars[i+1] = chars[i+1], chars[i] |
| return ''.join(chars) |
| |
| def _omit_letter(self, word: str) -> str: |
| if len(word) < 4: |
| return word |
| i = random.randint(1, len(word) - 2) |
| return word[:i] + word[i+1:] |
| |
| def _double_letter(self, word: str) -> str: |
| if len(word) < 3: |
| return word |
| i = random.randint(1, len(word) - 2) |
| return word[:i] + word[i] + word[i:] |
| |
| def _reverse_letter(self, word: str) -> str: |
| chars = list(word) |
| for i, c in enumerate(chars): |
| if c in self.LETTER_REVERSALS and random.random() < 0.5: |
| chars[i] = self.LETTER_REVERSALS[c] |
| return ''.join(chars) |
| |
| def corrupt_word(self, word: str) -> str: |
| """Apply a single random error to a word.""" |
| if len(word) <= 2 or random.random() > self.error_rate: |
| return word |
| # Check phonetic substitutions first |
| lower = word.lower() |
| if lower in self.PHONETIC_SUBS: |
| return self.PHONETIC_SUBS[lower] |
| choice = random.choice(['transpose', 'omit', 'double', 'reverse']) |
| if choice == 'transpose': |
| return self._transpose_letters(word) |
| elif choice == 'omit': |
| return self._omit_letter(word) |
| elif choice == 'double': |
| return self._double_letter(word) |
| else: |
| return self._reverse_letter(word) |
| |
| def simulate(self, clean_text: str) -> Tuple[str, str]: |
| """Returns (corrupted_text, clean_text) training pair.""" |
| words = clean_text.split() |
| corrupted = [self.corrupt_word(w) for w in words] |
| corrupted_text = ' '.join(corrupted) |
| # Apply word merge errors |
| for correct_phrase, merged in self.WORD_MERGES: |
| if random.random() < 0.3: |
| corrupted_text = corrupted_text.replace(correct_phrase, merged) |
| return corrupted_text, clean_text |
| ``` |
| |
| --- |
|
|
| ## 6. Layer 2 β Style Fingerprinting Module |
|
|
| ### `src/style/fingerprinter.py` |
|
|
| ```python |
| """ |
| Extracts a numerical style vector from any text sample. |
| The style vector encodes the author's unique writing fingerprint |
| and is used both to condition the generation model and to evaluate |
| style preservation after correction. |
| |
| Style vector dimensions (total: 512 after projection): |
| Raw features (~40) β MLP projection β 512-dim dense vector |
| |
| Raw features: |
| - sentence_length_mean, sentence_length_std, sentence_length_skew [3] |
| - word_length_mean, word_length_std [2] |
| - type_token_ratio (TTR) [1] |
| - passive_voice_ratio [1] |
| - active_voice_ratio [1] |
| - subordinate_clause_ratio [1] |
| - avg_dependency_tree_depth [1] |
| - hedging_frequency (per 100 words) [1] |
| - discourse_marker_counts [however, therefore, moreover, ...] [20] |
| - formality_score (0-1) [1] |
| - lexical_density [1] |
| - nominalization_ratio [1] |
| - question_sentence_ratio [1] |
| - exclamation_ratio [1] |
| - first_person_ratio [1] |
| - third_person_ratio [1] |
| - academic_word_coverage [1] |
| - avg_syllables_per_word [1] |
| - flesch_reading_ease [1] |
| """ |
| |
| import spacy |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from typing import List, Dict, Optional |
| from scipy import stats |
| |
| |
| HEDGING_WORDS = { |
| "perhaps", "possibly", "probably", "might", "may", "could", "seem", |
| "appears", "suggests", "indicates", "tend", "often", "generally", |
| "approximately", "roughly", "somewhat", "relatively", "fairly", |
| } |
| |
| DISCOURSE_MARKERS = [ |
| "however", "therefore", "moreover", "furthermore", "consequently", |
| "nevertheless", "nonetheless", "additionally", "alternatively", |
| "subsequently", "previously", "similarly", "conversely", "thus", |
| "hence", "accordingly", "meanwhile", "indeed", "notably", "specifically", |
| ] |
| |
| NOMINALISATION_SUFFIXES = ( |
| "tion", "sion", "ment", "ness", "ity", "ance", "ence", |
| "hood", "ship", "ism", "al", "ure", |
| ) |
| |
| |
| class StyleProjectionMLP(nn.Module): |
| """Projects raw feature vector to 512-dim style embedding.""" |
| def __init__(self, input_dim: int = 40, hidden_dim: int = 256, output_dim: int = 512): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, output_dim), |
| nn.LayerNorm(output_dim), |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
| |
| |
| class StyleFingerprinter: |
| |
| def __init__(self, spacy_model: str = "en_core_web_trf", awl_path: str = "data/awl/coxhead_awl.txt"): |
| self.nlp = spacy.load(spacy_model) |
| self.awl = self._load_awl(awl_path) |
| self.projection = StyleProjectionMLP() |
| |
| def _load_awl(self, path: str) -> set: |
| try: |
| with open(path) as f: |
| return {line.strip().lower() for line in f if line.strip()} |
| except FileNotFoundError: |
| return set() |
| |
| def _passive_voice_ratio(self, doc) -> float: |
| passive_count = sum( |
| 1 for token in doc |
| if token.dep_ in ("nsubjpass", "auxpass") |
| ) |
| sentences = list(doc.sents) |
| return passive_count / max(len(sentences), 1) |
| |
| def _avg_dep_tree_depth(self, doc) -> float: |
| def depth(token): |
| d = 0 |
| while token.head != token: |
| token = token.head |
| d += 1 |
| return d |
| depths = [depth(token) for token in doc] |
| return np.mean(depths) if depths else 0.0 |
| |
| def _lexical_density(self, doc) -> float: |
| content_pos = {"NOUN", "VERB", "ADJ", "ADV"} |
| content = sum(1 for t in doc if t.pos_ in content_pos) |
| return content / max(len(doc), 1) |
| |
| def extract_raw_features(self, text: str) -> Dict[str, float]: |
| doc = self.nlp(text) |
| sentences = list(doc.sents) |
| words = [t.text for t in doc if not t.is_punct and not t.is_space] |
| word_lengths = [len(w) for w in words] |
| sent_lengths = [len(list(s)) for s in sentences] |
| |
| # Type-Token Ratio |
| unique_words = set(w.lower() for w in words) |
| ttr = len(unique_words) / max(len(words), 1) |
| |
| # Hedging |
| hedging_freq = sum(1 for w in words if w.lower() in HEDGING_WORDS) |
| hedging_per_100 = (hedging_freq / max(len(words), 1)) * 100 |
| |
| # Discourse markers |
| text_lower = text.lower() |
| dm_counts = {dm: text_lower.count(dm) for dm in DISCOURSE_MARKERS} |
| |
| # Nominalisation |
| nom_count = sum(1 for w in words if w.lower().endswith(NOMINALISATION_SUFFIXES)) |
| nom_ratio = nom_count / max(len(words), 1) |
| |
| # Person |
| first_p = sum(1 for t in doc if t.lower_ in {"i", "we", "my", "our", "me", "us"}) |
| third_p = sum(1 for t in doc if t.lower_ in {"he", "she", "they", "it", "his", "her", "their"}) |
| first_ratio = first_p / max(len(words), 1) |
| third_ratio = third_p / max(len(words), 1) |
| |
| # AWL coverage |
| awl_hits = sum(1 for w in words if w.lower() in self.awl) |
| awl_coverage = awl_hits / max(len(words), 1) |
| |
| # Sentence type |
| question_ratio = sum(1 for s in sentences if s.text.strip().endswith("?")) / max(len(sentences), 1) |
| exclaim_ratio = sum(1 for s in sentences if s.text.strip().endswith("!")) / max(len(sentences), 1) |
| |
| raw = { |
| "sent_len_mean": np.mean(sent_lengths) if sent_lengths else 0, |
| "sent_len_std": np.std(sent_lengths) if sent_lengths else 0, |
| "sent_len_skew": float(stats.skew(sent_lengths)) if len(sent_lengths) > 2 else 0, |
| "word_len_mean": np.mean(word_lengths) if word_lengths else 0, |
| "word_len_std": np.std(word_lengths) if word_lengths else 0, |
| "ttr": ttr, |
| "passive_ratio": self._passive_voice_ratio(doc), |
| "active_ratio": 1.0 - self._passive_voice_ratio(doc), |
| "avg_dep_depth": self._avg_dep_tree_depth(doc), |
| "hedging_per_100": hedging_per_100, |
| "nom_ratio": nom_ratio, |
| "lexical_density": self._lexical_density(doc), |
| "question_ratio": question_ratio, |
| "exclaim_ratio": exclaim_ratio, |
| "first_person_ratio": first_ratio, |
| "third_person_ratio": third_ratio, |
| "awl_coverage": awl_coverage, |
| } |
| for dm, count in dm_counts.items(): |
| raw[f"dm_{dm}"] = count / max(len(sentences), 1) |
| |
| return raw |
| |
| def extract_vector(self, text: str) -> torch.Tensor: |
| """Returns a 512-dim style embedding tensor.""" |
| raw = self.extract_raw_features(text) |
| feature_array = np.array(list(raw.values()), dtype=np.float32) |
| # Pad or truncate to expected input_dim |
| expected_dim = 40 |
| if len(feature_array) < expected_dim: |
| feature_array = np.pad(feature_array, (0, expected_dim - len(feature_array))) |
| else: |
| feature_array = feature_array[:expected_dim] |
| feature_tensor = torch.tensor(feature_array).unsqueeze(0) |
| with torch.no_grad(): |
| style_vec = self.projection(feature_tensor) |
| return style_vec.squeeze(0) # [512] |
| |
| def blend_vectors( |
| self, |
| user_vec: torch.Tensor, |
| master_vec: Optional[torch.Tensor], |
| alpha: float = 0.6, |
| ) -> torch.Tensor: |
| """ |
| Blend user style with master copy style. |
| alpha = weight given to user's own style (0.6 = user dominates) |
| (1-alpha) = weight given to master copy style |
| |
| Formula: target = alpha * user_vec + (1 - alpha) * master_vec |
| """ |
| if master_vec is None: |
| return user_vec |
| blended = alpha * user_vec + (1 - alpha) * master_vec |
| # L2 normalise to unit sphere |
| return blended / (blended.norm() + 1e-8) |
| ``` |
|
|
| --- |
|
|
| ## 7. Layer 3 β Core Generation Model |
|
|
| ### Model Selection Decision Tree |
|
|
| ``` |
| Do you have β₯ 40GB VRAM (e.g., A100)? |
| βββ YES β Fine-tune Llama-3.1-8B with LoRA (best quality) |
| βββ NO β Do you have β₯ 16GB VRAM? |
| βββ YES β Fine-tune Flan-T5-XL (3B params, best encoder-decoder) |
| βββ NO β Fine-tune BART-large (400M params, excellent denoiser) |
| OR Flan-T5-Large (780M params) |
| ``` |
|
|
| ### `src/model/base_model.py` |
| |
| ```python |
| """ |
| Loads and wraps the base pretrained model. |
| Supported architectures: |
| - google/flan-t5-xl (recommended, 3B) |
| - google/flan-t5-large (780M, resource-constrained) |
| - facebook/bart-large (400M, excellent denoiser) |
| - meta-llama/Meta-Llama-3.1-8B-Instruct (8B, best quality) |
| """ |
| |
| from transformers import ( |
| AutoTokenizer, AutoModelForSeq2SeqLM, |
| AutoModelForCausalLM, BitsAndBytesConfig |
| ) |
| from peft import get_peft_model, LoraConfig, TaskType |
| import torch |
| from loguru import logger |
| |
| |
| ENCODER_DECODER_MODELS = { |
| "flan-t5-xl": "google/flan-t5-xl", |
| "flan-t5-large": "google/flan-t5-large", |
| "bart-large": "facebook/bart-large", |
| } |
| |
| DECODER_ONLY_MODELS = { |
| "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", |
| } |
| |
| |
| def load_model_and_tokenizer(model_key: str, quantize: bool = False, use_lora: bool = True): |
| |
| is_seq2seq = model_key in ENCODER_DECODER_MODELS |
| model_name = ENCODER_DECODER_MODELS.get(model_key) or DECODER_ONLY_MODELS.get(model_key) |
| |
| if not model_name: |
| raise ValueError(f"Unknown model key: {model_key}") |
| |
| logger.info(f"Loading {model_name} ({'seq2seq' if is_seq2seq else 'causal'})...") |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| # Quantisation config (for large models on limited VRAM) |
| bnb_config = None |
| if quantize: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| |
| if is_seq2seq: |
| model = AutoModelForSeq2SeqLM.from_pretrained( |
| model_name, |
| quantization_config=bnb_config, |
| torch_dtype=torch.bfloat16 if not quantize else None, |
| device_map="auto", |
| ) |
| lora_task = TaskType.SEQ_2_SEQ_LM |
| else: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| quantization_config=bnb_config, |
| torch_dtype=torch.bfloat16 if not quantize else None, |
| device_map="auto", |
| ) |
| lora_task = TaskType.CAUSAL_LM |
| |
| if use_lora: |
| lora_config = LoraConfig( |
| task_type=lora_task, |
| r=16, # LoRA rank β increase for more capacity |
| lora_alpha=32, # Scaling factor (typically 2x rank) |
| target_modules=[ # Modules to apply LoRA to |
| "q_proj", "v_proj", # Attention query and value |
| "k_proj", "o_proj", # Attention key and output |
| "gate_proj", "up_proj", # FFN layers (for T5/Llama) |
| ], |
| lora_dropout=0.05, |
| bias="none", |
| inference_mode=False, |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| return model, tokenizer, is_seq2seq |
| ``` |
| |
| --- |
|
|
| ### `src/model/style_conditioner.py` |
| |
| ```python |
| """ |
| Injects the style vector into the model via soft prompt conditioning. |
| The style vector is projected to the model's hidden dimension and |
| prepended to the input token embeddings as virtual tokens. |
| |
| This technique is called "prefix tuning" / "style prefix injection". |
| It biases the model's attention toward the desired output style |
| without modifying the base model weights. |
| |
| For Flan-T5: injects into encoder input embeddings |
| For BART: injects into encoder input embeddings |
| For Llama: prepends to the full input context |
| """ |
| |
| import torch |
| import torch.nn as nn |
| |
| |
| class StyleConditioner(nn.Module): |
| """ |
| Projects a 512-dim style vector to n_prefix_tokens virtual tokens |
| in the model's embedding space. |
| """ |
| |
| def __init__( |
| self, |
| style_dim: int = 512, |
| model_hidden_dim: int = 2048, # T5-XL hidden size |
| n_prefix_tokens: int = 10, # Number of virtual prefix tokens |
| ): |
| super().__init__() |
| self.n_prefix_tokens = n_prefix_tokens |
| self.projection = nn.Sequential( |
| nn.Linear(style_dim, model_hidden_dim * n_prefix_tokens), |
| nn.Tanh(), |
| ) |
| |
| def forward(self, style_vector: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| style_vector: [batch_size, 512] |
| Returns: |
| prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim] |
| """ |
| batch_size = style_vector.shape[0] |
| projected = self.projection(style_vector) |
| return projected.view(batch_size, self.n_prefix_tokens, -1) |
| |
|
|
| def prepend_style_prefix( |
| input_embeddings: torch.Tensor, |
| style_prefix: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Concatenates style prefix to input embeddings along sequence dimension. |
| |
| Args: |
| input_embeddings: [batch, seq_len, hidden_dim] |
| style_prefix: [batch, n_prefix, hidden_dim] |
| Returns: |
| [batch, n_prefix + seq_len, hidden_dim] |
| """ |
| return torch.cat([style_prefix, input_embeddings], dim=1) |
| ``` |
| |
| --- |
|
|
| ## 8. Layer 4 β Training Data Strategy |
|
|
| ### `src/training/dataset.py` |
|
|
| ```python |
| """ |
| Dataset class that handles all data sources and produces training triplets: |
| (input_text, style_vector, target_text) |
| |
| Data sources priority: |
| 1. W&I+LOCNESS β real learner errors with expert corrections |
| 2. JFLEG β naturalistic fluency corrections |
| 3. GYAFC β informalβformal style transfer |
| 4. Synthetic β dyslexia simulator augmentation on Wikipedia/books |
| 5. Custom β any user-provided correction pairs |
| |
| Each example is structured as: |
| { |
| "input": "<corrupted/informal text>", |
| "target": "<corrected academic text>", |
| "style_vector": [512 floats], |
| "source": "wi_locness | jfleg | gyafc | synthetic | custom", |
| } |
| """ |
| |
| import json |
| from pathlib import Path |
| from typing import List, Dict, Optional |
| import torch |
| from torch.utils.data import Dataset |
| from transformers import PreTrainedTokenizer |
| from ..style.fingerprinter import StyleFingerprinter |
| from ..preprocessing.dyslexia_simulator import DyslexiaSimulator |
| |
| |
| TASK_PREFIX = ( |
| "Correct the following text for grammar, spelling, and clarity. " |
| "Maintain the author's original tone and writing style. " |
| "Elevate vocabulary to academic register. " |
| "Do NOT change the meaning or add new information. " |
| "Preserve named entities exactly. " |
| "Text to correct: " |
| ) |
| |
| |
| class WritingCorrectionDataset(Dataset): |
| |
| def __init__( |
| self, |
| data_path: str, |
| tokenizer: PreTrainedTokenizer, |
| fingerprinter: StyleFingerprinter, |
| max_input_length: int = 512, |
| max_target_length: int = 512, |
| augment_with_synthetic: bool = True, |
| synthetic_ratio: float = 0.3, |
| ): |
| self.tokenizer = tokenizer |
| self.fingerprinter = fingerprinter |
| self.max_input_length = max_input_length |
| self.max_target_length = max_target_length |
| self.examples = self._load(data_path) |
| |
| if augment_with_synthetic: |
| self._add_synthetic(synthetic_ratio) |
| |
| def _load(self, path: str) -> List[Dict]: |
| examples = [] |
| with open(path) as f: |
| for line in f: |
| obj = json.loads(line.strip()) |
| examples.append(obj) |
| return examples |
| |
| def _add_synthetic(self, ratio: float): |
| simulator = DyslexiaSimulator(error_rate=0.15) |
| n_synthetic = int(len(self.examples) * ratio) |
| # Use clean targets as source for simulation |
| synthetic = [] |
| for ex in self.examples[:n_synthetic]: |
| corrupted, clean = simulator.simulate(ex["target"]) |
| synthetic.append({"input": corrupted, "target": clean, "source": "synthetic"}) |
| self.examples.extend(synthetic) |
| |
| def __len__(self): |
| return len(self.examples) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| ex = self.examples[idx] |
| input_text = TASK_PREFIX + ex["input"] |
| target_text = ex["target"] |
| |
| # Compute style vector from the TARGET (we want to learn to match this style) |
| style_vec = self.fingerprinter.extract_vector(target_text) |
| |
| # Tokenise input |
| input_enc = self.tokenizer( |
| input_text, |
| max_length=self.max_input_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
| |
| # Tokenise target |
| target_enc = self.tokenizer( |
| target_text, |
| max_length=self.max_target_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ) |
| |
| labels = target_enc["input_ids"].squeeze() |
| labels[labels == self.tokenizer.pad_token_id] = -100 # Ignore padding in loss |
| |
| return { |
| "input_ids": input_enc["input_ids"].squeeze(), |
| "attention_mask": input_enc["attention_mask"].squeeze(), |
| "labels": labels, |
| "style_vector": style_vec, |
| } |
| ``` |
|
|
| --- |
|
|
| ## 9. Layer 5 β Training Loop & Loss Functions |
|
|
| ### `src/training/loss_functions.py` |
| |
| ```python |
| """ |
| Combined training loss: |
| |
| L_total = L_CE + Ξ»β Β· L_style + Ξ»β Β· L_semantic |
| |
| Where: |
| L_CE = cross-entropy language model loss (standard token prediction) |
| L_style = style consistency loss (cosine distance between output and target style vectors) |
| L_semantic = semantic similarity loss (cosine distance between sentence embeddings) |
| Ξ»β = style loss weight (default 0.3) |
| Ξ»β = semantic loss weight (default 0.5) |
|
|
| L_style: |
| style_sim = cosine_similarity(style_vec(output), style_vec(target)) |
| L_style = 1 - style_sim |
| |
| L_semantic: |
| sem_emb_output = sentence_transformer.encode(output_text) |
| sem_emb_input = sentence_transformer.encode(input_text) |
| sem_sim = cosine_similarity(sem_emb_output, sem_emb_input) |
| L_semantic = 1 - sem_sim |
| (We compare to INPUT meaning β meaning must be preserved, not changed) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sentence_transformers import SentenceTransformer |
| from typing import Optional |
| |
| |
| class CombinedCorrectionLoss(nn.Module): |
| |
| def __init__( |
| self, |
| lambda_style: float = 0.3, |
| lambda_semantic: float = 0.5, |
| sem_model_name: str = "all-mpnet-base-v2", |
| device: str = "cuda", |
| ): |
| super().__init__() |
| self.lambda_style = lambda_style |
| self.lambda_semantic = lambda_semantic |
| self.device = device |
| |
| # Frozen sentence transformer for semantic similarity |
| self.sem_model = SentenceTransformer(sem_model_name, device=device) |
| for param in self.sem_model.parameters(): |
| param.requires_grad = False |
| |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) |
| |
| def _style_loss( |
| self, |
| output_style_vec: torch.Tensor, |
| target_style_vec: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| 1 - cosine_similarity(output_style, target_style) |
| Shape: [batch_size, 512] β scalar |
| """ |
| sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) |
| return (1 - sim).mean() |
| |
| def _semantic_loss( |
| self, |
| input_texts: List[str], |
| output_texts: List[str], |
| ) -> torch.Tensor: |
| """ |
| Penalises meaning change between input and output. |
| Uses frozen sentence-transformer embeddings. |
| """ |
| with torch.no_grad(): |
| input_embs = torch.tensor( |
| self.sem_model.encode(input_texts), device=self.device |
| ) |
| output_embs = torch.tensor( |
| self.sem_model.encode(output_texts), device=self.device |
| ) |
| sim = F.cosine_similarity(input_embs, output_embs, dim=-1) |
| return (1 - sim).mean() |
| |
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| output_style_vec: Optional[torch.Tensor] = None, |
| target_style_vec: Optional[torch.Tensor] = None, |
| input_texts: Optional[List[str]] = None, |
| output_texts: Optional[List[str]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| |
| # Standard cross-entropy loss |
| vocab_size = logits.shape[-1] |
| l_ce = self.ce_loss(logits.view(-1, vocab_size), labels.view(-1)) |
| |
| losses = {"l_ce": l_ce, "total": l_ce} |
| |
| if output_style_vec is not None and target_style_vec is not None: |
| l_style = self._style_loss(output_style_vec, target_style_vec) |
| losses["l_style"] = l_style |
| losses["total"] = losses["total"] + self.lambda_style * l_style |
| |
| if input_texts is not None and output_texts is not None: |
| l_sem = self._semantic_loss(input_texts, output_texts) |
| losses["l_semantic"] = l_sem |
| losses["total"] = losses["total"] + self.lambda_semantic * l_sem |
| |
| return losses |
| ``` |
| |
| --- |
|
|
| ### `src/training/trainer.py` |
|
|
| ```python |
| """ |
| Custom HuggingFace Trainer subclass. |
| Overrides compute_loss to use CombinedCorrectionLoss. |
| """ |
| |
| from transformers import Trainer, TrainingArguments |
| from transformers.trainer_utils import EvalLoopOutput |
| import torch |
| from .loss_functions import CombinedCorrectionLoss |
| import wandb |
| |
| |
| class CorrectionTrainer(Trainer): |
| |
| def __init__(self, loss_fn: CombinedCorrectionLoss, fingerprinter, tokenizer, **kwargs): |
| super().__init__(**kwargs) |
| self.loss_fn = loss_fn |
| self.fingerprinter = fingerprinter |
| self.tokenizer = tokenizer |
| |
| def compute_loss(self, model, inputs, return_outputs=False): |
| style_vectors = inputs.pop("style_vector", None) |
| labels = inputs.get("labels") |
| |
| outputs = model(**inputs) |
| logits = outputs.logits |
| |
| # Decode output tokens to text for semantic + style losses |
| pred_token_ids = logits.argmax(dim=-1) |
| output_texts = self.tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True) |
| input_texts = self.tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True) |
| |
| # Compute output style vectors for batch |
| output_style_vecs = torch.stack([ |
| self.fingerprinter.extract_vector(t) for t in output_texts |
| ]).to(logits.device) |
| |
| loss_dict = self.loss_fn( |
| logits=logits, |
| labels=labels, |
| output_style_vec=output_style_vecs, |
| target_style_vec=style_vectors.to(logits.device) if style_vectors is not None else None, |
| input_texts=input_texts, |
| output_texts=output_texts, |
| ) |
| |
| # Log to W&B |
| if self.state.global_step % 50 == 0: |
| wandb.log({ |
| "loss/ce": loss_dict.get("l_ce", 0).item(), |
| "loss/style": loss_dict.get("l_style", 0).item(), |
| "loss/semantic": loss_dict.get("l_semantic", 0).item(), |
| "loss/total": loss_dict["total"].item(), |
| "step": self.state.global_step, |
| }) |
| |
| return (loss_dict["total"], outputs) if return_outputs else loss_dict["total"] |
| ``` |
|
|
| --- |
|
|
| ### `configs/training_config.yaml` |
| |
| ```yaml |
| model: |
| key: "flan-t5-xl" # flan-t5-xl | flan-t5-large | bart-large | llama-3.1-8b |
| quantize: false # Set true for 4-bit on limited VRAM |
| use_lora: true |
|
|
| lora: |
| r: 16 |
| lora_alpha: 32 |
| lora_dropout: 0.05 |
| target_modules: ["q", "v", "k", "o", "wi_0", "wi_1", "wo"] |
| |
| data: |
| train_path: "data/processed/train.jsonl" |
| val_path: "data/processed/val.jsonl" |
| test_path: "data/processed/test.jsonl" |
| max_input_length: 512 |
| max_target_length: 512 |
| augment_synthetic: true |
| synthetic_ratio: 0.3 |
|
|
| training: |
| output_dir: "checkpoints/" |
| num_train_epochs: 5 |
| per_device_train_batch_size: 8 |
| per_device_eval_batch_size: 16 |
| gradient_accumulation_steps: 4 # Effective batch = 8*4 = 32 |
| learning_rate: 3.0e-4 |
| lr_scheduler_type: "cosine" |
| warmup_ratio: 0.05 |
| weight_decay: 0.01 |
| fp16: false |
| bf16: true # Use bfloat16 on Ampere+ GPUs |
| evaluation_strategy: "steps" |
| eval_steps: 500 |
| save_strategy: "steps" |
| save_steps: 500 |
| save_total_limit: 3 |
| load_best_model_at_end: true |
| metric_for_best_model: "gleu" |
| greater_is_better: true |
| logging_dir: "logs/" |
| logging_steps: 50 |
| report_to: ["wandb", "tensorboard"] |
| dataloader_num_workers: 4 |
| seed: 42 |
| push_to_hub: false |
|
|
| loss: |
| lambda_style: 0.3 |
| lambda_semantic: 0.5 |
| sem_model_name: "all-mpnet-base-v2" |
|
|
| generation: |
| num_beams: 5 |
| length_penalty: 1.0 |
| no_repeat_ngram_size: 3 |
| min_length: 10 |
| max_new_tokens: 512 |
| early_stopping: true |
| ``` |
| |
| --- |
| |
| ## 10. Layer 6 β Academic Vocabulary Control Module |
| |
| ### `src/vocabulary/lexical_substitution.py` |
|
|
| ```python |
| """ |
| Post-generation academic vocabulary elevation module. |
| |
| Pipeline: |
| 1. POS-tag the generated output |
| 2. Identify content words (NOUN, VERB, ADJ, ADV) NOT in AWL |
| 3. For each candidate word, generate AWL-aligned substitutions |
| using BERT masked language model (fill-mask) |
| 4. Apply substitution only if: |
| a. Semantic similarity between original and substitution > threshold |
| b. Substitution is in the AWL |
| c. Substitution does not change sentence meaning |
| 5. Apply register-level post-processing (nominalisation, hedging, passive) |
| |
| AWL = Coxhead Academic Word List (570 word families, ~3,000 lemmas) |
| """ |
| |
| import spacy |
| import torch |
| from transformers import pipeline as hf_pipeline |
| from sentence_transformers import SentenceTransformer |
| import torch.nn.functional as F |
| from typing import List, Dict, Tuple, Optional |
| from .awl_loader import AWLLoader |
| |
| |
| class LexicalElevator: |
| |
| # Words that should NEVER be substituted (structural, functional words) |
| PROTECTED_POS = {"PRON", "DET", "CCONJ", "SCONJ", "ADP", "AUX", "PART", "PUNCT", "NUM"} |
| SEMANTIC_THRESHOLD = 0.82 # Minimum cosine similarity to accept substitution |
| |
| def __init__( |
| self, |
| awl_path: str = "data/awl/coxhead_awl.txt", |
| spacy_model: str = "en_core_web_trf", |
| mlm_model: str = "bert-large-uncased", |
| sem_model: str = "all-mpnet-base-v2", |
| ): |
| self.nlp = spacy.load(spacy_model) |
| self.awl = AWLLoader(awl_path) |
| self.fill_mask = hf_pipeline("fill-mask", model=mlm_model, top_k=10) |
| self.sem_model = SentenceTransformer(sem_model) |
| |
| def _sem_similarity(self, word_a: str, word_b: str, context: str) -> float: |
| """Compute contextual semantic similarity using sentence embeddings.""" |
| ctx_a = context.replace(word_a, word_a, 1) |
| ctx_b = context.replace(word_a, word_b, 1) |
| embs = self.sem_model.encode([ctx_a, ctx_b]) |
| t = torch.tensor(embs) |
| return F.cosine_similarity(t[0].unsqueeze(0), t[1].unsqueeze(0)).item() |
| |
| def _get_awl_substitutions(self, sentence: str, word: str, pos: str) -> List[str]: |
| """Generate candidate substitutions using BERT fill-mask.""" |
| masked = sentence.replace(word, "[MASK]", 1) |
| try: |
| predictions = self.fill_mask(masked) |
| candidates = [p["token_str"].strip() for p in predictions] |
| except Exception: |
| return [] |
| # Filter to AWL words only |
| return [c for c in candidates if self.awl.is_academic(c)] |
| |
| def elevate(self, text: str, protected_spans: List[Tuple[int, int]] = None) -> str: |
| """ |
| Main entry point: elevates vocabulary to academic register. |
| protected_spans: list of (start_char, end_char) that must not be modified. |
| """ |
| doc = self.nlp(text) |
| replacements = {} |
| |
| for sent in doc.sents: |
| sent_text = sent.text |
| for token in sent: |
| # Skip protected tokens |
| if token.pos_ in self.PROTECTED_POS: |
| continue |
| if self.awl.is_academic(token.lemma_): |
| continue # Already academic |
| if protected_spans: |
| if any(s <= token.idx < e for s, e in protected_spans): |
| continue |
| |
| candidates = self._get_awl_substitutions(sent_text, token.text, token.pos_) |
| for candidate in candidates: |
| sim = self._sem_similarity(token.text, candidate, sent_text) |
| if sim >= self.SEMANTIC_THRESHOLD: |
| replacements[token.idx] = (token.text, candidate) |
| break |
| |
| # Apply replacements (reverse order to preserve offsets) |
| result = list(text) |
| for idx in sorted(replacements.keys(), reverse=True): |
| original, replacement = replacements[idx] |
| start = idx |
| end = idx + len(original) |
| result[start:end] = list(replacement) |
| |
| return ''.join(result) |
| |
| |
| class RegisterFilter: |
| """ |
| Applies register-level corrections to ensure academic tone: |
| - Converts contractions to full forms |
| - Ensures hedging where appropriate |
| - Flags over-colloquial phrases for review |
| """ |
| |
| CONTRACTIONS = { |
| "don't": "do not", "can't": "cannot", "won't": "will not", |
| "it's": "it is", "that's": "that is", "there's": "there is", |
| "they're": "they are", "we're": "we are", "you're": "you are", |
| "I'm": "I am", "I've": "I have", "I'll": "I will", |
| "isn't": "is not", "aren't": "are not", "wasn't": "was not", |
| "weren't": "were not", "hasn't": "has not", "haven't": "have not", |
| "couldn't": "could not", "wouldn't": "would not", "shouldn't": "should not", |
| } |
| |
| COLLOQUIAL_TO_ACADEMIC = { |
| "a lot of": "a substantial number of", |
| "lots of": "numerous", |
| "big": "substantial", |
| "get": "obtain", |
| "show": "demonstrate", |
| "use": "utilise", |
| "find out": "ascertain", |
| "look at": "examine", |
| "think about": "consider", |
| "talk about": "discuss", |
| "deal with": "address", |
| "carry out": "conduct", |
| "point out": "indicate", |
| "make sure": "ensure", |
| "come up with": "develop", |
| "go up": "increase", |
| "go down": "decrease", |
| "start": "commence", |
| "end": "conclude", |
| "help": "facilitate", |
| "need": "require", |
| "try": "attempt", |
| "want": "seek", |
| } |
| |
| def apply(self, text: str) -> str: |
| import re |
| for contraction, full in self.CONTRACTIONS.items(): |
| text = re.sub(re.escape(contraction), full, text, flags=re.IGNORECASE) |
| for colloquial, academic in self.COLLOQUIAL_TO_ACADEMIC.items(): |
| text = re.sub(r'\b' + re.escape(colloquial) + r'\b', academic, text, flags=re.IGNORECASE) |
| return text |
| ``` |
|
|
| --- |
|
|
| ## 11. Layer 7 β Evaluation Framework |
|
|
| ### `src/evaluation/gleu_scorer.py` |
| |
| ```python |
| """ |
| GLEU (Generalized Language Evaluation Understanding) score. |
| Preferred over BLEU for grammatical error correction tasks. |
| Designed specifically to handle the GEC task where the reference |
| correction may differ from the source in minimal ways. |
| |
| Also computes BERTScore for semantic similarity evaluation. |
| """ |
| |
| import sacrebleu |
| from bert_score import score as bert_score_fn |
| from typing import List, Tuple |
|
|
|
|
| class GLEUScorer: |
|
|
| def compute_gleu( |
| self, |
| predictions: List[str], |
| references: List[str], |
| ) -> float: |
| """Corpus-level GLEU score.""" |
| result = sacrebleu.corpus_bleu(predictions, [references]) |
| return result.score # 0β100 |
| |
| def compute_bert_score( |
| self, |
| predictions: List[str], |
| references: List[str], |
| lang: str = "en", |
| ) -> Tuple[float, float, float]: |
| """ |
| Returns (precision, recall, F1) as averages over the batch. |
| F1 > 0.9 is generally considered high quality. |
| """ |
| P, R, F1 = bert_score_fn(predictions, references, lang=lang, verbose=False) |
| return P.mean().item(), R.mean().item(), F1.mean().item() |
| ``` |
| |
| --- |
|
|
| ### `src/evaluation/style_metrics.py` |
| |
| ```python |
| """ |
| Measures style preservation between input and output. |
| |
| Key metric: Style Vector Cosine Similarity |
| sim = cosine_similarity(style_vec(input), style_vec(output)) |
| Target: > 0.85 |
|
|
| Key metric: Authorship Verification Score |
| A binary classifier trained to answer: "Was this written by the same author?" |
| Uses a fine-tuned RoBERTa model on authorship verification datasets. |
| Target: > 0.80 (model says same author 80%+ of the time) |
|
|
| Key metric: AWL Coverage Score |
| Fraction of content words from the Academic Word List. |
| Target: > 0.25 (25% of content words should be academic) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from typing import List, Tuple |
| from ..style.fingerprinter import StyleFingerprinter |
| from .awl_loader import AWLLoader |
| |
| |
| class StyleEvaluator: |
| |
| def __init__(self, fingerprinter: StyleFingerprinter, awl: AWLLoader): |
| self.fingerprinter = fingerprinter |
| self.awl = awl |
| |
| def style_similarity(self, text_a: str, text_b: str) -> float: |
| """Cosine similarity between style vectors. Target: > 0.85""" |
| vec_a = self.fingerprinter.extract_vector(text_a) |
| vec_b = self.fingerprinter.extract_vector(text_b) |
| return F.cosine_similarity(vec_a.unsqueeze(0), vec_b.unsqueeze(0)).item() |
| |
| def awl_coverage(self, text: str) -> float: |
| """Fraction of content words in AWL. Target: > 0.25""" |
| import spacy |
| nlp = spacy.load("en_core_web_sm") |
| doc = nlp(text) |
| content_words = [t.lemma_.lower() for t in doc if t.pos_ in {"NOUN", "VERB", "ADJ", "ADV"}] |
| if not content_words: |
| return 0.0 |
| return sum(1 for w in content_words if self.awl.is_academic(w)) / len(content_words) |
| |
| def evaluate_batch( |
| self, |
| inputs: List[str], |
| outputs: List[str], |
| references: List[str], |
| ) -> dict: |
| style_sims = [self.style_similarity(i, o) for i, o in zip(inputs, outputs)] |
| awl_scores = [self.awl_coverage(o) for o in outputs] |
| return { |
| "style_similarity_mean": sum(style_sims) / len(style_sims), |
| "style_similarity_min": min(style_sims), |
| "awl_coverage_mean": sum(awl_scores) / len(awl_scores), |
| } |
| ``` |
| |
| --- |
|
|
| ## 12. Layer 8 β Inference Pipeline |
|
|
| ### `src/inference/corrector.py` |
|
|
| ```python |
| """ |
| End-to-end inference pipeline. |
| Accepts raw dyslectic text (and optionally a master copy), |
| returns corrected academic text with metadata. |
| """ |
| |
| from ..preprocessing.pipeline import PreprocessingPipeline |
| from ..style.fingerprinter import StyleFingerprinter |
| from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter |
| from ..model.base_model import load_model_and_tokenizer |
| from ..model.style_conditioner import StyleConditioner, prepend_style_prefix |
| import torch |
| from typing import Optional |
| from dataclasses import dataclass |
| |
| |
| TASK_PREFIX = ( |
| "Correct the following text for grammar, spelling, and clarity. " |
| "Maintain the author's original tone and writing style. " |
| "Elevate vocabulary to academic register. " |
| "Do NOT change the meaning or add new information. " |
| "Preserve named entities exactly. " |
| "Text to correct: " |
| ) |
| |
| |
| @dataclass |
| class CorrectionResult: |
| original: str |
| corrected: str |
| preprocessed: str |
| style_similarity: float |
| awl_coverage: float |
| readability: dict |
| changes_summary: str |
| |
| |
| class AcademicCorrector: |
| |
| def __init__(self, config: dict): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( |
| config["model"]["key"], |
| quantize=config["model"].get("quantize", False), |
| use_lora=False, # Inference: use merged weights |
| ) |
| self.model.eval() |
| self.preprocessor = PreprocessingPipeline() |
| self.fingerprinter = StyleFingerprinter() |
| self.conditioner = StyleConditioner( |
| style_dim=512, |
| model_hidden_dim=config.get("model_hidden_dim", 2048), |
| n_prefix_tokens=10, |
| ).to(self.device) |
| self.elevator = LexicalElevator() |
| self.register_filter = RegisterFilter() |
| self.gen_config = config.get("generation", {}) |
| |
| def correct( |
| self, |
| raw_text: str, |
| master_copy: Optional[str] = None, |
| style_alpha: float = 0.6, |
| ) -> CorrectionResult: |
| |
| # Step 1: Pre-process |
| doc = self.preprocessor.process(raw_text) |
| |
| # Step 2: Style fingerprinting |
| user_style = self.fingerprinter.extract_vector(doc.corrected_text) |
| master_style = self.fingerprinter.extract_vector(master_copy) if master_copy else None |
| target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha) |
| |
| # Step 3: Prepare input |
| input_text = TASK_PREFIX + doc.corrected_text |
| inputs = self.tokenizer( |
| input_text, |
| return_tensors="pt", |
| max_length=512, |
| truncation=True, |
| ).to(self.device) |
| |
| # Step 4: Generate with style conditioning |
| style_prefix = self.conditioner(target_style.unsqueeze(0).to(self.device)) |
| |
| with torch.no_grad(): |
| generated = self.model.generate( |
| **inputs, |
| num_beams=self.gen_config.get("num_beams", 5), |
| length_penalty=self.gen_config.get("length_penalty", 1.0), |
| no_repeat_ngram_size=self.gen_config.get("no_repeat_ngram_size", 3), |
| max_new_tokens=self.gen_config.get("max_new_tokens", 512), |
| early_stopping=True, |
| ) |
| |
| draft = self.tokenizer.decode(generated[0], skip_special_tokens=True) |
| |
| # Step 5: Academic vocabulary elevation |
| elevated = self.elevator.elevate(draft, protected_spans=doc.protected_spans) |
| |
| # Step 6: Register filter (contractions, colloquialisms) |
| final = self.register_filter.apply(elevated) |
| |
| # Step 7: Compute metrics |
| from ..evaluation.style_metrics import StyleEvaluator |
| from ..vocabulary.awl_loader import AWLLoader |
| evaluator = StyleEvaluator(self.fingerprinter, AWLLoader()) |
| style_sim = evaluator.style_similarity(raw_text, final) |
| awl_cov = evaluator.awl_coverage(final) |
| |
| return CorrectionResult( |
| original=raw_text, |
| corrected=final, |
| preprocessed=doc.corrected_text, |
| style_similarity=round(style_sim, 3), |
| awl_coverage=round(awl_cov, 3), |
| readability=doc.readability, |
| changes_summary=f"Style similarity: {style_sim:.2%} | AWL coverage: {awl_cov:.2%}", |
| ) |
| ``` |
|
|
| --- |
|
|
| ## 13. Layer 9 β API Server |
|
|
| ### `src/api/main.py` |
|
|
| ```python |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from .schemas import CorrectionRequest, CorrectionResponse |
| from ..inference.corrector import AcademicCorrector |
| import yaml |
| |
| app = FastAPI( |
| title="Dyslexia Academic Writing Corrector API", |
| description="Style-preserving grammar correction and academic vocabulary elevation.", |
| version="1.0.0", |
| ) |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| |
| with open("configs/inference_config.yaml") as f: |
| config = yaml.safe_load(f) |
| |
| corrector = AcademicCorrector(config) |
| |
| |
| @app.post("/correct", response_model=CorrectionResponse) |
| async def correct_text(request: CorrectionRequest): |
| try: |
| result = corrector.correct( |
| raw_text=request.text, |
| master_copy=request.master_copy, |
| style_alpha=request.style_alpha, |
| ) |
| return CorrectionResponse( |
| original=result.original, |
| corrected=result.corrected, |
| style_similarity=result.style_similarity, |
| awl_coverage=result.awl_coverage, |
| readability=result.readability, |
| changes_summary=result.changes_summary, |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
| |
| |
| @app.get("/health") |
| async def health(): |
| return {"status": "ok", "model": config["model"]["key"]} |
| ``` |
|
|
| --- |
|
|
| ### `src/api/schemas.py` |
|
|
| ```python |
| from pydantic import BaseModel, Field |
| from typing import Optional, Dict |
| |
| |
| class CorrectionRequest(BaseModel): |
| text: str = Field(..., min_length=10, max_length=5000, description="Raw dyslectic text to correct.") |
| master_copy: Optional[str] = Field(None, description="Optional master copy to match style toward.") |
| style_alpha: float = Field(0.6, ge=0.0, le=1.0, description="Weight given to user's own style (0=full master, 1=full user).") |
| |
| |
| class CorrectionResponse(BaseModel): |
| original: str |
| corrected: str |
| style_similarity: float |
| awl_coverage: float |
| readability: Dict[str, float] |
| changes_summary: str |
| ``` |
|
|
| --- |
|
|
| ## 14. Layer 10 β Configuration Files |
|
|
| ### `configs/model_config.yaml` |
| |
| ```yaml |
| model: |
| key: "flan-t5-xl" |
| checkpoint_path: "checkpoints/best_model" |
| quantize: false |
| use_lora: true |
| model_hidden_dim: 2048 # flan-t5-xl hidden size |
| # model_hidden_dim: 1024 # flan-t5-large |
| # model_hidden_dim: 1024 # bart-large |
| # model_hidden_dim: 4096 # llama-3.1-8b |
|
|
| style_conditioner: |
| style_dim: 512 |
| n_prefix_tokens: 10 |
|
|
| fingerprinter: |
| spacy_model: "en_core_web_trf" |
| awl_path: "data/awl/coxhead_awl.txt" |
| projection_hidden_dim: 256 |
| projection_output_dim: 512 |
|
|
| generation: |
| num_beams: 5 |
| length_penalty: 1.0 |
| no_repeat_ngram_size: 3 |
| min_length: 10 |
| max_new_tokens: 512 |
| early_stopping: true |
| temperature: 0.7 # Slight randomness for naturalness |
| do_sample: false # Beam search by default |
|
|
| vocabulary: |
| awl_path: "data/awl/coxhead_awl.txt" |
| mlm_model: "bert-large-uncased" |
| sem_model: "all-mpnet-base-v2" |
| semantic_threshold: 0.82 |
| ``` |
| |
| --- |
| |
| ### `configs/awl_config.yaml` |
|
|
| ```yaml |
| awl: |
| primary: "data/awl/coxhead_awl.txt" |
| supplementary: |
| - "data/awl/domain_lexicons/humanities.txt" |
| - "data/awl/domain_lexicons/sciences.txt" |
| - "data/awl/domain_lexicons/social_sciences.txt" |
| academic_synonyms: "data/awl/academic_synonyms.json" |
| |
| register: |
| expand_contractions: true |
| replace_colloquialisms: true |
| enforce_third_person_academic: false # Keep user's voice (don't force "one") |
| minimum_formality_score: 0.65 |
| ``` |
|
|
| --- |
|
|
| ## 15. Layer 11 β Full Training Run Sequence |
|
|
| ### `scripts/train.py` |
|
|
| ```python |
| """ |
| Full training entry point. |
| Run: python scripts/train.py --config configs/training_config.yaml |
| """ |
| |
| import click |
| import yaml |
| import wandb |
| import torch |
| from transformers import TrainingArguments |
| from torch.utils.data import random_split |
| |
| from src.model.base_model import load_model_and_tokenizer |
| from src.model.style_conditioner import StyleConditioner |
| from src.training.dataset import WritingCorrectionDataset |
| from src.training.loss_functions import CombinedCorrectionLoss |
| from src.training.trainer import CorrectionTrainer |
| from src.style.fingerprinter import StyleFingerprinter |
| from src.evaluation.gleu_scorer import GLEUScorer |
| |
| |
| @click.command() |
| @click.option("--config", default="configs/training_config.yaml") |
| def train(config: str): |
| with open(config) as f: |
| cfg = yaml.safe_load(f) |
| |
| wandb.init(project="dyslexia-writing-ai", config=cfg) |
| |
| # Load model |
| model, tokenizer, is_seq2seq = load_model_and_tokenizer( |
| cfg["model"]["key"], |
| quantize=cfg["model"].get("quantize", False), |
| use_lora=cfg["model"].get("use_lora", True), |
| ) |
| |
| # Fingerprinter |
| fingerprinter = StyleFingerprinter( |
| awl_path=cfg["data"].get("awl_path", "data/awl/coxhead_awl.txt") |
| ) |
| |
| # Datasets |
| train_dataset = WritingCorrectionDataset( |
| data_path=cfg["data"]["train_path"], |
| tokenizer=tokenizer, |
| fingerprinter=fingerprinter, |
| max_input_length=cfg["data"]["max_input_length"], |
| max_target_length=cfg["data"]["max_target_length"], |
| augment_with_synthetic=cfg["data"]["augment_synthetic"], |
| synthetic_ratio=cfg["data"]["synthetic_ratio"], |
| ) |
| val_dataset = WritingCorrectionDataset( |
| data_path=cfg["data"]["val_path"], |
| tokenizer=tokenizer, |
| fingerprinter=fingerprinter, |
| augment_with_synthetic=False, |
| ) |
| |
| # Loss function |
| loss_fn = CombinedCorrectionLoss( |
| lambda_style=cfg["loss"]["lambda_style"], |
| lambda_semantic=cfg["loss"]["lambda_semantic"], |
| sem_model_name=cfg["loss"]["sem_model_name"], |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
| |
| # Training arguments |
| training_args = TrainingArguments( |
| output_dir=cfg["training"]["output_dir"], |
| num_train_epochs=cfg["training"]["num_train_epochs"], |
| per_device_train_batch_size=cfg["training"]["per_device_train_batch_size"], |
| per_device_eval_batch_size=cfg["training"]["per_device_eval_batch_size"], |
| gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"], |
| learning_rate=cfg["training"]["learning_rate"], |
| lr_scheduler_type=cfg["training"]["lr_scheduler_type"], |
| warmup_ratio=cfg["training"]["warmup_ratio"], |
| weight_decay=cfg["training"]["weight_decay"], |
| bf16=cfg["training"]["bf16"], |
| fp16=cfg["training"]["fp16"], |
| evaluation_strategy=cfg["training"]["evaluation_strategy"], |
| eval_steps=cfg["training"]["eval_steps"], |
| save_strategy=cfg["training"]["save_strategy"], |
| save_steps=cfg["training"]["save_steps"], |
| save_total_limit=cfg["training"]["save_total_limit"], |
| load_best_model_at_end=cfg["training"]["load_best_model_at_end"], |
| logging_steps=cfg["training"]["logging_steps"], |
| report_to=cfg["training"]["report_to"], |
| dataloader_num_workers=cfg["training"]["dataloader_num_workers"], |
| seed=cfg["training"]["seed"], |
| ) |
| |
| trainer = CorrectionTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| loss_fn=loss_fn, |
| fingerprinter=fingerprinter, |
| tokenizer=tokenizer, |
| ) |
| |
| trainer.train() |
| trainer.save_model(cfg["training"]["output_dir"] + "/final") |
| wandb.finish() |
| |
| |
| if __name__ == "__main__": |
| train() |
| ``` |
|
|
| --- |
|
|
| ### `scripts/download_datasets.sh` |
| |
| ```bash |
| #!/bin/bash |
| # Download all training data sources |
| |
| mkdir -p data/raw/wi_locness data/raw/jfleg data/raw/gyafc data/raw/custom_dyslexia |
| |
| # W&I+LOCNESS (Cambridge Grammar Error Correction) |
| # Requires registration at: https://www.cl.cam.ac.uk/research/nl/bea2019st/ |
| echo "W&I+LOCNESS: Download manually from https://www.cl.cam.ac.uk/research/nl/bea2019st/" |
| echo "Place files in data/raw/wi_locness/" |
|
|
| # JFLEG (JHU Fluency-Extended GUG) |
| git clone https://github.com/keisks/jfleg.git data/raw/jfleg_repo |
| cp data/raw/jfleg_repo/test/*.src data/raw/jfleg/ |
| cp data/raw/jfleg_repo/test/*.ref* data/raw/jfleg/ |
|
|
| # GYAFC (Formality Corpus - Yahoo Answers) |
| # Requires request from Grammarly: https://github.com/raosudha89/GYAFC-corpus |
| echo "GYAFC: Request access at https://github.com/raosudha89/GYAFC-corpus" |
| echo "Place files in data/raw/gyafc/" |
|
|
| # Download Coxhead Academic Word List |
| curl -o data/awl/coxhead_awl.txt \ |
| "https://www.victoria.ac.nz/lals/resources/academicwordlist/sublists/Sublist_1.txt" |
|
|
| echo "Dataset download complete. Check manually downloaded datasets." |
| ``` |
| |
| ### `scripts/preprocess_data.py` |
| |
| ```python |
| """ |
| Converts all raw dataset formats into unified JSONL training format. |
| Output schema per line: |
| {"input": "...", "target": "...", "source": "wi_locness|jfleg|gyafc|synthetic"} |
| """ |
| import json |
| import os |
| from pathlib import Path |
| from src.preprocessing.dyslexia_simulator import DyslexiaSimulator |
|
|
|
|
| def process_jfleg(raw_dir: str, out_file): |
| """JFLEG: .src files (original) and .ref0..ref3 (4 human corrections).""" |
| src_files = list(Path(raw_dir).glob("*.src")) |
| for src_file in src_files: |
| refs = [src_file.with_suffix(f".ref{i}") for i in range(4)] |
| with open(src_file) as sf: |
| src_lines = sf.readlines() |
| for ref_path in refs: |
| if ref_path.exists(): |
| with open(ref_path) as rf: |
| ref_lines = rf.readlines() |
| for src, ref in zip(src_lines, ref_lines): |
| src, ref = src.strip(), ref.strip() |
| if src and ref and src != ref: |
| out_file.write(json.dumps({"input": src, "target": ref, "source": "jfleg"}) + "\n") |
| |
|
|
| def process_gyafc(raw_dir: str, out_file): |
| """GYAFC: informal/ and formal/ subdirectories with parallel files.""" |
| for domain in ["Entertainment_Music", "Family_Relationships"]: |
| for split in ["train", "tune", "test"]: |
| informal = Path(raw_dir) / domain / split / "informal" |
| formal = Path(raw_dir) / domain / split / "formal.ref0" |
| if informal.exists() and formal.exists(): |
| with open(informal) as inf_f, open(formal) as form_f: |
| for inf_line, form_line in zip(inf_f, form_f): |
| inf_line, form_line = inf_line.strip(), form_line.strip() |
| if inf_line and form_line: |
| out_file.write(json.dumps({"input": inf_line, "target": form_line, "source": "gyafc"}) + "\n") |
| |
|
|
| def main(): |
| os.makedirs("data/processed", exist_ok=True) |
| with open("data/processed/train.jsonl", "w") as out: |
| process_jfleg("data/raw/jfleg", out) |
| process_gyafc("data/raw/gyafc", out) |
| # Add W&I+LOCNESS processing here when available |
| print("Preprocessing complete.") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| ``` |
| |
| --- |
|
|
| ## 16. Mathematical Formulations |
|
|
| ### Total Training Loss |
|
|
| ``` |
| L_total = L_CE + Ξ»β Β· L_style + Ξ»β Β· L_semantic |
| |
| L_CE = -Ξ£ log P(y_t | y_{<t}, x) [cross-entropy over tokens] |
| |
| L_style = 1 - cos(f(Ε·), f(y)) [style vector cosine distance] |
| where f(Β·) = StyleFingerprinter.extract_vector(Β·) |
| cos(a,b) = (aΒ·b) / (βaββbβ) |
| |
| L_semantic = 1 - cos(g(x), g(Ε·)) [semantic distance from input] |
| where g(Β·) = SentenceTransformer.encode(Β·) |
| (frozen during training) |
| |
| Ξ»β = 0.3 (style weight) |
| Ξ»β = 0.5 (semantic weight) |
| ``` |
|
|
| ### Style Vector Blending |
|
|
| ``` |
| target_style = Ξ± Β· v_user + (1-Ξ±) Β· v_master |
| target_style = target_style / βtarget_styleβ [L2 normalise to unit sphere] |
| |
| Ξ± = 0.6 (user's style dominates by default) |
| ``` |
|
|
| ### LoRA Weight Update |
|
|
| ``` |
| W = Wβ + ΞW |
| ΞW = B Β· A |
| |
| Where: |
| Wβ β R^{dΓk} (frozen pretrained weights) |
| A β R^{rΓk} (trainable, r << d, initialised from N(0, ΟΒ²)) |
| B β R^{dΓr} (trainable, initialised to zero) |
| r = 16 (rank hyperparameter) |
| |
| Effective weight: W_eff = Wβ + (Ξ±/r) Β· BΒ·A |
| where Ξ± = lora_alpha = 32 |
| ``` |
|
|
| ### Style Similarity Evaluation Metric |
|
|
| ``` |
| StyleSim(input, output) = cos(f(input), f(output)) |
| = (f(input) Β· f(output)) / (βf(input)β Β· βf(output)β) |
| |
| Target: StyleSim > 0.85 |
| Acceptable minimum: StyleSim > 0.75 |
| ``` |
|
|
| ### AWL Coverage Score |
|
|
| ``` |
| AWL_Coverage(text) = |{w β content_words(text) : lemma(w) β AWL}| |
| βββββββββββββββββββββββββββββββββββββββββββββ |
| |content_words(text)| |
| |
| content_words = {w : POS(w) β {NOUN, VERB, ADJ, ADV}} |
| Target: AWL_Coverage > 0.25 |
| ``` |
|
|
| --- |
|
|
| ## 17. Hyperparameter Reference |
|
|
| | Hyperparameter | Value | Rationale | |
| |---|---|---| |
| | LoRA rank (r) | 16 | Balances capacity vs. parameter efficiency | |
| | LoRA alpha | 32 | Standard 2x rank scaling | |
| | LoRA dropout | 0.05 | Light regularisation | |
| | Learning rate | 3e-4 | Standard for LoRA fine-tuning | |
| | LR scheduler | cosine | Smooth decay, avoids sharp LR drops | |
| | Warmup ratio | 0.05 | 5% of steps for warmup | |
| | Batch size (device) | 8 | Per GPU | |
| | Gradient accumulation | 4 | Effective batch = 32 | |
| | Training epochs | 5 | Sufficient for fine-tuning on GEC data | |
| | Ξ»β (style weight) | 0.3 | Strong style signal without dominating CE | |
| | Ξ»β (semantic weight) | 0.5 | Meaning preservation is critical | |
| | Style blend Ξ± | 0.6 | User style dominates over master copy | |
| | Style prefix tokens | 10 | Virtual prefix length | |
| | Beam search beams | 5 | Quality vs. speed tradeoff | |
| | No-repeat ngram | 3 | Prevents repetition in output | |
| | Semantic threshold | 0.82 | For lexical substitution acceptance | |
| | Max input tokens | 512 | T5/BART context window | |
| | Style projection dim | 512 | Rich enough to capture style nuance | |
|
|
| --- |
|
|
| ## 18. Dataset Sources & Download Instructions |
|
|
| | Dataset | Size | Task | Access | URL | |
| |---|---|---|---|---| |
| | W&I+LOCNESS | ~35k pairs | Grammar error correction | Free registration | https://www.cl.cam.ac.uk/research/nl/bea2019st/ | |
| | JFLEG | ~1.5k pairs | Fluency correction | Public GitHub | https://github.com/keisks/jfleg | |
| | GYAFC | ~105k pairs | Formality transfer | Request from Grammarly | https://github.com/raosudha89/GYAFC-corpus | |
| | CoNLL-2014 | ~1.3k pairs | Grammar correction | Public | https://www.comp.nus.edu.sg/~nlp/conll14st.html | |
| | FCE Corpus | ~33k pairs | Learner English | Free registration | https://ilexir.co.uk/datasets/index.html | |
| | WikiAtomic | Millions | Style transfer | Public | https://huggingface.co/datasets/wiki_atomic_edits | |
| | Synthetic (generated) | Unlimited | Dyslexia simulation | Self-generated | `scripts/preprocess_data.py` | |
|
|
| --- |
|
|
| ## 19. Hardware Requirements |
|
|
| ### Minimum (Development / Testing) |
|
|
| ``` |
| CPU: 8-core, e.g., Intel i7 / AMD Ryzen 7 |
| RAM: 32 GB |
| GPU: NVIDIA RTX 3090 (24 GB VRAM) β Fine-tune Flan-T5-Large or BART-large |
| SSD: 500 GB NVMe |
| Model: Flan-T5-Large (780M) or BART-large (400M) |
| Quantize: false |
| ``` |
|
|
| ### Recommended (Production Training) |
|
|
| ``` |
| CPU: 16-core+ |
| RAM: 64 GB |
| GPU: NVIDIA A100 80 GB OR 2Γ RTX 4090 (48 GB total) |
| SSD: 2 TB NVMe |
| Model: Flan-T5-XL (3B) with LoRA |
| Quantize: false |
| Training time: ~12 hours on A100 |
| ``` |
|
|
| ### Maximum Quality |
|
|
| ``` |
| GPU: 4Γ A100 80 GB (320 GB total VRAM) |
| Model: Llama-3.1-8B with LoRA |
| Training time: ~24-48 hours |
| Use: torchrun --nproc_per_node=4 scripts/train.py |
| ``` |
|
|
| ### Cloud Options |
|
|
| ``` |
| AWS: p3.2xlarge (V100 16GB) β BART-large only |
| p3.8xlarge (4Γ V100 64GB) β Flan-T5-XL |
| p4d.24xlarge (8Γ A100) β Llama-3.1-8B |
| |
| GCP: n1-standard-8 + 1Γ A100 β Flan-T5-XL |
| a2-highgpu-4g (4Γ A100) β Llama-3.1-8B |
| |
| Lambda Labs: 1Γ A100 ~$1.10/hr β Most cost-effective |
| RunPod: 1Γ A100 ~$0.99/hr β Alternative |
| ``` |
|
|
| --- |
|
|
| ## 20. Testing Suite |
|
|
| ### `tests/test_preprocessing.py` |
| |
| ```python |
| import pytest |
| from src.preprocessing.pipeline import PreprocessingPipeline |
| from src.preprocessing.dyslexia_simulator import DyslexiaSimulator |
|
|
|
|
| @pytest.fixture |
| def pipeline(): |
| return PreprocessingPipeline() |
| |
|
|
| def test_spell_correction(pipeline): |
| result = pipeline.process("i wuz going to the store but cud not find it") |
| assert "was" in result.corrected_text |
| assert "could" in result.corrected_text |
| |
|
|
| def test_entity_protection(pipeline): |
| result = pipeline.process("John Smith livd in London.") |
| entities = [e.text for e in result.entities] |
| assert any("John" in e or "London" in e for e in entities) |
| |
|
|
| def test_sentence_segmentation(pipeline): |
| result = pipeline.process("I went to school. I lerned a lot.") |
| assert len(result.sentences) == 2 |
| |
|
|
| def test_dyslexia_simulator(): |
| sim = DyslexiaSimulator(error_rate=1.0, seed=0) |
| corrupted, clean = sim.simulate("The quick brown fox jumps over the lazy dog.") |
| assert corrupted != clean |
| assert clean == "The quick brown fox jumps over the lazy dog." |
| ``` |
| |
| ### `tests/test_style.py` |
| |
| ```python |
| import pytest |
| import torch |
| from src.style.fingerprinter import StyleFingerprinter |
| |
| |
| @pytest.fixture |
| def fingerprinter(tmp_path): |
| awl = tmp_path / "awl.txt" |
| awl.write_text("analysis\nconsider\nestablish\nsignificant\n") |
| return StyleFingerprinter(spacy_model="en_core_web_sm", awl_path=str(awl)) |
| |
|
|
| def test_style_vector_shape(fingerprinter): |
| vec = fingerprinter.extract_vector("The quick brown fox jumps over the lazy dog.") |
| assert vec.shape == (512,) |
| |
|
|
| def test_style_vector_different_texts(fingerprinter): |
| formal = "The analysis demonstrates significant implications for the field." |
| informal = "So basically it shows that this stuff really matters a lot lol." |
| vec_formal = fingerprinter.extract_vector(formal) |
| vec_informal = fingerprinter.extract_vector(informal) |
| # Vectors should be different |
| assert not torch.allclose(vec_formal, vec_informal) |
| |
|
|
| def test_style_blend(fingerprinter): |
| vec_a = fingerprinter.extract_vector("Short punchy text here.") |
| vec_b = fingerprinter.extract_vector("Elaborate and comprehensive academic discourse.") |
| blended = fingerprinter.blend_vectors(vec_a, vec_b, alpha=0.5) |
| assert blended.shape == (512,) |
| # Blended should be unit vector |
| assert abs(blended.norm().item() - 1.0) < 1e-4 |
| ``` |
| |
| --- |
|
|
| ## Quick Start Execution Order |
|
|
| ```bash |
| # 1. Setup environment |
| python -m venv venv && source venv/bin/activate |
| pip install -r requirements.txt |
| python -m spacy download en_core_web_trf |
| |
| # 2. Download datasets |
| bash scripts/download_datasets.sh |
| |
| # 3. Preprocess all data into unified format |
| python scripts/preprocess_data.py |
| |
| # 4. Run tests to verify setup |
| pytest tests/ -v |
| |
| # 5. Launch training |
| python scripts/train.py --config configs/training_config.yaml |
| |
| # 6. Evaluate on test set |
| python scripts/evaluate.py --config configs/training_config.yaml --split test |
| |
| # 7. Start inference API |
| uvicorn src.api.main:app --host 0.0.0.0 --port 8000 --reload |
| |
| # 8. Test the API |
| curl -X POST http://localhost:8000/correct \ |
| -H "Content-Type: application/json" \ |
| -d '{"text": "i went to the store but cud not find wat i was loking for", "style_alpha": 0.6}' |
| ``` |
|
|
| --- |
|
|
| --- |
|
|
| ## 21. Human-Pattern Anti-AI Training Layer |
|
|
| ### The Core Principle |
|
|
| These two Kaggle datasets are **not used to build a detector**. They are used to teach the model the statistical and linguistic signature of human writing, and to penalise the model when its output drifts toward AI-typical patterns. The training signal flows in one direction: **reward human-like writing, penalise AI-like writing**. |
|
|
| This is implemented as an additional loss term β `L_human_pattern` β added to the combined loss from Layer 5. The model learns what human writing looks and feels like at a statistical level, and is penalised during training whenever its generated corrections exhibit the same surface patterns that distinguish AI-generated text from human text in these datasets. |
|
|
| --- |
|
|
| ### Dataset 1 β shanegerami/ai-vs-human-text |
|
|
| ``` |
| Source: https://www.kaggle.com/datasets/shanegerami/ai-vs-human-text |
| Size: ~500,000 essays |
| Format: CSV β two columns |
| Columns: |
| text (str) Full essay text |
| generated (int) 0 = human-written | 1 = AI-generated |
| |
| Human count: 305,797 essays |
| AI count: ~194,203 essays (GPT-family generated) |
| Content type: Academic essays across diverse topics |
| File: train_essays.csv |
| |
| HuggingFace mirror (already split, ~400k rows, use this for convenience): |
| andythetechnerd03/AI-human-text |
| Load: datasets.load_dataset("andythetechnerd03/AI-human-text") |
| ``` |
|
|
| ### Dataset 2 β starblasters8/human-vs-llm-text-corpus |
|
|
| ``` |
| Source: https://www.kaggle.com/datasets/starblasters8/human-vs-llm-text-corpus |
| Size: ~800,000 texts |
| Format: Parquet β data.parquet |
| Columns: |
| text (str) Full text |
| label (str) "Human" | <LLM model name> (63 different LLMs represented) |
| |
| Key feature: covers 63 DIFFERENT LLMs β not just GPT. Includes outputs from |
| Llama, Mistral, Falcon, Claude, Gemini, PaLM, Vicuna, Alpaca, and many others. |
| This is critical: the model learns what AI text looks like ACROSS the LLM landscape, |
| not just from one model family. |
| |
| File: data.parquet |
| Read: pd.read_parquet("data/raw/starblasters8/data.parquet") |
| ``` |
|
|
| --- |
|
|
| ### `scripts/download_kaggle_datasets.sh` |
|
|
| ```bash |
| #!/bin/bash |
| # Requires: pip install kaggle |
| # Setup: Place kaggle.json API key at ~/.kaggle/kaggle.json |
| # Get key: kaggle.com β Account β Create New API Token |
| |
| mkdir -p data/raw/shanegerami data/raw/starblasters8 |
| |
| # Dataset 1: AI vs Human Text (500K essays) |
| kaggle datasets download -d shanegerami/ai-vs-human-text \ |
| -p data/raw/shanegerami --unzip |
| |
| # Dataset 2: Human vs LLM Text Corpus (800K, 63 LLMs) |
| kaggle datasets download -d starblasters8/human-vs-llm-text-corpus \ |
| -p data/raw/starblasters8 --unzip |
| |
| echo "Both datasets downloaded." |
| echo "Dataset 1 (CSV): data/raw/shanegerami/train_essays.csv" |
| echo "Dataset 2 (Parquet): data/raw/starblasters8/data.parquet" |
| ``` |
|
|
| --- |
|
|
| ### `src/training/human_pattern_extractor.py` |
|
|
| ```python |
| """ |
| Extracts the statistical signature of human writing vs AI writing. |
| Uses the two Kaggle datasets to build: |
| |
| 1. HumanPatternProfile β a statistical distribution of human writing features |
| 2. AIPatternProfile β a statistical distribution of AI writing features |
| 3. HumanPatternClassifier β a lightweight FROZEN classifier used at training time |
| to score how "human-like" the model's output looks. |
| |
| The classifier is FROZEN during main model training. It is pre-trained separately |
| on the Kaggle datasets, then its output score is used as a reward/penalty signal |
| in the main training loss. |
| |
| Feature set extracted (same dimensions as StyleFingerprinter + additional): |
| - All 40 StyleFingerprinter features |
| - Perplexity under GPT-2 (AI text tends to be lower perplexity) |
| - Burstiness score (human writing has more sentence length variance) |
| - Lexical diversity (AI text has narrower vocab distributions) |
| - Punctuation density patterns (AI overuses certain patterns) |
| - Discourse marker overuse (AI overuses "Furthermore", "Moreover", "Additionally") |
| - Sentence starter diversity (AI repeats sentence openers more) |
| - n-gram novelty score (AI repeats common n-grams more) |
| - Hedging vs certainty ratio (AI is overconfident OR over-hedges in detectable ways) |
| - Paragraph cohesion score (AI has unnaturally perfect paragraph transitions) |
| """ |
| |
| import pandas as pd |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| from sklearn.model_selection import train_test_split |
| from sklearn.preprocessing import StandardScaler |
| from typing import List, Tuple, Dict |
| import spacy |
| from collections import Counter |
| import math |
| |
| |
| # ββ AI-Typical Overused Discourse Markers βββββββββββββββββββββββββββββββββββ |
| AI_OVERUSED_MARKERS = { |
| "furthermore", "moreover", "additionally", "consequently", |
| "in conclusion", "to summarize", "it is worth noting", |
| "it is important to note", "in today's world", "in today's society", |
| "in the modern era", "as previously mentioned", "needless to say", |
| "it goes without saying", "at the end of the day", |
| "in terms of", "with regard to", "with respect to", |
| "delve", "leverage", "utilize", "holistic", "paradigm", |
| "transformative", "groundbreaking", "revolutionary", "game-changing", |
| "multifaceted", "nuanced", "comprehensive", "robust", "seamless", |
| "innovative", "synergy", "cutting-edge", "state-of-the-art", |
| } |
| |
| # Words that AI uses far MORE than humans in academic-adjacent writing |
| AI_FINGERPRINT_WORDS = { |
| "delve", "underscore", "tapestry", "intricate", "pivotal", |
| "crucial", "vital", "essential", "significant", "notable", |
| "commendable", "noteworthy", "straightforward", "straightforwardly", |
| "elucidate", "expound", "illuminate", "unravel", "harness", |
| "foster", "facilitate", "leverage", "optimize", "streamline", |
| } |
| |
| |
| class HumanPatternFeatureExtractor: |
| """Extracts 55-dimensional feature vector encoding human vs AI writing patterns.""" |
| |
| def __init__(self, spacy_model: str = "en_core_web_sm"): |
| self.nlp = spacy.load(spacy_model) |
| self.gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2") |
| self.gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
| self.gpt2_model.eval() |
| |
| def _perplexity(self, text: str, max_len: int = 512) -> float: |
| """ |
| AI text tends to have LOWER perplexity under GPT-2 |
| because LLMs generate high-probability token sequences. |
| Human text is more unpredictable β higher perplexity. |
| |
| Lower perplexity = more likely to be AI. |
| Higher perplexity = more likely to be human. |
| """ |
| encodings = self.gpt2_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_len) |
| with torch.no_grad(): |
| outputs = self.gpt2_model(**encodings, labels=encodings["input_ids"]) |
| return math.exp(outputs.loss.item()) |
| |
| def _burstiness(self, sentences: List[str]) -> float: |
| """ |
| Burstiness = coefficient of variation of sentence lengths. |
| Human writing has high burstiness (unpredictable length variation). |
| AI writing has low burstiness (unnaturally uniform sentence lengths). |
| |
| B = std(lengths) / mean(lengths) |
| """ |
| lengths = [len(s.split()) for s in sentences] |
| if len(lengths) < 2 or np.mean(lengths) == 0: |
| return 0.0 |
| return np.std(lengths) / np.mean(lengths) |
| |
| def _sentence_starter_diversity(self, sentences: List[str]) -> float: |
| """ |
| Fraction of unique first words across sentences. |
| AI tends to start sentences with the same words repeatedly. |
| High = human-like. Low = AI-like. |
| """ |
| starters = [s.split()[0].lower() for s in sentences if s.split()] |
| if not starters: |
| return 0.0 |
| return len(set(starters)) / len(starters) |
| |
| def _ngram_novelty(self, text: str, n: int = 3) -> float: |
| """ |
| Ratio of unique n-grams to total n-grams. |
| AI repeats common n-grams more than humans. |
| Higher = more novel = more human-like. |
| """ |
| words = text.lower().split() |
| if len(words) < n: |
| return 1.0 |
| ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)] |
| return len(set(ngrams)) / len(ngrams) |
| |
| def _ai_marker_density(self, text: str, word_count: int) -> float: |
| """ |
| Frequency of AI-fingerprint words per 100 words. |
| Higher = more AI-like. |
| """ |
| lower = text.lower() |
| hits = sum(1 for marker in AI_FINGERPRINT_WORDS if marker in lower) |
| return (hits / max(word_count, 1)) * 100 |
| |
| def _overused_discourse_density(self, text: str, word_count: int) -> float: |
| """ |
| Frequency of AI-overused discourse markers per 100 words. |
| """ |
| lower = text.lower() |
| hits = sum(1 for marker in AI_OVERUSED_MARKERS if marker in lower) |
| return (hits / max(word_count, 1)) * 100 |
| |
| def _punctuation_pattern(self, text: str, word_count: int) -> Dict[str, float]: |
| """ |
| AI writing exhibits characteristic punctuation patterns: |
| - Overuse of em-dash (β) |
| - Underuse of ellipsis (...) |
| - Very consistent comma density |
| """ |
| em_dash_rate = text.count("β") / max(word_count, 1) * 100 |
| ellipsis_rate = text.count("...") / max(word_count, 1) * 100 |
| comma_rate = text.count(",") / max(word_count, 1) * 100 |
| semicolon_rate = text.count(";") / max(word_count, 1) * 100 |
| return { |
| "em_dash_rate": em_dash_rate, |
| "ellipsis_rate": ellipsis_rate, |
| "comma_rate": comma_rate, |
| "semicolon_rate": semicolon_rate, |
| } |
| |
| def extract(self, text: str) -> np.ndarray: |
| """Extract full 55-dimensional feature vector.""" |
| doc = self.nlp(text[:10000]) # Truncate for speed |
| sentences = [s.text.strip() for s in doc.sents if s.text.strip()] |
| words = [t.text for t in doc if not t.is_punct and not t.is_space] |
| word_count = len(words) |
| |
| punct = self._punctuation_pattern(text, word_count) |
| |
| features = np.array([ |
| # Human-pattern features |
| self._perplexity(text[:1024]), # Higher = more human |
| self._burstiness(sentences), # Higher = more human |
| self._sentence_starter_diversity(sentences), # Higher = more human |
| self._ngram_novelty(text, n=2), # Higher = more human |
| self._ngram_novelty(text, n=3), # Higher = more human |
| self._ngram_novelty(text, n=4), # Higher = more human |
| |
| # AI-pattern features (higher = more AI) |
| self._ai_marker_density(text, word_count), |
| self._overused_discourse_density(text, word_count), |
| punct["em_dash_rate"], |
| punct["ellipsis_rate"], |
| punct["comma_rate"], |
| punct["semicolon_rate"], |
| |
| # Distributional features |
| float(word_count), |
| float(len(sentences)), |
| np.mean([len(s.split()) for s in sentences]) if sentences else 0, |
| np.std([len(s.split()) for s in sentences]) if sentences else 0, |
| len(set(w.lower() for w in words)) / max(word_count, 1), # TTR |
| ], dtype=np.float32) |
| |
| return features # [17 raw features β extend as needed] |
| |
| |
| class KaggleHumanPatternDataset(Dataset): |
| """ |
| Loads both Kaggle datasets and produces (feature_vector, label) pairs. |
| label = 1 (human) | 0 (AI) |
| """ |
| |
| def __init__( |
| self, |
| shanegerami_path: str, |
| starblasters_path: str, |
| extractor: HumanPatternFeatureExtractor, |
| max_samples_per_source: int = 50000, |
| ): |
| self.extractor = extractor |
| self.samples = [] |
| |
| # Load Dataset 1 (shanegerami) |
| df1 = pd.read_csv(shanegerami_path).dropna() |
| df1 = df1.sample(min(len(df1), max_samples_per_source), random_state=42) |
| for _, row in df1.iterrows(): |
| self.samples.append({ |
| "text": str(row["text"]), |
| "label": int(row["generated"] == 0), # 0βAI, 1βhuman β flip: 1=human |
| "source": "shanegerami", |
| }) |
| |
| # Load Dataset 2 (starblasters β parquet) |
| df2 = pd.read_parquet(starblasters_path).dropna() |
| df2 = df2.sample(min(len(df2), max_samples_per_source), random_state=42) |
| for _, row in df2.iterrows(): |
| label = 1 if str(row["label"]).lower() == "human" else 0 |
| self.samples.append({ |
| "text": str(row["text"]), |
| "label": label, |
| "source": "starblasters", |
| }) |
| |
| print(f"Total samples loaded: {len(self.samples)}") |
| human = sum(1 for s in self.samples if s["label"] == 1) |
| print(f" Human: {human} | AI: {len(self.samples) - human}") |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: |
| sample = self.samples[idx] |
| features = self.extractor.extract(sample["text"]) |
| return torch.tensor(features), sample["label"] |
| |
| |
| class HumanPatternClassifier(nn.Module): |
| """ |
| Lightweight MLP trained to distinguish human from AI writing. |
| Input: feature vector from HumanPatternFeatureExtractor |
| Output: probability that text is human-written (0 to 1) |
| |
| This is PRE-TRAINED on the Kaggle datasets, then FROZEN. |
| Its output score is used as a loss signal in main model training. |
| High score = human-like = good. Low score = AI-like = penalise. |
| """ |
| |
| def __init__(self, input_dim: int = 17, hidden_dim: int = 128): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim // 2, 1), |
| nn.Sigmoid(), |
| ) |
| |
| def forward(self, features: torch.Tensor) -> torch.Tensor: |
| """Returns human-likeness score in [0, 1]. Higher = more human.""" |
| return self.net(features).squeeze(-1) |
| |
| def score(self, text: str, extractor: HumanPatternFeatureExtractor) -> float: |
| """Convenience: score a single text string.""" |
| features = torch.tensor(extractor.extract(text)).unsqueeze(0) |
| with torch.no_grad(): |
| return self.forward(features).item() |
| ``` |
|
|
| --- |
|
|
| ### `scripts/pretrain_human_pattern_classifier.py` |
| |
| ```python |
| """ |
| Pre-trains the HumanPatternClassifier on both Kaggle datasets. |
| Run this BEFORE the main training loop. |
| The saved classifier weights are then loaded frozen during main training. |
| |
| Run: python scripts/pretrain_human_pattern_classifier.py |
| Output: checkpoints/human_pattern_classifier.pt |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, random_split |
| from sklearn.metrics import accuracy_score, roc_auc_score |
| import numpy as np |
| from loguru import logger |
| import wandb |
|
|
| from src.training.human_pattern_extractor import ( |
| HumanPatternFeatureExtractor, |
| KaggleHumanPatternDataset, |
| HumanPatternClassifier, |
| ) |
| |
|
|
| def train_classifier(): |
| wandb.init(project="dyslexia-writing-ai", name="human-pattern-pretrain") |
| |
| extractor = HumanPatternFeatureExtractor() |
| dataset = KaggleHumanPatternDataset( |
| shanegerami_path="data/raw/shanegerami/train_essays.csv", |
| starblasters_path="data/raw/starblasters8/data.parquet", |
| extractor=extractor, |
| max_samples_per_source=50000, # 100k total β adjust for speed |
| ) |
| |
| train_size = int(0.85 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_ds, val_ds = random_split(dataset, [train_size, val_size]) |
| |
| train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=4) |
| val_loader = DataLoader(val_ds, batch_size=512, shuffle=False, num_workers=4) |
| |
| input_dim = extractor.extract("sample text").shape[0] |
| model = HumanPatternClassifier(input_dim=input_dim, hidden_dim=256) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) |
| criterion = nn.BCELoss() |
| |
| best_auc = 0.0 |
| for epoch in range(20): |
| # Train |
| model.train() |
| train_losses = [] |
| for features, labels in train_loader: |
| features = features.to(device) |
| labels = labels.float().to(device) |
| preds = model(features) |
| loss = criterion(preds, labels) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| train_losses.append(loss.item()) |
| |
| # Validate |
| model.eval() |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for features, labels in val_loader: |
| features = features.to(device) |
| preds = model(features).cpu().numpy() |
| all_preds.extend(preds) |
| all_labels.extend(labels.numpy()) |
| |
| auc = roc_auc_score(all_labels, all_preds) |
| acc = accuracy_score(all_labels, [1 if p > 0.5 else 0 for p in all_preds]) |
| scheduler.step() |
| |
| logger.info(f"Epoch {epoch+1:02d} | Loss: {np.mean(train_losses):.4f} | AUC: {auc:.4f} | Acc: {acc:.4f}") |
| wandb.log({"classifier/train_loss": np.mean(train_losses), "classifier/val_auc": auc, "classifier/val_acc": acc}) |
| |
| if auc > best_auc: |
| best_auc = auc |
| torch.save(model.state_dict(), "checkpoints/human_pattern_classifier.pt") |
| logger.info(f" β Saved best classifier (AUC: {best_auc:.4f})") |
| |
| wandb.finish() |
| logger.info(f"Pre-training complete. Best AUC: {best_auc:.4f}") |
| logger.info("Classifier saved to: checkpoints/human_pattern_classifier.pt") |
| |
|
|
| if __name__ == "__main__": |
| train_classifier() |
| ``` |
| |
| --- |
|
|
| ### Integration into Main Training Loss |
|
|
| #### Updated `src/training/loss_functions.py` β add `L_human_pattern` |
| |
| ```python |
| """ |
| UPDATED Combined Loss with Human-Pattern Term: |
| |
| L_total = L_CE + Ξ»β Β· L_style + Ξ»β Β· L_semantic + Ξ»β Β· L_human_pattern |
| |
| L_human_pattern = 1 - HumanPatternClassifier.score(output_text) |
| = reward for human-like output |
| = penalty for AI-like output |
|
|
| The HumanPatternClassifier is FROZEN β its weights do not update. |
| It acts as a discriminator/critic, not a trainable component. |
| Ξ»β default = 0.4 |
| """ |
|
|
| class CombinedCorrectionLossV2(nn.Module): |
|
|
| def __init__( |
| self, |
| lambda_style: float = 0.3, |
| lambda_semantic: float = 0.5, |
| lambda_human_pattern: float = 0.4, |
| classifier_path: str = "checkpoints/human_pattern_classifier.pt", |
| sem_model_name: str = "all-mpnet-base-v2", |
| device: str = "cuda", |
| ): |
| super().__init__() |
| self.lambda_style = lambda_style |
| self.lambda_semantic = lambda_semantic |
| self.lambda_human_pattern = lambda_human_pattern |
| self.device = device |
| |
| # Load pre-trained frozen classifier |
| from .human_pattern_extractor import HumanPatternClassifier, HumanPatternFeatureExtractor |
| self.hp_extractor = HumanPatternFeatureExtractor() |
| input_dim = self.hp_extractor.extract("sample").shape[0] |
| self.hp_classifier = HumanPatternClassifier(input_dim=input_dim) |
| self.hp_classifier.load_state_dict(torch.load(classifier_path, map_location=device)) |
| self.hp_classifier.to(device) |
| for param in self.hp_classifier.parameters(): |
| param.requires_grad = False # FROZEN β never trains |
| self.hp_classifier.eval() |
| |
| # Semantic model (also frozen) |
| from sentence_transformers import SentenceTransformer |
| self.sem_model = SentenceTransformer(sem_model_name, device=device) |
| for param in self.sem_model.parameters(): |
| param.requires_grad = False |
| |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) |
| |
| def _human_pattern_loss(self, output_texts: List[str]) -> torch.Tensor: |
| """ |
| For each output text, compute how AI-like it is. |
| Loss = 1 - human_score (penalise AI-like outputs). |
| |
| The gradient flows back through the generation model via this loss. |
| The classifier itself is frozen. |
| """ |
| features = torch.stack([ |
| torch.tensor(self.hp_extractor.extract(t)) |
| for t in output_texts |
| ]).to(self.device) |
| |
| with torch.no_grad(): |
| human_scores = self.hp_classifier(features) # [batch], values in [0,1] |
| |
| # Loss = average AI-likeness = 1 - average human-likeness |
| return (1 - human_scores).mean() |
| |
| def forward( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| output_style_vec: Optional[torch.Tensor] = None, |
| target_style_vec: Optional[torch.Tensor] = None, |
| input_texts: Optional[List[str]] = None, |
| output_texts: Optional[List[str]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| |
| vocab_size = logits.shape[-1] |
| l_ce = self.ce_loss(logits.view(-1, vocab_size), labels.view(-1)) |
| losses = {"l_ce": l_ce, "total": l_ce} |
| |
| if output_style_vec is not None and target_style_vec is not None: |
| sim = F.cosine_similarity(output_style_vec, target_style_vec, dim=-1) |
| l_style = (1 - sim).mean() |
| losses["l_style"] = l_style |
| losses["total"] = losses["total"] + self.lambda_style * l_style |
| |
| if input_texts is not None and output_texts is not None: |
| input_embs = torch.tensor(self.sem_model.encode(input_texts), device=self.device) |
| output_embs = torch.tensor(self.sem_model.encode(output_texts), device=self.device) |
| sim = F.cosine_similarity(input_embs, output_embs, dim=-1) |
| l_sem = (1 - sim).mean() |
| losses["l_semantic"] = l_sem |
| losses["total"] = losses["total"] + self.lambda_semantic * l_sem |
| |
| if output_texts is not None: |
| l_hp = self._human_pattern_loss(output_texts) |
| losses["l_human_pattern"] = l_hp |
| losses["total"] = losses["total"] + self.lambda_human_pattern * l_hp |
| |
| return losses |
| ``` |
| |
| --- |
|
|
| ### Updated Mathematical Formulation |
|
|
| ``` |
| L_total = L_CE + Ξ»β Β· L_style + Ξ»β Β· L_semantic + Ξ»β Β· L_human_pattern |
| |
| L_human_pattern = 1 - (1/N) Ξ£α΅’ HPC(Ο(Ε·α΅’)) |
| |
| Where: |
| HPC(Β·) = HumanPatternClassifier (frozen) |
| Ο(Β·) = HumanPatternFeatureExtractor |
| Ε·α΅’ = model's generated output text for example i |
| N = batch size |
| |
| Ξ»β = 0.3 (style consistency weight) |
| Ξ»β = 0.5 (semantic preservation weight) |
| Ξ»β = 0.4 (human pattern reward weight) |
| |
| Total loss weights must sum interpretably: |
| Ξ»β + Ξ»β + Ξ»β = 1.2 (additive, CE is the base anchor) |
| |
| The HumanPatternClassifier is trained to maximise AUC on the two Kaggle datasets. |
| Target pre-training performance: AUC > 0.88, Accuracy > 83% |
| ``` |
|
|
| --- |
|
|
| ### Updated Training Sequence |
|
|
| ```bash |
| # 0. Download both Kaggle datasets |
| bash scripts/download_kaggle_datasets.sh |
| |
| # 1. Pre-train the HumanPatternClassifier (runs separately, ~1-2 hours) |
| python scripts/pretrain_human_pattern_classifier.py |
| |
| # 2. Verify classifier quality (target AUC > 0.88) |
| python scripts/evaluate_classifier.py --checkpoint checkpoints/human_pattern_classifier.pt |
| |
| # 3. Then run main model training (classifier is auto-loaded frozen) |
| python scripts/train.py --config configs/training_config.yaml |
| |
| # 4. All four losses now tracked in W&B: |
| # loss/ce, loss/style, loss/semantic, loss/human_pattern, loss/total |
| ``` |
|
|
| --- |
|
|
| ### Updated `configs/training_config.yaml` additions |
| |
| ```yaml |
| # Add this section to configs/training_config.yaml: |
|
|
| human_pattern: |
| classifier_path: "checkpoints/human_pattern_classifier.pt" |
| shanegerami_path: "data/raw/shanegerami/train_essays.csv" |
| starblasters_path: "data/raw/starblasters8/data.parquet" |
| max_samples_per_source: 50000 |
| pretrain_epochs: 20 |
| pretrain_lr: 1.0e-3 |
| pretrain_batch_size: 512 |
| target_auc: 0.88 |
| |
| loss: |
| lambda_style: 0.3 |
| lambda_semantic: 0.5 |
| lambda_human_pattern: 0.4 # NEW β added in v2 |
| sem_model_name: "all-mpnet-base-v2" |
| ``` |
| |
| --- |
| |
| ### What the Two Datasets Teach the Model (Not What They Are Used For) |
| |
| | Learning Target | From Dataset | Mechanism | |
| |---|---|---| |
| | Human sentence length is bursty/unpredictable | Both datasets | Burstiness feature in HPC | |
| | Humans don't start every sentence the same way | Both datasets | Starter diversity feature | |
| | AI text has lower GPT-2 perplexity | Both (AI side) | Perplexity feature in HPC | |
| | AI overuses "delve", "tapestry", "crucial", "pivotal" | Both (AI side) | AI fingerprint word density | |
| | AI overuses "Furthermore", "Moreover", "In conclusion" | Both (AI side) | Discourse marker density | |
| | Humans have higher n-gram novelty | Both | n-gram novelty score | |
| | 63 different LLMs share the same surface patterns | starblasters8 | Broad AI coverage in HPC training | |
| | GPT-family essays are detectable at scale | shanegerami | Dense GPT signature learning | |
| |
| --- |
| |
| ## 22. Complete Dataset Directory |
| |
| All publicly available datasets relevant to this system, across three categories. |
| |
| --- |
| |
| ### Category A β Grammar Error Correction (Core Training Data) |
| |
| | Dataset | Size | Notes | Access | HuggingFace ID | |
| |---|---|---|---|---| |
| | W&I+LOCNESS | 35k pairs | Gold standard GEC, learner English, 5 proficiency levels | Free registration | `wi_locness` | |
| | JFLEG | 1.5k pairs | 4 human references per sentence, fluency focus | Public GitHub | β | |
| | CoNLL-2014 | 1.3k pairs | 2 human annotators, classic GEC benchmark | Public | β | |
| | FCE Corpus | 33k pairs | Cambridge First Certificate essays with corrections | Free registration | β | |
| | NUCLE | 57k sentences | NUS Corpus of Learner English, sentence-level errors | Free registration | β | |
| | Lang-8 | 1M+ pairs | Crowdsourced learner writing corrections in 80 languages | Request form | β | |
| | CLANG-8 | 2.6M pairs | Cleaned Lang-8, filtered for English quality | HuggingFace | `google/clang8` | |
| | Falko-MERLIN | 24k sentences | German learner English (good for multilingual) | Public | β | |
| | BEA-2019 Shared Task | 4k test pairs | Official GEC evaluation set, gold standard | Free | β | |
|
|
| --- |
|
|
| ### Category B β Formality & Style Transfer (Style Training Data) |
|
|
| | Dataset | Size | Notes | Access | HuggingFace ID | |
| |---|---|---|---|---| |
| | GYAFC | 105k pairs | Yahoo Answers informal β formal, 2 domains | Request Grammarly | β | |
| | YELP Sentiment Transfer | 560k reviews | Sentiment-controlled style transfer | Public | `yelp_review_full` | |
| | Shakespeare Modern | 21k lines | Shakespearean β modern English parallel | Public GitHub | β | |
| | Europarl | 60M sentences | Formal parliamentary discourse, 21 languages | Public | `Helsinki-NLP/europarl` | |
| | WikiText-103 | 103M tokens | High-quality Wikipedia prose, formal register | Public | `wikitext` | |
| | OpenWebText | 40GB | Curated human web text (Reddit upvoted links) | Public | `openwebtext` | |
| | PAWS | 108k pairs | Paraphrase pairs with controlled syntactic diversity | Public | `paws` | |
| | ParaBank2 | 50M pairs | Large-scale paraphrase pairs | Public | β | |
|
|
| --- |
|
|
| ### Category C β Human vs AI Distinction (Anti-AI Training Data) |
|
|
| #### Your Two Selected Datasets |
|
|
| | Dataset | Size | LLMs Covered | Access | Notes | |
| |---|---|---|---|---| |
| | shanegerami/ai-vs-human-text | 500k essays | GPT-family | Kaggle | Columns: text, generated(0/1) | |
| | starblasters8/human-vs-llm-text-corpus | 800k texts | 63 LLMs | Kaggle | Parquet: text, label(str) | |
|
|
| #### Additional Highly Recommended |
|
|
| | Dataset | Size | LLMs Covered | Access | HuggingFace / URL | |
| |---|---|---|---|---| |
| | RAID Benchmark | 6.2M generations | 11 generators, 8 domains, 11 adversarial attacks | Public | `liamdugan/raid` | |
| | HC3 (Human-ChatGPT Corpus) | 125k QA pairs | ChatGPT only | Public | `Hello-SimpleAI/HC3` | |
| | HC3-Plus | 210k pairs | ChatGPT, semantic-invariant variants | Public | `Hello-SimpleAI/HC3-Chinese` | |
| | M4GT-Bench | 152k texts | 7 LLMs, 8 languages, 8 domains | Public | `NicolaiSivesind/ChatGPT-Research-Abstracts` | |
| | DeepfakeTextDetect | 447k texts | 27 LLMs, 10 domains | Public | `Li2023` / arxiv | |
| | MGTBench | 21k texts | 6 LLMs, 3 domains | Public | `aadityaubhat/GPT-wiki-intro` | |
| | MAGE Dataset | 447k texts | Largest multi-model human/AI corpus | Public | `yaful/MAGE` | |
| | TuringBench | 168k articles | 20 LLMs including GPT-2 to GPT-3 | Public | β | |
| | BUST | 25.2k texts | 7 generators, 4 domains | Public | β | |
| | DetectRL | 235k texts | 4 LLMs, adversarial-robust benchmark | Public | β | |
| | GPT-Wiki-Intro | 150k intros | GPT-3.5 vs Wikipedia introductions | Public | `aadityaubhat/GPT-wiki-intro` | |
| | SemEval 2024 Task 8 | ~70k texts | Mixed human/AI, boundary detection task | Public | SemEval 2024 | |
| | PeerRead | 14.7k papers | Scientific paper review AI vs human | Public | `allenai/PeerRead` | |
| | ArXiv AI Abstract Dataset | 500k+ abstracts | Scientific writing, GPT vs real | Public | arxiv bulk API | |
| | ELI5-Human-AI | 30k pairs | Mistral-7B vs human on Explain Like I'm 5 | Public | Research benchmark | |
| | HC-Var | 145k texts | ChatGPT variants across prompting strategies | Public | β | |
| | WritingPrompts (Human) | 303k stories | Reddit human creative writing β pure human signal | Public | `euclaise/writingprompts` | |
| | MultiSocial | 472k texts | Social media, 22 languages, 7 LLMs | Public | β | |
| | WETBench | 101.9k texts | Web & essay text, 4 LLMs | Public | β | |
| | silentone0725/ai-human-text-detection-v1 | 9 corpora merged | HC3, RAID, M4GT-Bench + more, pre-cleaned | Public | `silentone0725/ai-human-text-detection-v1` | |
|
|
| --- |
|
|
| ### Category D β Dyslexia-Specific Data |
|
|
| | Dataset | Size | Notes | Access | URL | |
| |---|---|---|---|---| |
| | DysLexML Corpus | ~2k texts | Actual dyslectic writing samples, annotated | Academic request | Research paper: Rello et al. | |
| | POPSYCLE Corpus | ~800 texts | Dyslexic children's writing with expert annotations | Academic request | Lancaster University | |
| | Write & Improve (W&I) subset | ~5k texts | Includes dyslexia-pattern learner errors | Free registration | Cambridge | |
| | Synthetic (DyslexiaSimulator) | Unlimited | Generated by your own simulator (Layer 1) | Self-generated | `src/preprocessing/dyslexia_simulator.py` | |
|
|
| --- |
|
|
| ### Recommended Dataset Priority Order for Training |
|
|
| ``` |
| Phase 1 β Classifier Pre-training (Human Pattern): |
| 1. starblasters8/human-vs-llm-text-corpus (800k, 63 LLMs β widest coverage) |
| 2. shanegerami/ai-vs-human-text (500k, dense GPT signal) |
| 3. RAID Benchmark (6.2M, adversarial robustness) |
| 4. MAGE Dataset (447k, 27 LLMs) |
| |
| Phase 2 β Core GEC Model Training: |
| 1. CLANG-8 (2.6M pairs, largest clean GEC) |
| 2. W&I+LOCNESS (35k, gold standard, highest quality) |
| 3. JFLEG (1.5k, fluency focus) |
| 4. Synthetic dyslexia pairs (generated, unlimited) |
| |
| Phase 3 β Style Transfer Training: |
| 1. GYAFC (105k formal/informal pairs) |
| 2. WikiText-103 (103M tokens, formal register) |
| 3. OpenWebText (40GB human web text) |
| |
| Phase 4 β Academic Register Fine-tuning: |
| 1. PeerRead (14.7k academic papers) |
| 2. ArXiv abstracts (500k+ scientific writing) |
| 3. Europarl (60M formal parliamentary) |
| ``` |
|
|
| --- |
|
|
| ### `scripts/download_all_huggingface_datasets.py` |
| |
| ```python |
| """ |
| Downloads all publicly available HuggingFace datasets automatically. |
| Datasets requiring registration/request are flagged with instructions. |
| |
| Run: python scripts/download_all_huggingface_datasets.py |
| """ |
|
|
| from datasets import load_dataset |
| import os |
| |
| os.makedirs("data/raw/hf", exist_ok=True) |
|
|
| HF_DATASETS = [ |
| # (hf_identifier, config, split, output_subdir) |
| ("google/clang8", "en", "train", "clang8"), |
| ("liamdugan/raid", None, "train", "raid"), |
| ("Hello-SimpleAI/HC3", "all", "train", "hc3"), |
| ("yaful/MAGE", None, "train", "mage"), |
| ("aadityaubhat/GPT-wiki-intro", None, "train", "gpt_wiki_intro"), |
| ("euclaise/writingprompts", None, "train", "writing_prompts"), |
| ("wikitext", "wikitext-103-raw-v1", "train", "wikitext103"), |
| ("openwebtext", None, "train", "openwebtext"), |
| ("paws", "labeled_final", "train", "paws"), |
| ("allenai/PeerRead", "all", "train", "peerread"), |
| ("silentone0725/ai-human-text-detection-v1", None, "train", "merged_ai_human"), |
| ] |
| |
| for hf_id, config, split, subdir in HF_DATASETS: |
| out_path = f"data/raw/hf/{subdir}" |
| if os.path.exists(out_path): |
| print(f"β Already exists: {subdir}") |
| continue |
| try: |
| print(f"Downloading: {hf_id}...") |
| ds = load_dataset(hf_id, config, split=split, trust_remote_code=True) |
| ds.save_to_disk(out_path) |
| print(f" β Saved to {out_path} ({len(ds)} examples)") |
| except Exception as e: |
| print(f" β Failed: {hf_id} β {e}") |
| |
| # Datasets requiring manual action |
| MANUAL_DATASETS = { |
| "W&I+LOCNESS": "https://www.cl.cam.ac.uk/research/nl/bea2019st/ (free registration)", |
| "GYAFC": "https://github.com/raosudha89/GYAFC-corpus (email request to Grammarly)", |
| "FCE Corpus": "https://ilexir.co.uk/datasets/index.html (free registration)", |
| "NUCLE": "https://www.comp.nus.edu.sg/~nlp/corpora.html (free registration)", |
| "Lang-8": "https://sites.google.com/site/naistlang8corpora/ (request form)", |
| "DysLexML": "Contact Rello et al. authors directly via ResearchGate", |
| } |
| |
| print("\nββ Datasets requiring manual download ββ") |
| for name, url in MANUAL_DATASETS.items(): |
| print(f" {name}: {url}") |
| ``` |
| |
| --- |
|
|
| *Blueprint version 2.0 β Dyslexia Academic Writing Correction System* |
| *Architecture: Style-Preserving Constrained Correction + Human-Pattern Anti-AI Training* |
| *Datasets: 25+ sources Β· 10M+ training examples Β· 63 LLMs covered* |
|
|