""" Post-processing utilities for generated text. Handles cleanup, formatting, and final quality checks. """ import re from typing import List, Tuple from loguru import logger class PostProcessor: """Cleans and formats generated text after model output.""" # Common generation artifacts to remove ARTIFACTS = [ r'', r'', r'', r'', r'\[PAD\]', r'\[CLS\]', r'\[SEP\]', r'<\|endoftext\|>', ] def __init__(self): # Compile artifact removal regex self._artifact_pattern = re.compile( '|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS), re.IGNORECASE ) def clean(self, text: str) -> str: """Remove generation artifacts and normalise whitespace.""" if not text: return "" # Remove generation artifacts result = self._artifact_pattern.sub('', text) # Replace em dashes and en dashes with commas result = result.replace('—', ',') result = result.replace('–', ',') # Normalise whitespace result = re.sub(r'\s+', ' ', result) result = result.strip() # Fix common post-generation spacing issues result = re.sub(r'\s+([.,!?;:])', r'\1', result) # Remove space before punctuation result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) # Add space after punctuation result = re.sub(r'\(\s+', '(', result) # Remove space after opening paren result = re.sub(r'\s+\)', ')', result) # Remove space before closing paren # Fix multiple punctuation result = re.sub(r'\.{2,}', '.', result) result = re.sub(r'\?{2,}', '?', result) result = re.sub(r'!{2,}', '!', result) return result def restore_entities( self, text: str, original_entities: List[str], protected_spans: List[Tuple[int, int]], ) -> str: """Restore named entities that may have been altered during generation. Uses fuzzy matching to find where entities should be in the generated text and restores the original form. """ if not original_entities: return text result = text for entity in original_entities: # Check if entity is already present in correct form if entity in result: continue # Try case-insensitive match pattern = re.compile(re.escape(entity), re.IGNORECASE) if pattern.search(result): result = pattern.sub(entity, result, count=1) logger.debug(f"Restored entity: {entity}") return result def format_output(self, text: str) -> str: """Apply final formatting (capitalisation, punctuation, spacing).""" if not text: return "" result = text.strip() # Ensure first letter is capitalised if result and result[0].islower(): result = result[0].upper() + result[1:] # Ensure text ends with punctuation if result and result[-1] not in '.!?': result += '.' # Capitalise after sentence-ending punctuation result = re.sub( r'([.!?]\s+)([a-z])', lambda m: m.group(1) + m.group(2).upper(), result ) # Fix "i" → "I" when standalone result = re.sub(r'\bi\b', 'I', result) # Remove trailing whitespace from lines result = '\n'.join(line.rstrip() for line in result.split('\n')) return result