| """ |
| End-to-end inference pipeline. |
| Accepts raw dyslectic text (and optionally a master copy), |
| returns corrected academic text with metadata. |
| """ |
|
|
| from ..preprocessing.pipeline import PreprocessingPipeline |
| from ..style.fingerprinter import StyleFingerprinter |
| from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter |
| from ..model.base_model import load_model_and_tokenizer |
| from ..model.style_conditioner import StyleConditioner, prepend_style_prefix |
| from ..model.generation_utils import generate_correction |
| from .postprocessor import PostProcessor |
| from ..evaluation.style_metrics import StyleEvaluator |
| from ..vocabulary.awl_loader import AWLLoader |
| import torch |
| from typing import Optional |
| from dataclasses import dataclass |
| from loguru import logger |
| import yaml |
|
|
|
|
| TASK_PREFIX = ( |
| "Correct the following text for grammar, spelling, and clarity. " |
| "Maintain the author's original tone and writing style. " |
| "Elevate vocabulary to academic register. " |
| "Do NOT change the meaning or add new information. " |
| "Preserve named entities exactly. " |
| "Text to correct: " |
| ) |
|
|
|
|
| @dataclass |
| class CorrectionResult: |
| original: str |
| corrected: str |
| preprocessed: str |
| style_similarity: float |
| awl_coverage: float |
| readability: dict |
| changes_summary: str |
|
|
|
|
| class AcademicCorrector: |
| """Full inference pipeline: preprocess → fingerprint → generate → elevate → filter.""" |
|
|
| def __init__(self, config: dict): |
| logger.info("Initialising AcademicCorrector...") |
|
|
| model_cfg = config.get("model", {}) |
| gen_cfg = config.get("generation", {}) |
| vocab_cfg = config.get("vocabulary", {}) |
| style_cfg = config.get("style_conditioner", {}) |
|
|
| |
| model_key = model_cfg.get("key", "flan-t5-small") |
| checkpoint = model_cfg.get("checkpoint_path", None) |
| use_lora = model_cfg.get("use_lora", False) |
|
|
| if checkpoint and use_lora: |
| |
| import os |
| try: |
| from peft import PeftModel |
| logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'") |
| self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( |
| model_key, quantize=False, use_lora=False |
| ) |
| self.model = PeftModel.from_pretrained(self.model, checkpoint) |
| logger.info(f"PEFT adapter loaded from {checkpoint}") |
| except Exception as e: |
| logger.warning(f"PEFT loading failed ({e}), loading base model only") |
| self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( |
| model_key, quantize=False, use_lora=False |
| ) |
| elif checkpoint: |
| |
| try: |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
| self.tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
| self.is_seq2seq = True |
| logger.info(f"Loaded full model from checkpoint: {checkpoint}") |
| except Exception: |
| logger.warning(f"Checkpoint not found, loading base model: {model_key}") |
| self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( |
| model_key, quantize=False, use_lora=False |
| ) |
| else: |
| self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( |
| model_key, quantize=False, use_lora=False |
| ) |
|
|
| self.model.eval() |
| self.generation_config = gen_cfg |
|
|
| |
| self.preprocessor = PreprocessingPipeline() |
|
|
| |
| fp_cfg = config.get("fingerprinter", {}) |
| self.fingerprinter = StyleFingerprinter( |
| spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"), |
| awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), |
| ) |
|
|
| |
| if hasattr(self.model.config, "d_model"): |
| auto_hidden_dim = self.model.config.d_model |
| elif hasattr(self.model.config, "hidden_size"): |
| auto_hidden_dim = self.model.config.hidden_size |
| else: |
| auto_hidden_dim = 512 |
| logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}") |
|
|
| self.conditioner = StyleConditioner( |
| style_dim=style_cfg.get("style_dim", 512), |
| model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim), |
| n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10), |
| ) |
| self.conditioner.eval() |
|
|
| |
| try: |
| self.elevator = LexicalElevator( |
| awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), |
| spacy_model="en_core_web_sm", |
| mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"), |
| sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"), |
| ) |
| except Exception as e: |
| logger.warning(f"Lexical elevator init failed: {e}, elevation disabled") |
| self.elevator = None |
|
|
| |
| self.register_filter = RegisterFilter() |
|
|
| |
| self.postprocessor = PostProcessor() |
|
|
| |
| awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt")) |
| self.evaluator = StyleEvaluator(self.fingerprinter, awl) |
|
|
| logger.info("AcademicCorrector initialised successfully") |
|
|
| def correct( |
| self, |
| raw_text: str, |
| master_copy: Optional[str] = None, |
| style_alpha: float = 0.6, |
| ) -> CorrectionResult: |
| """ |
| Full correction pipeline: |
| 1. Pre-process (spell correct + parse) |
| 2. Style fingerprint |
| 3. Generate with style conditioning |
| 4. Academic vocabulary elevation |
| 5. Register filter |
| 6. Compute quality metrics |
| """ |
| |
| logger.info("Step 1: Preprocessing...") |
| doc = self.preprocessor.process(raw_text) |
|
|
| |
| logger.info("Step 2: Extracting style fingerprint...") |
| user_style = self.fingerprinter.extract_vector(doc.corrected_text) |
|
|
| if master_copy: |
| master_style = self.fingerprinter.extract_vector(master_copy) |
| target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha) |
| else: |
| target_style = user_style |
|
|
| |
| |
| |
| |
| logger.info("Step 3: Generating correction (chunked)...") |
|
|
| MAX_INPUT_TOKENS = 128 |
| |
| prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False)) |
| budget = MAX_INPUT_TOKENS - prefix_tokens - 2 |
|
|
| |
| sent_doc = self.fingerprinter.nlp(doc.corrected_text) |
| sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()] |
|
|
| |
| chunks = [] |
| current_chunk = [] |
| current_tokens = 0 |
|
|
| for sent in sentences: |
| sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False)) |
| if current_tokens + sent_tokens > budget and current_chunk: |
| chunks.append(" ".join(current_chunk)) |
| current_chunk = [sent] |
| current_tokens = sent_tokens |
| else: |
| current_chunk.append(sent) |
| current_tokens += sent_tokens |
|
|
| if current_chunk: |
| chunks.append(" ".join(current_chunk)) |
|
|
| logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences") |
|
|
| corrected_chunks = [] |
| device = next(self.model.parameters()).device |
|
|
| for i, chunk in enumerate(chunks): |
| chunk_input = TASK_PREFIX + chunk |
| inputs = self.tokenizer( |
| chunk_input, |
| max_length=MAX_INPUT_TOKENS, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| input_ids = inputs["input_ids"].to(device) |
| attention_mask = inputs["attention_mask"].to(device) |
|
|
| chunk_output = generate_correction( |
| self.model, |
| self.tokenizer, |
| input_ids, |
| attention_mask, |
| self.generation_config, |
| ) |
| corrected_chunks.append(chunk_output) |
| logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())} → {len(chunk_output.split())} words") |
|
|
| generated = " ".join(corrected_chunks) |
|
|
| |
| logger.info("Step 4: Post-processing...") |
| generated = self.postprocessor.clean(generated) |
| generated = self.postprocessor.restore_entities( |
| generated, |
| [e.text for e in doc.entities], |
| doc.protected_spans, |
| ) |
|
|
| |
| logger.info("Step 5: Vocabulary elevation...") |
| if self.elevator: |
| try: |
| generated = self.elevator.elevate(generated, doc.protected_spans) |
| except Exception as e: |
| logger.warning(f"Vocabulary elevation failed: {e}") |
|
|
| |
| logger.info("Step 6: Register filtering...") |
| generated = self.register_filter.apply(generated) |
|
|
| |
| generated = self.postprocessor.format_output(generated) |
|
|
| |
| logger.info("Step 7: Computing metrics...") |
| style_sim = self.evaluator.style_similarity(raw_text, generated) |
| awl_cov = self.evaluator.awl_coverage(generated) |
|
|
| |
| changes = [] |
| if doc.original_text != doc.corrected_text: |
| changes.append("Spelling/grammar corrections applied") |
| if generated != doc.corrected_text: |
| changes.append("Text restructured and elevated") |
| changes_summary = "; ".join(changes) if changes else "No changes needed" |
|
|
| return CorrectionResult( |
| original=raw_text, |
| corrected=generated, |
| preprocessed=doc.corrected_text, |
| style_similarity=style_sim, |
| awl_coverage=awl_cov, |
| readability=doc.readability, |
| changes_summary=changes_summary, |
| ) |
|
|