rewrite / src /inference /corrector.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
End-to-end inference pipeline.
Accepts raw dyslectic text (and optionally a master copy),
returns corrected academic text with metadata.
"""
from ..preprocessing.pipeline import PreprocessingPipeline
from ..style.fingerprinter import StyleFingerprinter
from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter
from ..model.base_model import load_model_and_tokenizer
from ..model.style_conditioner import StyleConditioner, prepend_style_prefix
from ..model.generation_utils import generate_correction
from .postprocessor import PostProcessor
from ..evaluation.style_metrics import StyleEvaluator
from ..vocabulary.awl_loader import AWLLoader
import torch
from typing import Optional
from dataclasses import dataclass
from loguru import logger
import yaml
TASK_PREFIX = (
"Correct the following text for grammar, spelling, and clarity. "
"Maintain the author's original tone and writing style. "
"Elevate vocabulary to academic register. "
"Do NOT change the meaning or add new information. "
"Preserve named entities exactly. "
"Text to correct: "
)
@dataclass
class CorrectionResult:
original: str
corrected: str
preprocessed: str
style_similarity: float
awl_coverage: float
readability: dict
changes_summary: str
class AcademicCorrector:
"""Full inference pipeline: preprocess → fingerprint → generate → elevate → filter."""
def __init__(self, config: dict):
logger.info("Initialising AcademicCorrector...")
model_cfg = config.get("model", {})
gen_cfg = config.get("generation", {})
vocab_cfg = config.get("vocabulary", {})
style_cfg = config.get("style_conditioner", {})
# 1. Load model and tokenizer
model_key = model_cfg.get("key", "flan-t5-small")
checkpoint = model_cfg.get("checkpoint_path", None)
use_lora = model_cfg.get("use_lora", False)
if checkpoint and use_lora:
# PEFT adapter checkpoint: load base model + apply adapter
import os
try:
from peft import PeftModel
logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'")
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
model_key, quantize=False, use_lora=False
)
self.model = PeftModel.from_pretrained(self.model, checkpoint)
logger.info(f"PEFT adapter loaded from {checkpoint}")
except Exception as e:
logger.warning(f"PEFT loading failed ({e}), loading base model only")
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
model_key, quantize=False, use_lora=False
)
elif checkpoint:
# Full model checkpoint (merged weights)
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.is_seq2seq = True
logger.info(f"Loaded full model from checkpoint: {checkpoint}")
except Exception:
logger.warning(f"Checkpoint not found, loading base model: {model_key}")
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
model_key, quantize=False, use_lora=False
)
else:
self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
model_key, quantize=False, use_lora=False
)
self.model.eval()
self.generation_config = gen_cfg
# 2. Preprocessor
self.preprocessor = PreprocessingPipeline()
# 3. Style fingerprinter
fp_cfg = config.get("fingerprinter", {})
self.fingerprinter = StyleFingerprinter(
spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"),
awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
)
# 4. Style conditioner — auto-detect hidden dim from loaded model
if hasattr(self.model.config, "d_model"):
auto_hidden_dim = self.model.config.d_model
elif hasattr(self.model.config, "hidden_size"):
auto_hidden_dim = self.model.config.hidden_size
else:
auto_hidden_dim = 512 # Safe default for T5-Small
logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}")
self.conditioner = StyleConditioner(
style_dim=style_cfg.get("style_dim", 512),
model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim),
n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10),
)
self.conditioner.eval()
# 5. Vocabulary elevator
try:
self.elevator = LexicalElevator(
awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
spacy_model="en_core_web_sm",
mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"),
sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"),
)
except Exception as e:
logger.warning(f"Lexical elevator init failed: {e}, elevation disabled")
self.elevator = None
# 6. Register filter
self.register_filter = RegisterFilter()
# 7. Post-processor
self.postprocessor = PostProcessor()
# 8. Evaluator
awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"))
self.evaluator = StyleEvaluator(self.fingerprinter, awl)
logger.info("AcademicCorrector initialised successfully")
def correct(
self,
raw_text: str,
master_copy: Optional[str] = None,
style_alpha: float = 0.6,
) -> CorrectionResult:
"""
Full correction pipeline:
1. Pre-process (spell correct + parse)
2. Style fingerprint
3. Generate with style conditioning
4. Academic vocabulary elevation
5. Register filter
6. Compute quality metrics
"""
# Step 1: Pre-process
logger.info("Step 1: Preprocessing...")
doc = self.preprocessor.process(raw_text)
# Step 2: Style fingerprint
logger.info("Step 2: Extracting style fingerprint...")
user_style = self.fingerprinter.extract_vector(doc.corrected_text)
if master_copy:
master_style = self.fingerprinter.extract_vector(master_copy)
target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha)
else:
target_style = user_style
# Step 3: Generate correction (sentence-chunked)
# The model was trained on max_input_length=128 tokens.
# Split text into sentence groups that fit within that window,
# process each chunk, then reassemble.
logger.info("Step 3: Generating correction (chunked)...")
MAX_INPUT_TOKENS = 128
# Measure how many tokens the task prefix uses
prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False))
budget = MAX_INPUT_TOKENS - prefix_tokens - 2 # 2 for special tokens
# Split into sentences using spaCy (already loaded for fingerprinting)
sent_doc = self.fingerprinter.nlp(doc.corrected_text)
sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()]
# Group sentences into chunks that fit the token budget
chunks = []
current_chunk = []
current_tokens = 0
for sent in sentences:
sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False))
if current_tokens + sent_tokens > budget and current_chunk:
chunks.append(" ".join(current_chunk))
current_chunk = [sent]
current_tokens = sent_tokens
else:
current_chunk.append(sent)
current_tokens += sent_tokens
if current_chunk:
chunks.append(" ".join(current_chunk))
logger.info(f" Split into {len(chunks)} chunks from {len(sentences)} sentences")
corrected_chunks = []
device = next(self.model.parameters()).device
for i, chunk in enumerate(chunks):
chunk_input = TASK_PREFIX + chunk
inputs = self.tokenizer(
chunk_input,
max_length=MAX_INPUT_TOKENS,
truncation=True,
return_tensors="pt",
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
chunk_output = generate_correction(
self.model,
self.tokenizer,
input_ids,
attention_mask,
self.generation_config,
)
corrected_chunks.append(chunk_output)
logger.debug(f" Chunk {i+1}/{len(chunks)}: {len(chunk.split())}{len(chunk_output.split())} words")
generated = " ".join(corrected_chunks)
# Step 4: Post-process
logger.info("Step 4: Post-processing...")
generated = self.postprocessor.clean(generated)
generated = self.postprocessor.restore_entities(
generated,
[e.text for e in doc.entities],
doc.protected_spans,
)
# Step 5: Vocabulary elevation
logger.info("Step 5: Vocabulary elevation...")
if self.elevator:
try:
generated = self.elevator.elevate(generated, doc.protected_spans)
except Exception as e:
logger.warning(f"Vocabulary elevation failed: {e}")
# Step 6: Register filter
logger.info("Step 6: Register filtering...")
generated = self.register_filter.apply(generated)
# Final formatting
generated = self.postprocessor.format_output(generated)
# Step 7: Compute quality metrics
logger.info("Step 7: Computing metrics...")
style_sim = self.evaluator.style_similarity(raw_text, generated)
awl_cov = self.evaluator.awl_coverage(generated)
# Build changes summary
changes = []
if doc.original_text != doc.corrected_text:
changes.append("Spelling/grammar corrections applied")
if generated != doc.corrected_text:
changes.append("Text restructured and elevated")
changes_summary = "; ".join(changes) if changes else "No changes needed"
return CorrectionResult(
original=raw_text,
corrected=generated,
preprocessed=doc.corrected_text,
style_similarity=style_sim,
awl_coverage=awl_cov,
readability=doc.readability,
changes_summary=changes_summary,
)