File size: 11,053 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
End-to-end inference pipeline.
Accepts raw dyslectic text (and optionally a master copy),
returns corrected academic text with metadata.
"""

from ..preprocessing.pipeline import PreprocessingPipeline
from ..style.fingerprinter import StyleFingerprinter
from ..vocabulary.lexical_substitution import LexicalElevator, RegisterFilter
from ..model.base_model import load_model_and_tokenizer
from ..model.style_conditioner import StyleConditioner, prepend_style_prefix
from ..model.generation_utils import generate_correction
from .postprocessor import PostProcessor
from ..evaluation.style_metrics import StyleEvaluator
from ..vocabulary.awl_loader import AWLLoader
import torch
from typing import Optional
from dataclasses import dataclass
from loguru import logger
import yaml


TASK_PREFIX = (
    "Correct the following text for grammar, spelling, and clarity. "
    "Maintain the author's original tone and writing style. "
    "Elevate vocabulary to academic register. "
    "Do NOT change the meaning or add new information. "
    "Preserve named entities exactly. "
    "Text to correct: "
)


@dataclass
class CorrectionResult:
    original: str
    corrected: str
    preprocessed: str
    style_similarity: float
    awl_coverage: float
    readability: dict
    changes_summary: str


class AcademicCorrector:
    """Full inference pipeline: preprocess → fingerprint → generate → elevate → filter."""

    def __init__(self, config: dict):
        logger.info("Initialising AcademicCorrector...")

        model_cfg = config.get("model", {})
        gen_cfg = config.get("generation", {})
        vocab_cfg = config.get("vocabulary", {})
        style_cfg = config.get("style_conditioner", {})

        # 1. Load model and tokenizer
        model_key = model_cfg.get("key", "flan-t5-small")
        checkpoint = model_cfg.get("checkpoint_path", None)
        use_lora = model_cfg.get("use_lora", False)

        if checkpoint and use_lora:
            # PEFT adapter checkpoint: load base model + apply adapter
            import os
            try:
                from peft import PeftModel
                logger.info(f"Loading base model '{model_key}' + PEFT adapter from '{checkpoint}'")
                self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
                    model_key, quantize=False, use_lora=False
                )
                self.model = PeftModel.from_pretrained(self.model, checkpoint)
                logger.info(f"PEFT adapter loaded from {checkpoint}")
            except Exception as e:
                logger.warning(f"PEFT loading failed ({e}), loading base model only")
                self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
                    model_key, quantize=False, use_lora=False
                )
        elif checkpoint:
            # Full model checkpoint (merged weights)
            try:
                from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
                self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
                self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
                self.is_seq2seq = True
                logger.info(f"Loaded full model from checkpoint: {checkpoint}")
            except Exception:
                logger.warning(f"Checkpoint not found, loading base model: {model_key}")
                self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
                    model_key, quantize=False, use_lora=False
                )
        else:
            self.model, self.tokenizer, self.is_seq2seq = load_model_and_tokenizer(
                model_key, quantize=False, use_lora=False
            )

        self.model.eval()
        self.generation_config = gen_cfg

        # 2. Preprocessor
        self.preprocessor = PreprocessingPipeline()

        # 3. Style fingerprinter
        fp_cfg = config.get("fingerprinter", {})
        self.fingerprinter = StyleFingerprinter(
            spacy_model=fp_cfg.get("spacy_model", "en_core_web_sm"),
            awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
        )

        # 4. Style conditioner — auto-detect hidden dim from loaded model
        if hasattr(self.model.config, "d_model"):
            auto_hidden_dim = self.model.config.d_model
        elif hasattr(self.model.config, "hidden_size"):
            auto_hidden_dim = self.model.config.hidden_size
        else:
            auto_hidden_dim = 512  # Safe default for T5-Small
        logger.info(f"Auto-detected model hidden dim: {auto_hidden_dim}")

        self.conditioner = StyleConditioner(
            style_dim=style_cfg.get("style_dim", 512),
            model_hidden_dim=style_cfg.get("model_hidden_dim", auto_hidden_dim),
            n_prefix_tokens=style_cfg.get("n_prefix_tokens", 10),
        )
        self.conditioner.eval()

        # 5. Vocabulary elevator
        try:
            self.elevator = LexicalElevator(
                awl_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"),
                spacy_model="en_core_web_sm",
                mlm_model=vocab_cfg.get("mlm_model", "bert-large-uncased"),
                sem_model=vocab_cfg.get("sem_model", "all-mpnet-base-v2"),
            )
        except Exception as e:
            logger.warning(f"Lexical elevator init failed: {e}, elevation disabled")
            self.elevator = None

        # 6. Register filter
        self.register_filter = RegisterFilter()

        # 7. Post-processor
        self.postprocessor = PostProcessor()

        # 8. Evaluator
        awl = AWLLoader(primary_path=vocab_cfg.get("awl_path", "data/awl/coxhead_awl.txt"))
        self.evaluator = StyleEvaluator(self.fingerprinter, awl)

        logger.info("AcademicCorrector initialised successfully")

    def correct(
        self,
        raw_text: str,
        master_copy: Optional[str] = None,
        style_alpha: float = 0.6,
    ) -> CorrectionResult:
        """
        Full correction pipeline:
        1. Pre-process (spell correct + parse)
        2. Style fingerprint
        3. Generate with style conditioning
        4. Academic vocabulary elevation
        5. Register filter
        6. Compute quality metrics
        """
        # Step 1: Pre-process
        logger.info("Step 1: Preprocessing...")
        doc = self.preprocessor.process(raw_text)

        # Step 2: Style fingerprint
        logger.info("Step 2: Extracting style fingerprint...")
        user_style = self.fingerprinter.extract_vector(doc.corrected_text)

        if master_copy:
            master_style = self.fingerprinter.extract_vector(master_copy)
            target_style = self.fingerprinter.blend_vectors(user_style, master_style, alpha=style_alpha)
        else:
            target_style = user_style

        # Step 3: Generate correction (sentence-chunked)
        # The model was trained on max_input_length=128 tokens.
        # Split text into sentence groups that fit within that window,
        # process each chunk, then reassemble.
        logger.info("Step 3: Generating correction (chunked)...")

        MAX_INPUT_TOKENS = 128
        # Measure how many tokens the task prefix uses
        prefix_tokens = len(self.tokenizer.encode(TASK_PREFIX, add_special_tokens=False))
        budget = MAX_INPUT_TOKENS - prefix_tokens - 2  # 2 for special tokens

        # Split into sentences using spaCy (already loaded for fingerprinting)
        sent_doc = self.fingerprinter.nlp(doc.corrected_text)
        sentences = [sent.text.strip() for sent in sent_doc.sents if sent.text.strip()]

        # Group sentences into chunks that fit the token budget
        chunks = []
        current_chunk = []
        current_tokens = 0

        for sent in sentences:
            sent_tokens = len(self.tokenizer.encode(sent, add_special_tokens=False))
            if current_tokens + sent_tokens > budget and current_chunk:
                chunks.append(" ".join(current_chunk))
                current_chunk = [sent]
                current_tokens = sent_tokens
            else:
                current_chunk.append(sent)
                current_tokens += sent_tokens

        if current_chunk:
            chunks.append(" ".join(current_chunk))

        logger.info(f"  Split into {len(chunks)} chunks from {len(sentences)} sentences")

        corrected_chunks = []
        device = next(self.model.parameters()).device

        for i, chunk in enumerate(chunks):
            chunk_input = TASK_PREFIX + chunk
            inputs = self.tokenizer(
                chunk_input,
                max_length=MAX_INPUT_TOKENS,
                truncation=True,
                return_tensors="pt",
            )

            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)

            chunk_output = generate_correction(
                self.model,
                self.tokenizer,
                input_ids,
                attention_mask,
                self.generation_config,
            )
            corrected_chunks.append(chunk_output)
            logger.debug(f"  Chunk {i+1}/{len(chunks)}: {len(chunk.split())}{len(chunk_output.split())} words")

        generated = " ".join(corrected_chunks)

        # Step 4: Post-process
        logger.info("Step 4: Post-processing...")
        generated = self.postprocessor.clean(generated)
        generated = self.postprocessor.restore_entities(
            generated,
            [e.text for e in doc.entities],
            doc.protected_spans,
        )

        # Step 5: Vocabulary elevation
        logger.info("Step 5: Vocabulary elevation...")
        if self.elevator:
            try:
                generated = self.elevator.elevate(generated, doc.protected_spans)
            except Exception as e:
                logger.warning(f"Vocabulary elevation failed: {e}")

        # Step 6: Register filter
        logger.info("Step 6: Register filtering...")
        generated = self.register_filter.apply(generated)

        # Final formatting
        generated = self.postprocessor.format_output(generated)

        # Step 7: Compute quality metrics
        logger.info("Step 7: Computing metrics...")
        style_sim = self.evaluator.style_similarity(raw_text, generated)
        awl_cov = self.evaluator.awl_coverage(generated)

        # Build changes summary
        changes = []
        if doc.original_text != doc.corrected_text:
            changes.append("Spelling/grammar corrections applied")
        if generated != doc.corrected_text:
            changes.append("Text restructured and elevated")
        changes_summary = "; ".join(changes) if changes else "No changes needed"

        return CorrectionResult(
            original=raw_text,
            corrected=generated,
            preprocessed=doc.corrected_text,
            style_similarity=style_sim,
            awl_coverage=awl_cov,
            readability=doc.readability,
            changes_summary=changes_summary,
        )