rewrite / src /preprocessing /pipeline.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Master pre-processing pipeline. Runs all NLP stages in sequence.
Returns a PreprocessedDoc object with all annotations attached.
"""
import spacy
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from .spell_corrector import DyslexiaAwareSpellCorrector
import textstat
from loguru import logger
@dataclass
class EntitySpan:
text: str
label: str
start_char: int
end_char: int
@dataclass
class PreprocessedDoc:
original_text: str
corrected_text: str
sentences: List[str]
entities: List[EntitySpan] # Never to be modified by rewriter
dependency_trees: List[Dict] # Grammatical skeletons per sentence
pos_tags: List[List[tuple]] # (token, POS) per sentence
readability: Dict[str, float] # Flesch-Kincaid, Gunning Fog, etc.
sentence_lengths: List[int]
protected_spans: List[tuple] # (start, end) char spans to never touch
class PreprocessingPipeline:
"""Orchestrates all pre-processing stages: spell correction, parsing, NER, readability."""
def __init__(self, model_name: str = "en_core_web_trf"):
# Load spaCy model with fallback
try:
self.nlp = spacy.load(model_name)
except OSError:
logger.warning(f"spaCy model '{model_name}' not found, falling back to 'en_core_web_sm'")
self.nlp = spacy.load("en_core_web_sm")
# Initialise spell corrector
self.spell_corrector = DyslexiaAwareSpellCorrector()
logger.info("PreprocessingPipeline initialised")
def _extract_readability(self, text: str) -> Dict[str, float]:
"""Compute readability scores (Flesch-Kincaid, Gunning Fog, etc.)."""
if not text or not text.strip():
return {
"flesch_kincaid_grade": 0.0,
"gunning_fog": 0.0,
"smog_index": 0.0,
"automated_readability_index": 0.0,
"flesch_reading_ease": 0.0,
"coleman_liau_index": 0.0,
}
return {
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"gunning_fog": textstat.gunning_fog(text),
"smog_index": textstat.smog_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
}
def _extract_dep_tree(self, sent) -> Dict:
"""Extract grammatical skeleton: subject-verb-object per sentence."""
subjects = []
verbs = []
objects = []
for token in sent:
if token.dep_ in ("nsubj", "nsubjpass"):
subjects.append(token.text)
if token.head.pos_ == "VERB":
verbs.append(token.head.text)
elif token.dep_ in ("dobj", "pobj", "attr"):
objects.append(token.text)
return {
"sentence": sent.text,
"subjects": subjects,
"verbs": list(dict.fromkeys(verbs)),
"objects": objects,
"root": sent.root.text if sent.root else "",
}
def process(self, raw_text: str) -> PreprocessedDoc:
"""Run full pre-processing pipeline on raw text.
7-step pipeline:
1. Spell correction (phonetic + spellcheck + grammar)
2. spaCy parsing
3. Sentence segmentation
4. Named entity recognition
5. Dependency tree extraction
6. POS tagging
7. Readability scoring
"""
if not raw_text or not raw_text.strip():
return PreprocessedDoc(
original_text=raw_text,
corrected_text=raw_text or "",
sentences=[],
entities=[],
dependency_trees=[],
pos_tags=[],
readability=self._extract_readability(""),
sentence_lengths=[],
protected_spans=[],
)
# Step 1: Spell correction
corrected = self.spell_corrector.correct(raw_text)
# Step 2: Parse corrected text with spaCy
doc = self.nlp(corrected)
# Step 3: Sentence segmentation
sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
# Step 4: NER — extract entities and protected spans
entities = []
protected_spans = []
for ent in doc.ents:
entities.append(EntitySpan(
text=ent.text,
label=ent.label_,
start_char=ent.start_char,
end_char=ent.end_char,
))
protected_spans.append((ent.start_char, ent.end_char))
# Step 5: Dependency trees per sentence
dependency_trees = []
for sent in doc.sents:
dependency_trees.append(self._extract_dep_tree(sent))
# Step 6: POS tags per sentence
pos_tags = []
for sent in doc.sents:
sent_tags = [(token.text, token.pos_) for token in sent]
pos_tags.append(sent_tags)
# Step 7: Readability
readability = self._extract_readability(corrected)
# Sentence lengths
sentence_lengths = [len(s.split()) for s in sentences]
return PreprocessedDoc(
original_text=raw_text,
corrected_text=corrected,
sentences=sentences,
entities=entities,
dependency_trees=dependency_trees,
pos_tags=pos_tags,
readability=readability,
sentence_lengths=sentence_lengths,
protected_spans=protected_spans,
)