finbertteacher_v1 / preprocess.py
aimlresearch2023's picture
initial commit:
89c54bf
"""SMS/text preprocessing for the BERT classifier (stdlib only)."""
from __future__ import annotations
import re
import unicodedata
from typing import Optional
_CHAR_REPLACEMENTS: dict[str, str] = {}
def _add(chars: str, replacement: str) -> None:
for ch in chars:
_CHAR_REPLACEMENTS[ch] = replacement
_add("—–‐‑‒―−", "-")
_add("„‟«»\u201c\u201d", '"')
_CHAR_REPLACEMENTS["\u201c"] = '"'
_CHAR_REPLACEMENTS["\u201d"] = '"'
_add("‚‛′‵\u2018\u2019", "'")
_CHAR_REPLACEMENTS["\u2018"] = "'"
_CHAR_REPLACEMENTS["\u2019"] = "'"
_CHAR_REPLACEMENTS["…"] = "..."
_CHAR_REPLACEMENTS["‥"] = ".."
_CHAR_REPLACEMENTS["․"] = "."
_add("•◦●○▪▫∙⁌⁍⁃", "*")
_add("\u00a0\u2000\u2001\u2002\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200a\u200b\u200c\u200d\u202f\u205f\u3000", " ")
_CHAR_REPLACEMENTS["\u200b"] = ""
_CHAR_REPLACEMENTS["\u200c"] = ""
_CHAR_REPLACEMENTS["\u200d"] = ""
_CHAR_REPLACEMENTS["×"] = "*"
_CHAR_REPLACEMENTS["÷"] = "/"
_CHAR_REPLACEMENTS["±"] = "+/-"
_CHAR_REPLACEMENTS["≈"] = "~"
_CHAR_REPLACEMENTS["≠"] = "!="
_CHAR_REPLACEMENTS["≤"] = "<="
_CHAR_REPLACEMENTS["≥"] = ">="
_CHAR_REPLACEMENTS["∞"] = "infinite"
_CHAR_REPLACEMENTS["→"] = "->"
_CHAR_REPLACEMENTS["←"] = "<-"
_CHAR_REPLACEMENTS["↑"] = "^"
_CHAR_REPLACEMENTS["↓"] = "v"
_CHAR_REPLACEMENTS["↔"] = "<->"
_CHAR_REPLACEMENTS["©"] = "(c)"
_CHAR_REPLACEMENTS["®"] = "(R)"
_CHAR_REPLACEMENTS["™"] = "TM"
_CHAR_REPLACEMENTS["°"] = "degree"
_CHAR_REPLACEMENTS["‰"] = "0/00"
_CHAR_REPLACEMENTS["‱"] = "0/000"
_CHAR_REPLACEMENTS["†"] = "+"
_CHAR_REPLACEMENTS["‡"] = "++"
_CHAR_REPLACEMENTS["§"] = "S"
_CHAR_REPLACEMENTS["¶"] = "P"
_CHAR_REPLACEMENTS["‹"] = "<"
_CHAR_REPLACEMENTS["›"] = ">"
_CHAR_REPLACEMENTS.update(
{
"½": "1/2",
"¼": "1/4",
"¾": "3/4",
"⅓": "1/3",
"⅔": "2/3",
"⅕": "1/5",
"⅖": "2/5",
"⅗": "3/5",
"⅘": "4/5",
"⅙": "1/6",
"⅚": "5/6",
"⅛": "1/8",
"⅜": "3/8",
"⅝": "5/8",
"⅞": "7/8",
"⅐": "1/7",
"⅑": "1/9",
"⅒": "1/10",
}
)
_CHAR_REPLACEMENTS.update(
{
"¹": "^1",
"²": "^2",
"³": "^3",
"⁴": "^4",
"⁵": "^5",
"⁶": "^6",
"⁷": "^7",
"⁸": "^8",
"⁹": "^9",
"⁰": "^0",
"⁺": "^+",
"⁻": "^-",
"⁼": "^=",
"⁽": "^(",
"⁾": "^)",
}
)
_CHAR_REPLACEMENTS.update(
{
"₁": "_1",
"₂": "_2",
"₃": "_3",
"₄": "_4",
"₅": "_5",
"₆": "_6",
"₇": "_7",
"₈": "_8",
"₉": "_9",
"₀": "_0",
"₊": "_+",
"₋": "_-",
"₌": "_=",
"₍": "_(",
"₎": "_)",
}
)
for _i in range(10):
_CHAR_REPLACEMENTS[chr(0xFF10 + _i)] = str(_i)
for _i in range(26):
_CHAR_REPLACEMENTS[chr(0xFF21 + _i)] = chr(ord("A") + _i)
_CHAR_REPLACEMENTS[chr(0xFF41 + _i)] = chr(ord("a") + _i)
_CHAR_REPLACEMENTS.update(
{
"!": "!",
""": '"',
"#": "#",
"$": "$",
"%": "%",
"&": "&",
"'": "'",
"(": "(",
")": ")",
"*": "*",
"+": "+",
",": ",",
"-": "-",
".": ".",
"/": "/",
":": ":",
";": ";",
"<": "<",
"=": "=",
">": ">",
"?": "?",
"@": "@",
"[": "[",
"\": "\\",
"]": "]",
"^": "^",
"_": "_",
"`": "`",
"{": "{",
"|": "|",
"}": "}",
"~": "~",
}
)
for _i in range(12):
n = str(_i + 1)
_CHAR_REPLACEMENTS[chr(0x2160 + _i)] = n
_CHAR_REPLACEMENTS[chr(0x2170 + _i)] = n
_CHAR_REPLACEMENTS["Ⅼ"] = "50"
_CHAR_REPLACEMENTS["Ⅽ"] = "100"
_CHAR_REPLACEMENTS["Ⅾ"] = "500"
_CHAR_REPLACEMENTS["Ⅿ"] = "1000"
_CHAR_REPLACEMENTS["ⅼ"] = "50"
_CHAR_REPLACEMENTS["ⅽ"] = "100"
_CHAR_REPLACEMENTS["ⅾ"] = "500"
_CHAR_REPLACEMENTS["ⅿ"] = "1000"
for _i in range(26):
_CHAR_REPLACEMENTS[chr(0x1D00C + _i)] = chr(ord("a") + _i)
_DESLOPIFY_TABLE = str.maketrans(_CHAR_REPLACEMENTS)
STATIC_EMAIL = "a@b.com"
STATIC_URL = "www.a.com"
_EMAIL_RE = re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
_URL_RE = re.compile(r"https?://[^\s]+|www\.[^\s]+", flags=re.IGNORECASE)
ORG_CANDIDATE_SEP = " [SEP] "
_CURRENCY_ISO_CODES: tuple[str, ...] = (
"AED",
"BDT",
"CNY",
"EUR",
"GBP",
"INR",
"JPY",
"KRW",
"LKR",
"NPR",
"PKR",
"SAR",
"USD",
)
_CURRENCY_CODE_ALT = "|".join(sorted(_CURRENCY_ISO_CODES, key=len, reverse=True))
_CURRENCY_CODE_RE = re.compile(rf"(?i)(?:{_CURRENCY_CODE_ALT})(?![A-Za-z])")
_CURRENCY_RUPEE_WORDS_RE = re.compile(r"(?i)(?<![A-Za-z])(?:rupees|rupee)(?![A-Za-z])")
_CURRENCY_RS_RE = re.compile(r"(?i)(?<![A-Za-z])rs(?![A-Za-z])")
_CURRENCY_SYM_RE = re.compile(r"[€£¢¥₹₽₿₩](?![A-Za-z])")
_CURRENCY_DUP_DOLLAR_RE = re.compile(r"\${2,}")
def normalize_currency_markers(text: str) -> str:
"""Map common ISO codes and currency symbols to '$' when not followed by [A-Za-z]."""
if not text:
return text
text = _CURRENCY_CODE_RE.sub("$", text)
text = _CURRENCY_RUPEE_WORDS_RE.sub("$", text)
text = _CURRENCY_RS_RE.sub("$", text)
text = _CURRENCY_SYM_RE.sub("$", text)
text = _CURRENCY_DUP_DOLLAR_RE.sub("$", text)
return text
def deslopify(text: str) -> str:
if not text:
return text
return text.translate(_DESLOPIFY_TABLE)
def filter_ascii(text: str) -> str:
if not text:
return text
return "".join(ch for ch in text if (32 <= ord(ch) <= 126) or ch in "\t\n\f\r")
def normalize_terminal_punctuation(text: str) -> str:
if not text:
return ""
stripped = text.strip()
if not stripped:
return ""
last = stripped[-1]
if unicodedata.category(last)[0] != "P":
return f"{stripped}."
return stripped
def digits_to_ones(text: str) -> str:
return "".join("1" if c.isdigit() else c for c in text)
def mask_emails(text: str) -> str:
return _EMAIL_RE.sub(STATIC_EMAIL, text)
def mask_urls(text: str) -> str:
return _URL_RE.sub(STATIC_URL, text)
def normalize_freeform_text(text: str) -> str:
text = mask_emails(text)
text = mask_urls(text)
return digits_to_ones(text)
def preprocess_text(raw_text: str) -> str:
if not raw_text:
return ""
text = deslopify(raw_text)
text = normalize_currency_markers(text)
text = filter_ascii(text)
text = normalize_terminal_punctuation(text)
return normalize_freeform_text(text)
def preprocess_batch(raw_texts: list[str]) -> list[str]:
return [preprocess_text(t) for t in raw_texts]
def preprocess_for_model(
raw_text: str,
org_candidates: Optional[list[str]] = None,
*,
org_sep: str = ORG_CANDIDATE_SEP,
tokenizer_sep: str = " [SEP] ",
) -> str:
text = preprocess_text(raw_text)
if not org_candidates:
return text
return f"{org_sep.join(org_candidates)}{tokenizer_sep}{text}"
if __name__ == "__main__":
_samples = (
"Rs 100",
"RS.500",
"inr100",
"Rs1",
"Rsuper",
"INRing",
"₹500",
"₩1000",
"99 Rs",
"50 rupees",
"pay 1 rupee",
)
for s in _samples:
out = preprocess_text(s)
print(ascii(s), "->", ascii(out))