| """ |
| Post-processing utilities for generated text. |
| Handles cleanup, formatting, and final quality checks. |
| """ |
|
|
| import re |
| from typing import List, Tuple |
| from loguru import logger |
|
|
|
|
| class PostProcessor: |
| """Cleans and formats generated text after model output.""" |
|
|
| |
| ARTIFACTS = [ |
| r'<pad>', |
| r'</s>', |
| r'<s>', |
| r'<unk>', |
| r'\[PAD\]', |
| r'\[CLS\]', |
| r'\[SEP\]', |
| r'<\|endoftext\|>', |
| ] |
|
|
| def __init__(self): |
| |
| self._artifact_pattern = re.compile( |
| '|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS), |
| re.IGNORECASE |
| ) |
|
|
| def clean(self, text: str) -> str: |
| """Remove generation artifacts and normalise whitespace.""" |
| if not text: |
| return "" |
|
|
| |
| result = self._artifact_pattern.sub('', text) |
|
|
| |
| result = result.replace('—', ',') |
| result = result.replace('–', ',') |
|
|
| |
| result = re.sub(r'\s+', ' ', result) |
| result = result.strip() |
|
|
| |
| result = re.sub(r'\s+([.,!?;:])', r'\1', result) |
| result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result) |
| result = re.sub(r'\(\s+', '(', result) |
| result = re.sub(r'\s+\)', ')', result) |
|
|
| |
| result = re.sub(r'\.{2,}', '.', result) |
| result = re.sub(r'\?{2,}', '?', result) |
| result = re.sub(r'!{2,}', '!', result) |
|
|
| return result |
|
|
| def restore_entities( |
| self, |
| text: str, |
| original_entities: List[str], |
| protected_spans: List[Tuple[int, int]], |
| ) -> str: |
| """Restore named entities that may have been altered during generation. |
| |
| Uses fuzzy matching to find where entities should be in the generated text |
| and restores the original form. |
| """ |
| if not original_entities: |
| return text |
|
|
| result = text |
| for entity in original_entities: |
| |
| if entity in result: |
| continue |
|
|
| |
| pattern = re.compile(re.escape(entity), re.IGNORECASE) |
| if pattern.search(result): |
| result = pattern.sub(entity, result, count=1) |
| logger.debug(f"Restored entity: {entity}") |
|
|
| return result |
|
|
| def format_output(self, text: str) -> str: |
| """Apply final formatting (capitalisation, punctuation, spacing).""" |
| if not text: |
| return "" |
|
|
| result = text.strip() |
|
|
| |
| if result and result[0].islower(): |
| result = result[0].upper() + result[1:] |
|
|
| |
| if result and result[-1] not in '.!?': |
| result += '.' |
|
|
| |
| result = re.sub( |
| r'([.!?]\s+)([a-z])', |
| lambda m: m.group(1) + m.group(2).upper(), |
| result |
| ) |
|
|
| |
| result = re.sub(r'\bi\b', 'I', result) |
|
|
| |
| result = '\n'.join(line.rstrip() for line in result.split('\n')) |
|
|
| return result |
|
|