File size: 3,643 Bytes
12fd5f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Post-processing utilities for generated text.
Handles cleanup, formatting, and final quality checks.
"""

import re
from typing import List, Tuple
from loguru import logger


class PostProcessor:
    """Cleans and formats generated text after model output."""

    # Common generation artifacts to remove
    ARTIFACTS = [
        r'<pad>',
        r'</s>',
        r'<s>',
        r'<unk>',
        r'\[PAD\]',
        r'\[CLS\]',
        r'\[SEP\]',
        r'<\|endoftext\|>',
    ]

    def __init__(self):
        # Compile artifact removal regex
        self._artifact_pattern = re.compile(
            '|'.join(re.escape(a) if not a.startswith('\\') else a for a in self.ARTIFACTS),
            re.IGNORECASE
        )

    def clean(self, text: str) -> str:
        """Remove generation artifacts and normalise whitespace."""
        if not text:
            return ""

        # Remove generation artifacts
        result = self._artifact_pattern.sub('', text)

        # Replace em dashes and en dashes with commas
        result = result.replace('—', ',')
        result = result.replace('–', ',')

        # Normalise whitespace
        result = re.sub(r'\s+', ' ', result)
        result = result.strip()

        # Fix common post-generation spacing issues
        result = re.sub(r'\s+([.,!?;:])', r'\1', result)  # Remove space before punctuation
        result = re.sub(r'([.,!?;:])([A-Za-z])', r'\1 \2', result)  # Add space after punctuation
        result = re.sub(r'\(\s+', '(', result)  # Remove space after opening paren
        result = re.sub(r'\s+\)', ')', result)  # Remove space before closing paren

        # Fix multiple punctuation
        result = re.sub(r'\.{2,}', '.', result)
        result = re.sub(r'\?{2,}', '?', result)
        result = re.sub(r'!{2,}', '!', result)

        return result

    def restore_entities(
        self,
        text: str,
        original_entities: List[str],
        protected_spans: List[Tuple[int, int]],
    ) -> str:
        """Restore named entities that may have been altered during generation.

        Uses fuzzy matching to find where entities should be in the generated text
        and restores the original form.
        """
        if not original_entities:
            return text

        result = text
        for entity in original_entities:
            # Check if entity is already present in correct form
            if entity in result:
                continue

            # Try case-insensitive match
            pattern = re.compile(re.escape(entity), re.IGNORECASE)
            if pattern.search(result):
                result = pattern.sub(entity, result, count=1)
                logger.debug(f"Restored entity: {entity}")

        return result

    def format_output(self, text: str) -> str:
        """Apply final formatting (capitalisation, punctuation, spacing)."""
        if not text:
            return ""

        result = text.strip()

        # Ensure first letter is capitalised
        if result and result[0].islower():
            result = result[0].upper() + result[1:]

        # Ensure text ends with punctuation
        if result and result[-1] not in '.!?':
            result += '.'

        # Capitalise after sentence-ending punctuation
        result = re.sub(
            r'([.!?]\s+)([a-z])',
            lambda m: m.group(1) + m.group(2).upper(),
            result
        )

        # Fix "i" → "I" when standalone
        result = re.sub(r'\bi\b', 'I', result)

        # Remove trailing whitespace from lines
        result = '\n'.join(line.rstrip() for line in result.split('\n'))

        return result