Sowrabhm's picture
Upload folder using huggingface_hub
ca3ccd1 verified
"""Post-processing utilities for transaction extraction.
Ported from the Android Kotlin GLiNER2 ONNX runner. Provides tokenisation,
span decoding, amount parsing, and date normalisation for bank SMS messages.
"""
from __future__ import annotations
import re
from typing import Optional
import numpy as np
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SCHEMA_TOKENS: list[str] = [
"(", "[P]", "transaction_type",
"(", "[L]", "DEBIT", "[L]", "CREDIT", ")", ")",
"[SEP_STRUCT]",
"(", "[P]", "transaction_info",
"(", "[C]", "transaction_amount",
"[C]", "transaction_date",
"[C]", "transaction_description",
"[C]", "masked_account_digits", ")", ")",
"[SEP_TEXT]",
]
"""Fixed schema token sequence matching the exported ONNX model."""
EXTRACTION_FIELDS: list[str] = [
"transaction_amount",
"transaction_date",
"transaction_description",
"masked_account_digits",
]
"""Ordered field names for the span-extraction head."""
CLASSIFICATION_LABELS: list[str] = ["DEBIT", "CREDIT"]
"""Labels emitted by the classification head."""
# ---------------------------------------------------------------------------
# Tokenisation
# ---------------------------------------------------------------------------
_WORD_PATTERN = re.compile(
r"(?:https?://\S+|www\.\S+)" # URLs
r"|[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}" # emails
r"|@[a-z0-9_]+" # @-mentions
r"|\w+(?:[-_]\w+)*" # words (with hyphens/underscores)
r"|\S", # single non-space fallback
re.IGNORECASE,
)
def split_into_words(text: str) -> list[tuple[str, int, int]]:
"""Whitespace-aware tokeniser matching GLiNER2's WhitespaceTokenSplitter.
Returns a list of *(word, char_start, char_end)* tuples.
"""
return [(m.group(), m.start(), m.end()) for m in _WORD_PATTERN.finditer(text)]
# ---------------------------------------------------------------------------
# Amount parsing
# ---------------------------------------------------------------------------
_CURRENCY_PATTERN = re.compile(r"(?:Rs\.?|INR|₹)\s*", re.IGNORECASE)
_NUMBER_PATTERN = re.compile(r"[\d,]+(?:\.\d+)?")
def parse_amount(raw: str) -> float | None:
"""Strip currency symbols and extract the first numeric value.
Handles Rs., Rs, INR, and the rupee sign. Commas are removed before
conversion. Returns *None* when no number can be found.
"""
cleaned = _CURRENCY_PATTERN.sub("", raw).strip()
match = _NUMBER_PATTERN.search(cleaned)
if not match:
return None
try:
return float(match.group().replace(",", ""))
except ValueError:
return None
# ---------------------------------------------------------------------------
# Date normalisation
# ---------------------------------------------------------------------------
_MONTH_MAP: dict[str, int] = {
"jan": 1, "january": 1,
"feb": 2, "february": 2,
"mar": 3, "march": 3,
"apr": 4, "april": 4,
"may": 5,
"jun": 6, "june": 6,
"jul": 7, "july": 7,
"aug": 8, "august": 8,
"sep": 9, "september": 9,
"oct": 10, "october": 10,
"nov": 11, "november": 11,
"dec": 12, "december": 12,
}
# Patterns ordered from most specific to least specific.
_DATE_PATTERNS: list[re.Pattern[str]] = [
# DD-MM-YYYY or DD/MM/YYYY
re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})"),
# DD-Mon-YYYY or DD/Mon/YYYY
re.compile(
r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{4})"
),
# DD-MM-YY or DD/MM/YY
re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})(?!\d)"),
# DD-Mon-YY or DD/Mon/YY
re.compile(
r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{2})(?!\d)"
),
# DDMonYYYY (e.g. 23Dec2025)
re.compile(r"(\d{1,2})([A-Za-z]+)(\d{4})"),
]
def _century_window(yy: int) -> int:
"""Apply century windowing: YY > 50 -> 19YY, else 20YY."""
return 1900 + yy if yy > 50 else 2000 + yy
def _parse_month(token: str) -> int | None:
"""Return 1-12 for a numeric or named month, or *None*."""
if token.isdigit():
val = int(token)
return val if 1 <= val <= 12 else None
return _MONTH_MAP.get(token.lower())
def normalize_date(raw: str, received_date: str | None = None) -> str | None:
"""Parse a date string in various Indian SMS formats and return DD-MM-YYYY.
Supported input formats:
DD-MM-YYYY, DD/MM/YYYY, DD-MM-YY, DD/MM/YY,
DD-Mon-YYYY, DD-Mon-YY, DDMonYYYY.
Falls back to *received_date* (which must already be DD-MM-YYYY) when
*raw* cannot be parsed. Returns *None* if nothing works.
"""
for pattern in _DATE_PATTERNS:
m = pattern.search(raw)
if not m:
continue
day_s, month_s, year_s = m.group(1), m.group(2), m.group(3)
day = int(day_s)
month = _parse_month(month_s)
if month is None:
continue
year = int(year_s)
if year < 100:
year = _century_window(year)
# Basic validation
if not (2000 <= year <= 2100):
continue
if not (1 <= month <= 12):
continue
if not (1 <= day <= 31):
continue
return f"{day:02d}-{month:02d}-{year}"
# Fallback
if received_date is not None:
return received_date
return None
# ---------------------------------------------------------------------------
# Span decoding
# ---------------------------------------------------------------------------
def decode_spans(
span_scores: np.ndarray,
text: str,
words: list[str],
word_spans: list[tuple[int, int]],
threshold: float = 0.3,
) -> dict[str, Optional[tuple[str, float]]]:
"""Decode the span-extraction head output into field values.
Parameters
----------
span_scores:
Array of shape ``[4, num_words, max_width]`` — one slice per
extraction field.
text:
The original SMS text.
words:
Tokenised words (from :func:`split_into_words`).
word_spans:
``(char_start, char_end)`` pairs for each word.
threshold:
Minimum confidence to accept a span.
Returns
-------
dict
Mapping of field name to ``(extracted_text, confidence)`` or
*None* when no span exceeds *threshold*.
"""
num_words = len(words)
result: dict[str, Optional[tuple[str, float]]] = {}
for field_idx, field_name in enumerate(EXTRACTION_FIELDS):
field_scores = span_scores[field_idx] # [num_words, max_width]
best_score = 0.0
best_span: tuple[int, int, float] | None = None
for start in range(min(num_words, field_scores.shape[0])):
for width in range(field_scores.shape[1]):
end = start + width
if end >= num_words:
break
score = float(field_scores[start, width])
if score > best_score and score > threshold:
best_score = score
best_span = (start, end, score)
if best_span is not None:
s, e, conf = best_span
char_start = word_spans[s][0]
char_end = word_spans[e][1]
result[field_name] = (text[char_start:char_end], conf)
else:
result[field_name] = None
return result