"""Post-processing utilities for transaction extraction. Ported from the Android Kotlin GLiNER2 ONNX runner. Provides tokenisation, span decoding, amount parsing, and date normalisation for bank SMS messages. """ from __future__ import annotations import re from typing import Optional import numpy as np # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- SCHEMA_TOKENS: list[str] = [ "(", "[P]", "transaction_type", "(", "[L]", "DEBIT", "[L]", "CREDIT", ")", ")", "[SEP_STRUCT]", "(", "[P]", "transaction_info", "(", "[C]", "transaction_amount", "[C]", "transaction_date", "[C]", "transaction_description", "[C]", "masked_account_digits", ")", ")", "[SEP_TEXT]", ] """Fixed schema token sequence matching the exported ONNX model.""" EXTRACTION_FIELDS: list[str] = [ "transaction_amount", "transaction_date", "transaction_description", "masked_account_digits", ] """Ordered field names for the span-extraction head.""" CLASSIFICATION_LABELS: list[str] = ["DEBIT", "CREDIT"] """Labels emitted by the classification head.""" # --------------------------------------------------------------------------- # Tokenisation # --------------------------------------------------------------------------- _WORD_PATTERN = re.compile( r"(?:https?://\S+|www\.\S+)" # URLs r"|[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}" # emails r"|@[a-z0-9_]+" # @-mentions r"|\w+(?:[-_]\w+)*" # words (with hyphens/underscores) r"|\S", # single non-space fallback re.IGNORECASE, ) def split_into_words(text: str) -> list[tuple[str, int, int]]: """Whitespace-aware tokeniser matching GLiNER2's WhitespaceTokenSplitter. Returns a list of *(word, char_start, char_end)* tuples. """ return [(m.group(), m.start(), m.end()) for m in _WORD_PATTERN.finditer(text)] # --------------------------------------------------------------------------- # Amount parsing # --------------------------------------------------------------------------- _CURRENCY_PATTERN = re.compile(r"(?:Rs\.?|INR|₹)\s*", re.IGNORECASE) _NUMBER_PATTERN = re.compile(r"[\d,]+(?:\.\d+)?") def parse_amount(raw: str) -> float | None: """Strip currency symbols and extract the first numeric value. Handles Rs., Rs, INR, and the rupee sign. Commas are removed before conversion. Returns *None* when no number can be found. """ cleaned = _CURRENCY_PATTERN.sub("", raw).strip() match = _NUMBER_PATTERN.search(cleaned) if not match: return None try: return float(match.group().replace(",", "")) except ValueError: return None # --------------------------------------------------------------------------- # Date normalisation # --------------------------------------------------------------------------- _MONTH_MAP: dict[str, int] = { "jan": 1, "january": 1, "feb": 2, "february": 2, "mar": 3, "march": 3, "apr": 4, "april": 4, "may": 5, "jun": 6, "june": 6, "jul": 7, "july": 7, "aug": 8, "august": 8, "sep": 9, "september": 9, "oct": 10, "october": 10, "nov": 11, "november": 11, "dec": 12, "december": 12, } # Patterns ordered from most specific to least specific. _DATE_PATTERNS: list[re.Pattern[str]] = [ # DD-MM-YYYY or DD/MM/YYYY re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})"), # DD-Mon-YYYY or DD/Mon/YYYY re.compile( r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{4})" ), # DD-MM-YY or DD/MM/YY re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})(?!\d)"), # DD-Mon-YY or DD/Mon/YY re.compile( r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{2})(?!\d)" ), # DDMonYYYY (e.g. 23Dec2025) re.compile(r"(\d{1,2})([A-Za-z]+)(\d{4})"), ] def _century_window(yy: int) -> int: """Apply century windowing: YY > 50 -> 19YY, else 20YY.""" return 1900 + yy if yy > 50 else 2000 + yy def _parse_month(token: str) -> int | None: """Return 1-12 for a numeric or named month, or *None*.""" if token.isdigit(): val = int(token) return val if 1 <= val <= 12 else None return _MONTH_MAP.get(token.lower()) def normalize_date(raw: str, received_date: str | None = None) -> str | None: """Parse a date string in various Indian SMS formats and return DD-MM-YYYY. Supported input formats: DD-MM-YYYY, DD/MM/YYYY, DD-MM-YY, DD/MM/YY, DD-Mon-YYYY, DD-Mon-YY, DDMonYYYY. Falls back to *received_date* (which must already be DD-MM-YYYY) when *raw* cannot be parsed. Returns *None* if nothing works. """ for pattern in _DATE_PATTERNS: m = pattern.search(raw) if not m: continue day_s, month_s, year_s = m.group(1), m.group(2), m.group(3) day = int(day_s) month = _parse_month(month_s) if month is None: continue year = int(year_s) if year < 100: year = _century_window(year) # Basic validation if not (2000 <= year <= 2100): continue if not (1 <= month <= 12): continue if not (1 <= day <= 31): continue return f"{day:02d}-{month:02d}-{year}" # Fallback if received_date is not None: return received_date return None # --------------------------------------------------------------------------- # Span decoding # --------------------------------------------------------------------------- def decode_spans( span_scores: np.ndarray, text: str, words: list[str], word_spans: list[tuple[int, int]], threshold: float = 0.3, ) -> dict[str, Optional[tuple[str, float]]]: """Decode the span-extraction head output into field values. Parameters ---------- span_scores: Array of shape ``[4, num_words, max_width]`` — one slice per extraction field. text: The original SMS text. words: Tokenised words (from :func:`split_into_words`). word_spans: ``(char_start, char_end)`` pairs for each word. threshold: Minimum confidence to accept a span. Returns ------- dict Mapping of field name to ``(extracted_text, confidence)`` or *None* when no span exceeds *threshold*. """ num_words = len(words) result: dict[str, Optional[tuple[str, float]]] = {} for field_idx, field_name in enumerate(EXTRACTION_FIELDS): field_scores = span_scores[field_idx] # [num_words, max_width] best_score = 0.0 best_span: tuple[int, int, float] | None = None for start in range(min(num_words, field_scores.shape[0])): for width in range(field_scores.shape[1]): end = start + width if end >= num_words: break score = float(field_scores[start, width]) if score > best_score and score > threshold: best_score = score best_span = (start, end, score) if best_span is not None: s, e, conf = best_span char_start = word_spans[s][0] char_end = word_spans[e][1] result[field_name] = (text[char_start:char_end], conf) else: result[field_name] = None return result