File size: 4,778 Bytes
c9e92b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Tera.VO Text Processing Module
Full text normalization and encoding pipeline built from scratch.
"""

import re
import numpy as np
import inflect

_inflect_engine = inflect.engine()

# === Symbol Set ===
_pad = '_'
_eos = '~'
_bos = '^'
_punctuation = '!\'(),.:;? -"'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

symbols = [_pad] + [_bos] + [_eos] + list(_punctuation) + list(_letters)
symbol_to_id = {s: i for i, s in enumerate(symbols)}
id_to_symbol = {i: s for i, s in enumerate(symbols)}
NUM_SYMBOLS = len(symbols)


class TextProcessor:
    """Complete text processing pipeline for Tera.VO"""

    def __init__(self):
        self.symbol_to_id = symbol_to_id
        self.id_to_symbol = id_to_symbol
        self.num_symbols = NUM_SYMBOLS

        self.abbreviations = {
            'mr.': 'mister', 'mrs.': 'missus', 'dr.': 'doctor',
            'prof.': 'professor', 'sr.': 'senior', 'jr.': 'junior',
            'st.': 'saint', 'vs.': 'versus', 'etc.': 'etcetera',
            'govt.': 'government', 'dept.': 'department',
            'jan.': 'january', 'feb.': 'february', 'mar.': 'march',
            'apr.': 'april', 'aug.': 'august', 'sep.': 'september',
            'oct.': 'october', 'nov.': 'november', 'dec.': 'december',
            'approx.': 'approximately', 'univ.': 'university',
        }

    def normalize_text(self, text):
        """Full normalization pipeline"""
        text = text.strip()
        text = self._expand_abbreviations(text)
        text = self._expand_numbers(text)
        text = self._expand_symbols(text)
        text = self._collapse_whitespace(text)
        return text

    def _expand_abbreviations(self, text):
        for abbr, full in self.abbreviations.items():
            text = re.sub(re.escape(abbr), full, text, flags=re.IGNORECASE)
        return text

    def _expand_numbers(self, text):
        text = re.sub(
            r'\$(\d+\.?\d*)',
            lambda m: self._currency(m.group(1)), text
        )
        text = re.sub(
            r'(\d+\.?\d*)%',
            lambda m: self._number_words(m.group(1)) + ' percent', text
        )
        text = re.sub(
            r'(\d+)(st|nd|rd|th)\b',
            lambda m: self._ordinal(int(m.group(1))), text
        )
        text = re.sub(
            r'\b\d+\.?\d*\b',
            lambda m: self._number_words(m.group(0)), text
        )
        return text

    def _currency(self, amount_str):
        parts = amount_str.split('.')
        dollars = int(parts[0])
        result = self._number_words(str(dollars))
        result += ' dollar' + ('s' if dollars != 1 else '')
        if len(parts) > 1 and int(parts[1]) > 0:
            cents = int(parts[1][:2].ljust(2, '0'))
            result += ' and ' + self._number_words(str(cents))
            result += ' cent' + ('s' if cents != 1 else '')
        return result

    def _number_words(self, num_str):
        try:
            num = float(num_str)
            if num == int(num):
                return _inflect_engine.number_to_words(int(num))
            return _inflect_engine.number_to_words(num_str)
        except (ValueError, TypeError):
            return num_str

    def _ordinal(self, num):
        try:
            return _inflect_engine.ordinal(
                _inflect_engine.number_to_words(num)
            )
        except Exception:
            return str(num)

    def _expand_symbols(self, text):
        replacements = {
            '&': ' and ', '@': ' at ', '#': ' hash ',
            '+': ' plus ', '=': ' equals ', '/': ' slash ',
        }
        for sym, word in replacements.items():
            text = text.replace(sym, word)
        return text

    def _collapse_whitespace(self, text):
        return re.sub(r'\s+', ' ', text).strip()

    def text_to_sequence(self, text):
        """Convert normalized text to integer sequence"""
        text = self.normalize_text(text)
        seq = [self.symbol_to_id[_bos]]
        for ch in text:
            if ch in self.symbol_to_id:
                seq.append(self.symbol_to_id[ch])
        seq.append(self.symbol_to_id[_eos])
        return seq

    def sequence_to_text(self, sequence):
        """Convert integer sequence back to text"""
        chars = []
        for idx in sequence:
            if idx in self.id_to_symbol:
                s = self.id_to_symbol[idx]
                if s not in [_pad, _bos, _eos]:
                    chars.append(s)
        return ''.join(chars)

    def pad_sequence(self, seq, max_len):
        """Pad or truncate sequence"""
        if len(seq) >= max_len:
            return seq[:max_len]
        return seq + [self.symbol_to_id[_pad]] * (max_len - len(seq))

    def get_vocab_size(self):
        return self.num_symbols


text_processor = TextProcessor()