rewrite / src /inference /postprocessor.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Post-processing utilities for generated text.
Handles cleanup, formatting, and final quality checks.
"""
import re
from typing import List, Tuple
from loguru import logger
class PostProcessor:
"""Cleans and formats generated text after model output."""
# Common generation artifacts to remove
ARTIFACTS = [
r'<pad>',
r'</s>',
r'<s>',
r'<unk>',
r'\[PAD\]',
r'\[CLS\]',
r'\[SEP\]',
r'<\|endoftext\|>',
]
def __init__(self):
# Compile artifact removal regex
self._artifact_pattern = re.compile(
'|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS),
re.IGNORECASE
)
def clean(self, text: str) -> str:
"""Remove generation artifacts and normalise whitespace."""
if not text:
return ""
# Remove generation artifacts
result = self._artifact_pattern.sub('', text)
# Replace em dashes and en dashes with commas
result = result.replace('—', ',')
result = result.replace('–', ',')
# Normalise whitespace
result = re.sub(r'\s+', ' ', result)
result = result.strip()
# Fix common post-generation spacing issues
result = re.sub(r'\s+([.,!?;:])', r'\1', result) # Remove space before punctuation
result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) # Add space after punctuation
result = re.sub(r'\(\s+', '(', result) # Remove space after opening paren
result = re.sub(r'\s+\)', ')', result) # Remove space before closing paren
# Fix multiple punctuation
result = re.sub(r'\.{2,}', '.', result)
result = re.sub(r'\?{2,}', '?', result)
result = re.sub(r'!{2,}', '!', result)
return result
def restore_entities(
self,
text: str,
original_entities: List[str],
protected_spans: List[Tuple[int, int]],
) -> str:
"""Restore named entities that may have been altered during generation.
Uses fuzzy matching to find where entities should be in the generated text
and restores the original form.
"""
if not original_entities:
return text
result = text
for entity in original_entities:
# Check if entity is already present in correct form
if entity in result:
continue
# Try case-insensitive match
pattern = re.compile(re.escape(entity), re.IGNORECASE)
if pattern.search(result):
result = pattern.sub(entity, result, count=1)
logger.debug(f"Restored entity: {entity}")
return result
def format_output(self, text: str) -> str:
"""Apply final formatting (capitalisation, punctuation, spacing)."""
if not text:
return ""
result = text.strip()
# Ensure first letter is capitalised
if result and result[0].islower():
result = result[0].upper() + result[1:]
# Ensure text ends with punctuation
if result and result[-1] not in '.!?':
result += '.'
# Capitalise after sentence-ending punctuation
result = re.sub(
r'([.!?]\s+)([a-z])',
lambda m: m.group(1) + m.group(2).upper(),
result
)
# Fix "i" → "I" when standalone
result = re.sub(r'\bi\b', 'I', result)
# Remove trailing whitespace from lines
result = '\n'.join(line.rstrip() for line in result.split('\n'))
return result