rewrite / README.md
morpheuslord's picture
Update README.md
6332b0b verified
metadata
language:
  - en
tags:
  - text2text-generation
  - dyslexia
  - grammar-correction
  - style-preservation
  - lora
  - flan-t5
license: mit
base_model: google/flan-t5-small
datasets:
  - jhu-clsp/jfleg
  - bea2019st/wi_locness
pipeline_tag: translation

Dyslexia Academic Writing Correction System

A style-preserving, grammar-correcting, academic vocabulary elevating AI system that corrects dyslectic writing while maintaining the author's personal voice, tone, and authorship signal โ€” not a rewriter, a corrector.

Overview

This system takes text written by dyslexic students and corrects grammar, spelling, and fluency errors while:

  1. Preserving the author's unique writing style via a 512-dimensional style fingerprint vector
  2. Elevating vocabulary to academic register using Coxhead's Academic Word List (AWL) and BERT-based lexical substitution
  3. Resisting AI detection through a frozen Human Pattern Classifier that penalises AI-typical writing during training
  4. Maintaining semantic meaning with cosine-similarity-based semantic preservation loss

The core model is Google Flan-T5-Small fine-tuned with LoRA (Low-Rank Adaptation, r=16), trained on real learner error corpora (JFLEG, W&I+LOCNESS) augmented with synthetic dyslexia-simulated data.


Latest Evaluation Results (v3)

Metric Score Description
GLEU 0.7593 Grammar + fluency correction quality
BERTScore F1 0.9758 Semantic closeness to reference corrections
1 โˆ’ WER 0.8552 Word-level accuracy (WER = 14.48%)
Composite 0.8634 (GLEU + BERTScore F1 + (1โˆ’WER)) / 3 โ€” gating score for Hub push
Faithfulness reverts 11 Outputs whose cosine sim to input fell below 0.75 โ€” reverted to source

The model is only pushed to the Hub when the composite score strictly beats the saved baseline from the previous run, ensuring the Hub always holds the best-seen weights.

Score Progression

Metric v1 v2 v3 ฮ” v2โ†’v3
GLEU โ€” 0.7506 0.7593 +0.0087
BERTScore F1 โ€” 0.9733 0.9758 +0.0025
1 โˆ’ WER โ€” 0.8488 0.8552 +0.0064
Composite โ€” 0.8576 0.8634 +0.0058

What Changed in v3

v3 keeps the same base model and LoRA rank as v2 but improves every other stage of the pipeline: wider context window, better generation, a semantic faithfulness gate that prevents meaning-destroying corrections, and optional ERRANT F0.5 evaluation.

Parameter v2 v3
Context window 128 tokens 256 tokens
Additional data JFLEG + W&I only + C4-200M-GEC (~100k pairs, falls back if unavailable)
Beam search num_beams=2 num_beams=5, length_penalty=1.2, repetition_penalty=1.3, no_repeat_ngram_size=3
Faithfulness gate none cosine sim < 0.75 โ†’ revert output to source
Human-pattern loss skipped on CPU active on GPU (loads classifier from Hub if present)
Evaluation cap always 200 samples 200 on CPU, full test set on GPU
ERRANT F0.5 not present optional metric (install errant + en_core_web_sm)
Composite mean(GLEU, BERTScore, 1-WER) mean(GLEU, BERTScore, 1-WER [, ERRANT F0.5 if available])

Semantic Faithfulness Gate (v3)

After generation, each output is checked against its source input using all-MiniLM-L6-v2 sentence embeddings. If cosine similarity falls below 0.75, the output is discarded and the original input is returned as fallback โ€” preventing corrections that accidentally change meaning.

In the v3 evaluation run, 11 outputs (of 228 test pairs evaluated) were reverted. Without the gate, those would have been incorrect predictions dragging all three metrics down.

Combined Loss (v3 โ€” unchanged from v2 on CPU)

L = L_CE + 0.3ยทL_style + 0.5ยทL_semantic          (CPU)
L = L_CE + 0.3ยทL_style + 0.5ยทL_semantic + 0.4ยทL_human   (GPU)
Term Purpose Weight
L_CE Cross-entropy with label smoothing (0.1) 1.0
L_style 1 โˆ’ cos_sim(style(input), style(output)) 0.3
L_semantic 1 โˆ’ cos_sim(input_emb, output_emb) 0.5
L_human 1 โˆ’ HumanPatternClassifier(output) โ€” anti-AI penalty 0.4 (GPU only)

What Changed in v2

The original model had a critical bug: CorrectionTrainer.compute_loss() only used cross-entropy loss. The multi-objective loss (L_CE + ฮป_style + ฮป_semantic + ฮป_human) was fully designed in loss_functions.py but was never wired into the trainer. v2 fixes this and upgrades several other parameters.

Parameter v1 (Original) v2 (Upgraded)
LoRA rank r=8, ฮฑ=16 r=16, ฮฑ=32
Epochs 5 10
Effective batch size 32 (4ร—8 accum) 64 (2ร—32 accum)
Learning rate 3e-4 2e-4
Warmup ratio 5% 10%
Label smoothing none 0.1
Loss function CE only (bug) CE + Style + Semantic (fixed)
Human-pattern loss designed, unused omitted on CPU; falls back to CE+style+sem
Evaluation GLEU only GLEU + BERTScore F1 + (1โˆ’WER) composite
Eval/save strategy every 100 steps per epoch
Early stopping none patience=3
Hub gate none composite must beat saved baseline
Warm-start strategy cold start merge r=8 adapter โ†’ apply fresh r=16 LoRA
Data split 90%/10% train/val 88%/7%/5% train/val/test
Dyslexia augmentation error rate 15% 20%

Features

Feature Description
Two-pass spell correction Dyslexia-aware phonetic pattern handling via LanguageTool
Style fingerprinting 41 raw features โ†’ MLP โ†’ 512-dim L2-normalised style vector
LoRA fine-tuning r=16, ฮฑ=32, dropout=0.05 โ€” targeting all attention + FFN projections
Academic vocabulary elevation BERT fill-mask โ†’ AWL candidate filtering โ†’ semantic similarity gate
Human pattern anti-AI loss Pre-trained frozen MLP classifier (17-dim features including GPT-2 perplexity)
Combined training loss L_CE + ฮปโ‚ยทL_style + ฮปโ‚‚ยทL_semantic (+ ฮปโ‚ƒยทL_human on GPU)
Semantic faithfulness gate Outputs with cosine sim < 0.75 to source are reverted โ€” prevents meaning drift
Sentence-chunked inference Long texts split into 256-token chunks matching training window
FastAPI server RESTful /correct endpoint with CORS and rate limiting
Multi-stage training Orchestrated via train.sh with checkpoint system (Skip/Redo/Continue)
Synthetic data augmentation DyslexiaSimulator generates realistic errors from clean text (20% error rate)
Composite score gating Hub push only if new model strictly beats saved baseline

Project Structure

Rewriter/
โ”œโ”€โ”€ configs/
โ”‚   โ”œโ”€โ”€ training_config.yaml        # Full training hyperparameters
โ”‚   โ”œโ”€โ”€ training_config_fast.yaml   # Quick iteration config
โ”‚   โ”œโ”€โ”€ inference_config.yaml       # Inference & generation settings
โ”‚   โ”œโ”€โ”€ model_config.yaml           # Model architecture registry
โ”‚   โ””โ”€โ”€ awl_config.yaml             # Academic Word List settings
โ”œโ”€โ”€ scripts/
โ”‚   โ”œโ”€โ”€ train.py                    # Main training script (Click CLI)
โ”‚   โ”œโ”€โ”€ evaluate.py                 # Test set evaluation (GLEU, ERRANT, BERTScore)
โ”‚   โ”œโ”€โ”€ run_inference.py            # Interactive CLI inference
โ”‚   โ”œโ”€โ”€ preprocess_data.py          # Raw datasets โ†’ unified JSONL
โ”‚   โ”œโ”€โ”€ pretrain_human_pattern_classifier.py  # Stage 3: anti-AI classifier
โ”‚   โ”œโ”€โ”€ download_datasets.sh        # BEA-2019 dataset downloader
โ”‚   โ””โ”€โ”€ download_kaggle_datasets.sh # Kaggle human/AI data downloader
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ model/
โ”‚   โ”‚   โ”œโ”€โ”€ base_model.py           # Model loader (T5/BART/Llama + LoRA + quantization)
โ”‚   โ”‚   โ”œโ”€โ”€ style_conditioner.py    # Prefix tuning: style โ†’ virtual tokens
โ”‚   โ”‚   โ”œโ”€โ”€ generation_utils.py     # Beam search, sampling, batch generation
โ”‚   โ”‚   โ””โ”€โ”€ lora_adapter.py         # LoRA configuration helpers
โ”‚   โ”œโ”€โ”€ preprocessing/
โ”‚   โ”‚   โ”œโ”€โ”€ pipeline.py             # Full preprocessing orchestrator
โ”‚   โ”‚   โ”œโ”€โ”€ spell_corrector.py      # LanguageTool + dyslexia-aware correction
โ”‚   โ”‚   โ”œโ”€โ”€ dyslexia_simulator.py   # Synthetic error generation (Rello et al.)
โ”‚   โ”‚   โ”œโ”€โ”€ dependency_parser.py    # spaCy dependency tree analysis
โ”‚   โ”‚   โ”œโ”€โ”€ ner_tagger.py           # Named entity protection
โ”‚   โ”‚   โ””โ”€โ”€ sentence_segmenter.py   # Sentence boundary detection
โ”‚   โ”œโ”€โ”€ style/
โ”‚   โ”‚   โ”œโ”€โ”€ fingerprinter.py        # 41 features โ†’ 512-dim style vector
โ”‚   โ”‚   โ”œโ”€โ”€ style_vector.py         # Style vector dataclass
โ”‚   โ”‚   โ”œโ”€โ”€ formality_classifier.py # Rule-based formality scoring
โ”‚   โ”‚   โ””โ”€โ”€ emotion_classifier.py   # Emotion detection
โ”‚   โ”œโ”€โ”€ training/
โ”‚   โ”‚   โ”œโ”€โ”€ dataset.py              # Pre-tokenized cached dataset with style vectors
โ”‚   โ”‚   โ”œโ”€โ”€ trainer.py              # CorrectionTrainer (HF Trainer + PEFT fixes)
โ”‚   โ”‚   โ”œโ”€โ”€ loss_functions.py       # V1 and V2 combined losses
โ”‚   โ”‚   โ”œโ”€โ”€ human_pattern_extractor.py  # 17-dim feature extraction + classifier
โ”‚   โ”‚   โ””โ”€โ”€ callbacks.py            # Evaluation logging callbacks
โ”‚   โ”œโ”€โ”€ vocabulary/
โ”‚   โ”‚   โ”œโ”€โ”€ lexical_substitution.py # BERT fill-mask โ†’ AWL substitution pipeline
โ”‚   โ”‚   โ”œโ”€โ”€ awl_loader.py           # Coxhead Academic Word List loader
โ”‚   โ”‚   โ””โ”€โ”€ register_filter.py      # Contraction expansion + colloquial replacement
โ”‚   โ”œโ”€โ”€ inference/
โ”‚   โ”‚   โ”œโ”€โ”€ corrector.py            # End-to-end inference pipeline orchestrator
โ”‚   โ”‚   โ””โ”€โ”€ postprocessor.py        # Cleanup, entity restore, formatting
โ”‚   โ”œโ”€โ”€ evaluation/
โ”‚   โ”‚   โ”œโ”€โ”€ gleu_scorer.py          # GLEU + BERTScore computation
โ”‚   โ”‚   โ”œโ”€โ”€ errant_evaluator.py     # ERRANT P/R/F0.5 evaluation
โ”‚   โ”‚   โ”œโ”€โ”€ style_metrics.py        # Style similarity + AWL coverage
โ”‚   โ”‚   โ””โ”€โ”€ authorship_verifier.py  # AI detection resistance testing
โ”‚   โ””โ”€โ”€ api/
โ”‚       โ”œโ”€โ”€ main.py                 # FastAPI application
โ”‚       โ”œโ”€โ”€ schemas.py              # Pydantic request/response models
โ”‚       โ””โ”€โ”€ middleware.py           # Rate limiting + CORS
โ”œโ”€โ”€ train_and_upgrade.py            # v3 upgrade pipeline (self-improving Hub push)
โ”œโ”€โ”€ data/
โ”‚   โ”œโ”€โ”€ raw/                        # Original datasets (JFLEG, W&I+LOCNESS)
โ”‚   โ”œโ”€โ”€ processed/                  # Unified JSONL (train/val/test splits)
โ”‚   โ”œโ”€โ”€ cache/                      # Pre-tokenized dataset caches (.pt files)
โ”‚   โ””โ”€โ”€ awl/                        # Coxhead Academic Word List
โ”œโ”€โ”€ train.sh                        # Multi-stage training orchestrator
โ”œโ”€โ”€ start.sh                        # Inference launcher (CLI or API mode)
โ”œโ”€โ”€ baseline_score.json             # Saved composite score (0.8634) โ€” gate for Hub push
โ”œโ”€โ”€ Dockerfile                      # Production container
โ”œโ”€โ”€ docker-compose.yml              # Docker deployment
โ”œโ”€โ”€ requirements.txt                # Python dependencies
โ””โ”€โ”€ pyproject.toml                  # Project metadata

Model Architecture

PNG:

Architecture

Mermaid Diagram:

graph TB
    subgraph INFERENCE["๐Ÿ”ฎ Inference Pipeline"]
        direction TB
        INPUT["๐Ÿ“ Raw Dyslectic Text"]
        subgraph PREPROCESS["Pre-Processing"]
            SPELL["Spell Corrector<br/><i>dyslexia-aware phonetic</i>"]
            SENT_SEG["Sentence Segmenter"]
            DEP_PARSE["Dependency Parser"]
            NER["NER Tagger"]
        end
        subgraph STYLE["Style Analysis"]
            FINGER["Style Fingerprinter<br/><i>512-dim vector</i>"]
            EMOTION["Emotion Classifier"]
            FORMALITY["Formality Classifier"]
            STYLE_VEC["Style Vector Composer"]
        end
        subgraph GENERATION["Core Generation"]
            STYLE_COND["Style Conditioner<br/><i>prefix tuning</i>"]
            BASE_MODEL["Base LM<br/><i>Flan-T5-Small (warm-merged)</i>"]
            LORA["LoRA Adapter<br/><i>r=16</i>"]
            GEN_UTILS["Generation Utils<br/><i>beam search, sampling</i>"]
        end
        subgraph POSTPROCESS["Post-Processing"]
            FAITH["Faithfulness Gate<br/><i>cos sim &lt; 0.75 โ†’ revert</i>"]
            POSTPROC["Post-Processor<br/><i>formatting, cleanup</i>"]
            VOCAB_SUB["Lexical Substitution<br/><i>BERT-based</i>"]
            AWL["AWL Loader<br/><i>Coxhead Academic Word List</i>"]
            REG_FILTER["Register Filter<br/><i>academic tone gate</i>"]
        end
        OUTPUT["โœ… Corrected Academic Text"]
        INPUT --> SPELL --> SENT_SEG --> DEP_PARSE --> NER
        INPUT --> FINGER --> EMOTION --> FORMALITY --> STYLE_VEC
        NER --> STYLE_COND
        STYLE_VEC --> STYLE_COND
        STYLE_COND --> BASE_MODEL
        LORA -.->|"merged weights"| BASE_MODEL
        BASE_MODEL --> GEN_UTILS --> FAITH --> POSTPROC
        POSTPROC --> VOCAB_SUB
        AWL --> VOCAB_SUB
        VOCAB_SUB --> REG_FILTER --> OUTPUT
    end

    subgraph TRAINING["๐Ÿ‹๏ธ Training Pipeline (v3)"]
        direction TB
        subgraph WARMSTART["Warm-Start Merge"]
            HUB_ADAPTER["Hub LoRA Adapter<br/><i>r=16 (v2)</i>"]
            MERGE["merge_and_unload()"]
            FRESH_LORA["Fresh LoRA r=16"]
        end
        subgraph DATA["Data Pipeline"]
            JFLEG["jhu-clsp/jfleg<br/><i>~5k pairs, 4 refs each</i>"]
            WILOCNESS["bea2019st/wi_locness<br/><i>~34k pairs</i>"]
            C4GEC["C4-200M-GEC<br/><i>~100k pairs (optional)</i>"]
            DYSLEXIA_AUG["DyslexiaSimulator<br/><i>20% error rate augmentation</i>"]
            SPLIT["88% train / 7% val / 5% test"]
        end
        subgraph LOSS["Combined Loss (v3)"]
            L_CE["L_CE + label_smoothing=0.1"]
            L_STYLE["0.3 ยท L_style"]
            L_SEM["0.5 ยท L_semantic"]
            L_HUMAN["0.4 ยท L_human<br/><i>(GPU only)</i>"]
        end
        subgraph EVAL["Composite Evaluation"]
            GLEU_E["GLEU"]
            BERT_E["BERTScore F1"]
            WER_E["1 โˆ’ WER"]
            ERRANT_E["ERRANT F0.5<br/><i>(optional)</i>"]
            COMPOSITE["Composite = mean(3 or 4)"]
            GATE["Beat baseline?"]
            HUB_PUSH["Push to Hub โœ…"]
        end
        HUB_ADAPTER --> MERGE --> FRESH_LORA
        JFLEG --> DYSLEXIA_AUG
        WILOCNESS --> DYSLEXIA_AUG
        C4GEC --> DYSLEXIA_AUG
        DYSLEXIA_AUG --> SPLIT
        L_CE --> COMPOSITE
        L_STYLE --> COMPOSITE
        L_SEM --> COMPOSITE
        GLEU_E --> COMPOSITE
        BERT_E --> COMPOSITE
        WER_E --> COMPOSITE
        ERRANT_E -.->|"if installed"| COMPOSITE
        COMPOSITE --> GATE --> HUB_PUSH
    end

Design Choices & Rationale

Why Flan-T5-Small?

Consideration Decision
Hardware constraint RTX 3050 Laptop GPU (4GB VRAM) โ€” rules out models > 500M params
Architecture Encoder-decoder (seq2seq) is ideal for text-to-text correction tasks
Instruction tuning Flan-T5 is pre-trained on 1,800+ instruction tasks โ€” follows correction prompts naturally
LoRA efficiency Trainable params scale with r: r=16 โ†’ ~2.56M (3.3%) โ€” still fits in 4GB

Why LoRA over Full Fine-Tuning?

  • Memory: Full fine-tuning of T5-Small requires ~2.5GB for gradients alone; LoRA r=16 needs ~400MB
  • Warm-start safety: Merging r=8 weights preserves corrections before expanding capacity to r=16
  • Merging: LoRA weights merge into base model at inference time โ€” zero latency overhead
  • Configuration: r=16, alpha=32, dropout=0.05, targeting all attention + FFN projections (q, k, v, o, wi_0, wi_1, wo)

Why a Combined Multi-Objective Loss?

The system uses (on CPU): L = L_CE + 0.3ยทL_style + 0.5ยทL_semantic

On GPU (with human-pattern classifier available): L = L_CE + 0.3ยทL_style + 0.5ยทL_semantic + 0.4ยทL_human

Term Purpose Weight
L_CE Cross-entropy with label smoothing (0.1) 1.0
L_style 1 โˆ’ cos_sim(style(input), style(output)) โ€” preserves writing fingerprint 0.3
L_semantic 1 โˆ’ cos_sim(input_emb, output_emb) โ€” preserves meaning 0.5
L_human 1 โˆ’ HumanPatternClassifier(output) โ€” penalises AI-like text patterns 0.4

Why a Semantic Faithfulness Gate?

Even a well-trained correction model can occasionally produce outputs that drift semantically from the input โ€” particularly when dyslexic phonetic patterns are ambiguous (e.g. "becaus" could be "because" or "becaused"). Rather than accepting every model output blindly, v3 computes cosine similarity between the source and output using all-MiniLM-L6-v2 sentence embeddings. Outputs below 0.75 similarity are treated as unreliable and the original input is returned unchanged. This is conservative by design: a correct-but-awkward source is always better than a fluent-but-wrong correction.

Why a Human Pattern Classifier?

AI-generated text has detectable statistical signatures:

  • Lower GPT-2 perplexity (AI text is more "predictable")
  • Lower burstiness (AI has uniform sentence lengths; humans vary)
  • Higher AI marker density (overuse of "delve", "leverage", "furthermore")
  • Lower n-gram novelty (AI reuses phrases more)

The classifier is a 3-layer MLP (17โ†’128โ†’64โ†’1) pre-trained on ~100k samples from two Kaggle datasets (Shanegerami AI_Human.csv + Starblasters8), then frozen during main training. Its output score (0=AI, 1=human) is used as a reward signal. Requires GPU for GPT-2 perplexity scoring; falls back gracefully on CPU.

Why Sentence-Chunked Inference?

The model is trained with max_input_length=256 tokens. The task prefix alone consumes ~40 tokens, leaving ~216 tokens for actual text. Long inputs are:

  1. Split into sentences using spaCy
  2. Grouped into chunks that fit the 256-token budget
  3. Each chunk is corrected independently
  4. Results are joined back together

Why Post-Generation Vocabulary Elevation?

Rather than relying solely on the model to produce academic vocabulary (which T5-Small lacks the capacity for), a separate BERT-based lexical substitution pipeline is applied:

  1. POS-tag the output with spaCy
  2. Identify non-AWL content words (nouns, verbs, adjectives, adverbs)
  3. Mask each candidate โ†’ run BERT fill-mask โ†’ filter to AWL-only predictions
  4. Accept substitution only if semantic_similarity > 0.82 (measured with all-mpnet-base-v2)
  5. Track used substitutions to prevent duplicate replacements

Quick Start

Prerequisites

  • Python โ‰ฅ 3.10
  • NVIDIA GPU with โ‰ฅ 4GB VRAM (or CPU, slower)
  • ~10GB disk space for models and datasets

Option A: Self-Improving Upgrade Pipeline (v3)

This pipeline loads the existing Hub adapter, upgrades it, evaluates, and only pushes if it improves.

git clone https://huggingface.co/morpheuslord/rewrite && cd rewrite
pip install -r requirements.txt

export HF_TOKEN="your-hf-token-with-write-access"
python train_and_upgrade.py

The pipeline handles all 10 steps automatically: Load adapter โ†’ Warm-start merge โ†’ Apply r=16 LoRA โ†’ Load data โ†’ Train โ†’ Evaluate โ†’ Gate โ†’ Save โ†’ Merge โ†’ Push

Option B: Manual Step-by-Step (original pipeline)

# 1. Install dependencies
pip install -r requirements.txt
python -m spacy download en_core_web_sm

# 2. Preprocess datasets (FCE, W&I+LOCNESS, JFLEG โ†’ unified JSONL)
python scripts/preprocess_data.py

# 3. Pre-train the human pattern classifier
python scripts/pretrain_human_pattern_classifier.py

# 4. Train the correction model
PYTHONPATH=. python scripts/train.py --config configs/training_config.yaml --use-v2-loss

# 5. Merge LoRA adapter into base model for inference
python -c "
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small', torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, 'checkpoints/checkpoint-BEST')
model = model.merge_and_unload()
model.save_pretrained('checkpoints/best_model_merged')
AutoTokenizer.from_pretrained('google/flan-t5-small').save_pretrained('checkpoints/best_model_merged')
"

# 6. Run inference
PYTHONPATH=. python scripts/run_inference.py --text "The studnet recieved alot of informtion."

# 7. Or start the API server
PYTHONPATH=. python -m uvicorn src.api.main:app --host 0.0.0.0 --port 8000

Training Pipeline

v3 Upgrade Pipeline (train_and_upgrade.py) โ€” 10 Steps

Step Action
1 Load existing LoRA adapter (r=16, v2) from Hub
2 Merge into base weights (merge_and_unload) โ€” warm start
3 Apply fresh LoRA r=16 on merged base
4 Load JFLEG + W&I+LOCNESS + C4-GEC (optional); augment with DyslexiaSimulator (20% error rate)
5 Train with combined loss for 10 epochs, early stopping patience=3
6 Evaluate on test set: GLEU + BERTScore F1 + (1โˆ’WER) [+ ERRANT F0.5 if installed]
7 Apply semantic faithfulness gate โ€” revert outputs with cosine sim < 0.75 to source
8 Compare composite score against baseline_score.json
9 If improved: merge adapter โ†’ save full model
10 Push adapter (repo root) + merged model (merged/ subfolder) to Hub; update baseline

v2 Upgrade Pipeline โ€” 10 Steps

Step Action
1 Load existing LoRA adapter (r=8) from Hub
2 Merge into base weights (merge_and_unload) โ€” warm start
3 Apply fresh LoRA r=16 on merged base
4 Load JFLEG + W&I+LOCNESS; augment with DyslexiaSimulator (20% error rate)
5 Train with combined loss for 10 epochs, early stopping patience=3
6 Evaluate on test set: GLEU + BERTScore F1 + (1โˆ’WER)
7 Compare composite score against baseline_score.json
8 If improved: save LoRA adapter
9 Merge adapter โ†’ save full model
10 Push adapter + merged model to Hub; update baseline

v1 Original Pipeline (train.sh) โ€” 5 Stages

Stage Action
1 Setup & Dependencies
2 Data Preprocessing (FCE + W&I+LOCNESS + JFLEG โ†’ JSONL)
3 Human Pattern Classifier Pre-Training
4 Main Model Training (LoRA r=8, 5 epochs, CE only)
5 Evaluation (GLEU only)

Hyperparameter Reference

v3 (train_and_upgrade.py)

LORA_R          = 16
LORA_ALPHA      = 32
LORA_DROPOUT    = 0.05
TARGET_MODULES  = ["q", "v", "k", "o", "wi_0", "wi_1", "wo"]

EPOCHS          = 10
BATCH_SIZE      = 2            # per device (CPU); 8 on GPU
GRAD_ACCUM      = 32           # effective batch = 64
LR              = 2e-4
WARMUP_RATIO    = 0.10
LABEL_SMOOTHING = 0.1
MAX_INPUT_LEN   = 256          # up from 128 in v2
MAX_TARGET_LEN  = 256

LAMBDA_STYLE    = 0.3
LAMBDA_SEMANTIC = 0.5
LAMBDA_HUMAN    = 0.4          # GPU only

FAITHFULNESS_THRESHOLD = 0.75  # new in v3

v2 (train_and_upgrade.py)

LORA_R          = 16
LORA_ALPHA      = 32
LORA_DROPOUT    = 0.05
TARGET_MODULES  = ["q", "v", "k", "o", "wi_0", "wi_1", "wo"]

EPOCHS          = 10
BATCH_SIZE      = 2
GRAD_ACCUM      = 32           # effective batch = 64
LR              = 2e-4
WARMUP_RATIO    = 0.10
LABEL_SMOOTHING = 0.1
MAX_INPUT_LEN   = 128
MAX_TARGET_LEN  = 128

LAMBDA_STYLE    = 0.3
LAMBDA_SEMANTIC = 0.5
LAMBDA_HUMAN    = 0.4          # GPU only

v1 (configs/training_config.yaml)

lora:
  r: 8
  lora_alpha: 16
  lora_dropout: 0.05
  target_modules: [q, v, k, o, wi_0, wi_1, wo]

training:
  per_device_train_batch_size: 4
  gradient_accumulation_steps: 8  # effective batch = 32
  learning_rate: 3.0e-4
  lr_scheduler_type: cosine
  bf16: true

loss:
  lambda_style: 0.3
  lambda_semantic: 0.5
  lambda_human_pattern: 0.4

configs/inference_config.yaml

model:
  key: "flan-t5-small"
  checkpoint_path: "checkpoints/best_model_merged"
  use_lora: false

generation:
  num_beams: 5
  length_penalty: 1.2
  repetition_penalty: 1.3
  no_repeat_ngram_size: 3
  max_new_tokens: 256

vocabulary:
  semantic_threshold: 0.82

faithfulness:
  threshold: 0.75

Inference Pipeline (8 Steps)

Raw Text
  โ”‚
  โ–ผ
1. Preprocessing โ”€โ”€โ”€โ”€โ”€โ”€โ”€ LanguageTool spell correction + spaCy parsing
  โ”‚
  โ–ผ
2. Style Fingerprinting โ”€ Extract 41 features โ†’ MLP โ†’ 512-dim vector
  โ”‚
  โ–ผ
3. Sentence-Chunked Generation โ”€ Split into 256-token chunks โ†’ Flan-T5 โ†’ rejoin
  โ”‚
  โ–ผ
4. Faithfulness Gate โ”€โ”€โ”€โ”€ cosine_sim(source, output) < 0.75 โ†’ revert to source  [NEW v3]
  โ”‚
  โ–ผ
5. Post-Processing โ”€โ”€โ”€โ”€โ”€ Remove artifacts, replace em dashes, fix spacing
  โ”‚
  โ–ผ
6. Vocabulary Elevation โ”€ BERT fill-mask โ†’ AWL filtering โ†’ semantic gate (threshold 0.82)
  โ”‚
  โ–ผ
7. Register Filtering โ”€โ”€ Expand contractions, replace colloquialisms
  โ”‚
  โ–ผ
8. Metrics โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Style similarity, AWL coverage, readability scores
  โ”‚
  โ–ผ
Corrected Text

API Usage

# Start the server
PYTHONPATH=. python -m uvicorn src.api.main:app --host 0.0.0.0 --port 8000

# Correct text
curl -X POST http://localhost:8000/correct \
  -H "Content-Type: application/json" \
  -d '{"text": "The studnet recieved alot of informtion.", "style_alpha": 0.6}'

# Health check
curl http://localhost:8000/health

Interactive docs at http://localhost:8000/docs.


Hardware Requirements

Tier GPU LoRA Config Epochs Training Time
Tested (v1) RTX 3050 4GB r=8 5 ~45 min
Tested (v2 CPU) None (HF Space CPU Basic) r=16 10 ~12โ€“24 hours
Tested (v3 CPU) None (HF Space CPU Basic) r=16 10 ~12โ€“24 hours
Recommended RTX 3090 24GB r=16 + human-pattern loss 10 ~2โ€“3h
Maximum A100 80GB Full pipeline with GPT-2 perplexity + ERRANT 10 ~12h

Data Sources

Dataset Type Size Access
JFLEG (jhu-clsp/jfleg) Fluency corrections (4 refs each) ~5k pairs HF Hub, no registration
W&I+LOCNESS (bea2019st/wi_locness) Learner errors + corrections ~34k pairs HF Hub, no registration
C4-200M-GEC (cointegrated/c4_200m-gec-filtered) Synthetic GEC pairs ~100k pairs (capped) HF Hub, no registration โ€” falls back silently if unavailable
FCE v2.1 Learner errors + corrections ~28k pairs BEA-2019 (registration required)
Shanegerami AI_Human.csv Human vs AI classification ~50k samples Kaggle
Starblasters8 data.parquet Human vs AI classification ~50k samples Kaggle
Coxhead AWL Academic Word List 570 families / 549 headwords Victoria University

Note: train_and_upgrade.py uses JFLEG + W&I+LOCNESS + C4-GEC (freely accessible via HF Hub). FCE and Kaggle datasets are used in the full manual pipeline only.


Dyslexia Error Simulation

The DyslexiaSimulator generates synthetic training data based on research by Rello et al. (2013, 2017). v2+ uses a 20% per-word error rate (up from 15% in v1).

Error Type Frequency Example
Phonetic substitution 35% "because" โ†’ "becaus"
Letter transposition 18% "the" โ†’ "teh"
Letter omission 16% "important" โ†’ "importnt"
Letter doubling 12% "letter" โ†’ "lettter"
Letter reversal (b/d, p/q) 10% "bad" โ†’ "dad"
Word boundary errors 9% "a lot" โ†’ "alot"

Style Fingerprint Vector

The 512-dimensional style vector captures 41 raw features:

Group Features Count
Sentence stats mean, std, skew of sentence lengths 3
Word stats mean, std of word lengths 2
Lexical type-token ratio, lexical density 2
Syntactic passive/active voice ratio, subordinate clause ratio, avg dependency tree depth 4
Discourse 20 academic discourse markers (per 100 words) 20
Register hedging frequency, formality score, nominalization ratio 3
Readability Flesch reading ease, avg syllables per word 2
Pronouns first-person ratio, third-person ratio 2
Other question ratio, exclamation ratio, AWL coverage 3

Projected through a 2-layer MLP (41 โ†’ 256 โ†’ 512) with LayerNorm and GELU activation, then L2-normalised.


Known Limitations

  1. Model capacity: Flan-T5-Small (77M params) has limited correction ability compared to larger models. Doubling LoRA rank (r=8 โ†’ r=16) partially addresses this.
  2. Training window: 256-token max input (up from 128 in v1/v2) โ€” very long paragraphs may still be split mid-clause.
  3. Vocabulary elevation: BERT fill-mask can suggest semantically inappropriate AWL words; the 0.82 similarity threshold is a trade-off between coverage and accuracy.
  4. Already-correct text: The model is trained on errorโ†’correction pairs; feeding it clean text produces unpredictable output.
  5. LanguageTool latency: Spell correction takes ~15โ€“20s due to JVM startup on first call.
  6. Human-pattern loss on CPU: The GPT-2 perplexity-based loss is skipped on CPU for performance. Full loss is only active on GPU.
  7. Faithfulness gate conservatism: The 0.75 cosine similarity threshold occasionally reverts valid-but-heavily-corrected outputs. Outputs flagged as reverts are logged โ€” monitor num_fallback in evaluation to tune the threshold.