""" End-to-end inference pipeline. Accepts raw dyslectic text (and optionally a master copy), returns corrected academic text with metadata. """ from ..preprocessing.pipeline import PreprocessingPipeline from ..style.fingerprinter import StyleFingerprinter from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter from ..model.base_model import load_model_and_tokenizer from ..model.style_conditioner import StyleConditioner, prepend_style_prefix from ..model.generation_utils import generate_correction from .postprocessor import PostProcessor from ..evaluation.style_metrics import StyleEvaluator from ..vocabulary.awl_loader import AWLLoader import torch from typing import Optional from dataclasses import dataclass from loguru import logger import yaml TASK_PREFIX = ( "Correct the following text for grammar, spelling, and clarity. " "Maintain the author's original tone and writing style. " "Elevate vocabulary to academic register. " "Do NOT change the meaning or add new information. " "Preserve named entities exactly. " "Text to correct: " ) @dataclass class CorrectionResult: original: str corrected: str preprocessed: str style_similarity: float awl_coverage: float readability: dict changes_summary: str class AcademicCorrector: """Full inference pipeline: preprocess → fingerprint → generate → elevate → filter.""" def __init__(self, config: dict): logger.info("Initialising AcademicCorrector...") model_cfg = config.get("model", {}) gen_cfg = config.get("generation", {}) vocab_cfg = config.get("vocabulary", {}) style_cfg = config.get("style_conditioner", {}) # 1. Load model and tokenizer model_key = model_cfg.get("key", "flan-t5-small") checkpoint = model_cfg.get("checkpoint_path", None) use_lora = model_cfg.get("use_lora", False) if checkpoint and use_lora: # PEFT adapter checkpoint: load base model + apply adapter import os try: from peft import PeftModel logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'") self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( model_key, quantize=False, use_lora=False ) self.model = PeftModel.from_pretrained(self.model, checkpoint) logger.info(f"PEFT adapter loaded from {checkpoint}") except Exception as e: logger.warning(f"PEFT loading failed ({e}), loading base model only") self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( model_key, quantize=False, use_lora=False ) elif checkpoint: # Full model checkpoint (merged weights) try: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) self.tokenizer = AutoTokenizer.from_pretrained(checkpoint) self.is_seq2seq = True logger.info(f"Loaded full model from checkpoint: {checkpoint}") except Exception: logger.warning(f"Checkpoint not found, loading base model: {model_key}") self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( model_key, quantize=False, use_lora=False ) else: self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer( model_key, quantize=False, use_lora=False ) self.model.eval() self.generation_config = gen_cfg # 2. Preprocessor self.preprocessor = PreprocessingPipeline() # 3. Style fingerprinter fp_cfg = config.get("fingerprinter", {}) self.fingerprinter = StyleFingerprinter( spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"), awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), ) # 4. Style conditioner — auto-detect hidden dim from loaded model if hasattr(self.model.config, "d_model"): auto_hidden_dim = self.model.config.d_model elif hasattr(self.model.config, "hidden_size"): auto_hidden_dim = self.model.config.hidden_size else: auto_hidden_dim = 512 # Safe default for T5-Small logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}") self.conditioner = StyleConditioner( style_dim=style_cfg.get("style_dim", 512), model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim), n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10), ) self.conditioner.eval() # 5. Vocabulary elevator try: self.elevator = LexicalElevator( awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"), spacy_model="en_core_web_sm", mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"), sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"), ) except Exception as e: logger.warning(f"Lexical elevator init failed: {e}, elevation disabled") self.elevator = None # 6. Register filter self.register_filter = RegisterFilter() # 7. Post-processor self.postprocessor = PostProcessor() # 8. Evaluator awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt")) self.evaluator = StyleEvaluator(self.fingerprinter, awl) logger.info("AcademicCorrector initialised successfully") def correct( self, raw_text: str, master_copy: Optional[str] = None, style_alpha: float = 0.6, ) -> CorrectionResult: """ Full correction pipeline: 1. Pre-process (spell correct + parse) 2. Style fingerprint 3. Generate with style conditioning 4. Academic vocabulary elevation 5. Register filter 6. Compute quality metrics """ # Step 1: Pre-process logger.info("Step 1: Preprocessing...") doc = self.preprocessor.process(raw_text) # Step 2: Style fingerprint logger.info("Step 2: Extracting style fingerprint...") user_style = self.fingerprinter.extract_vector(doc.corrected_text) if master_copy: master_style = self.fingerprinter.extract_vector(master_copy) target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha) else: target_style = user_style # Step 3: Generate correction (sentence-chunked) # The model was trained on max_input_length=128 tokens. # Split text into sentence groups that fit within that window, # process each chunk, then reassemble. logger.info("Step 3: Generating correction (chunked)...") MAX_INPUT_TOKENS = 128 # Measure how many tokens the task prefix uses prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False)) budget = MAX_INPUT_TOKENS - prefix_tokens - 2 # 2 for special tokens # Split into sentences using spaCy (already loaded for fingerprinting) sent_doc = self.fingerprinter.nlp(doc.corrected_text) sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()] # Group sentences into chunks that fit the token budget chunks = [] current_chunk = [] current_tokens = 0 for sent in sentences: sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False)) if current_tokens + sent_tokens > budget and current_chunk: chunks.append(" ".join(current_chunk)) current_chunk = [sent] current_tokens = sent_tokens else: current_chunk.append(sent) current_tokens += sent_tokens if current_chunk: chunks.append(" ".join(current_chunk)) logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences") corrected_chunks = [] device = next(self.model.parameters()).device for i, chunk in enumerate(chunks): chunk_input = TASK_PREFIX + chunk inputs = self.tokenizer( chunk_input, max_length=MAX_INPUT_TOKENS, truncation=True, return_tensors="pt", ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) chunk_output = generate_correction( self.model, self.tokenizer, input_ids, attention_mask, self.generation_config, ) corrected_chunks.append(chunk_output) logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())} → {len(chunk_output.split())} words") generated = " ".join(corrected_chunks) # Step 4: Post-process logger.info("Step 4: Post-processing...") generated = self.postprocessor.clean(generated) generated = self.postprocessor.restore_entities( generated, [e.text for e in doc.entities], doc.protected_spans, ) # Step 5: Vocabulary elevation logger.info("Step 5: Vocabulary elevation...") if self.elevator: try: generated = self.elevator.elevate(generated, doc.protected_spans) except Exception as e: logger.warning(f"Vocabulary elevation failed: {e}") # Step 6: Register filter logger.info("Step 6: Register filtering...") generated = self.register_filter.apply(generated) # Final formatting generated = self.postprocessor.format_output(generated) # Step 7: Compute quality metrics logger.info("Step 7: Computing metrics...") style_sim = self.evaluator.style_similarity(raw_text, generated) awl_cov = self.evaluator.awl_coverage(generated) # Build changes summary changes = [] if doc.original_text != doc.corrected_text: changes.append("Spelling/grammar corrections applied") if generated != doc.corrected_text: changes.append("Text restructured and elevated") changes_summary = "; ".join(changes) if changes else "No changes needed" return CorrectionResult( original=raw_text, corrected=generated, preprocessed=doc.corrected_text, style_similarity=style_sim, awl_coverage=awl_cov, readability=doc.readability, changes_summary=changes_summary, )