Spaces:
Running
Running
| """Post-processing utilities for transaction extraction. | |
| Ported from the Android Kotlin GLiNER2 ONNX runner. Provides tokenisation, | |
| span decoding, amount parsing, and date normalisation for bank SMS messages. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import Optional | |
| import numpy as np | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| SCHEMA_TOKENS: list[str] = [ | |
| "(", "[P]", "transaction_type", | |
| "(", "[L]", "DEBIT", "[L]", "CREDIT", ")", ")", | |
| "[SEP_STRUCT]", | |
| "(", "[P]", "transaction_info", | |
| "(", "[C]", "transaction_amount", | |
| "[C]", "transaction_date", | |
| "[C]", "transaction_description", | |
| "[C]", "masked_account_digits", ")", ")", | |
| "[SEP_TEXT]", | |
| ] | |
| """Fixed schema token sequence matching the exported ONNX model.""" | |
| EXTRACTION_FIELDS: list[str] = [ | |
| "transaction_amount", | |
| "transaction_date", | |
| "transaction_description", | |
| "masked_account_digits", | |
| ] | |
| """Ordered field names for the span-extraction head.""" | |
| CLASSIFICATION_LABELS: list[str] = ["DEBIT", "CREDIT"] | |
| """Labels emitted by the classification head.""" | |
| # --------------------------------------------------------------------------- | |
| # Tokenisation | |
| # --------------------------------------------------------------------------- | |
| _WORD_PATTERN = re.compile( | |
| r"(?:https?://\S+|www\.\S+)" # URLs | |
| r"|[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}" # emails | |
| r"|@[a-z0-9_]+" # @-mentions | |
| r"|\w+(?:[-_]\w+)*" # words (with hyphens/underscores) | |
| r"|\S", # single non-space fallback | |
| re.IGNORECASE, | |
| ) | |
| def split_into_words(text: str) -> list[tuple[str, int, int]]: | |
| """Whitespace-aware tokeniser matching GLiNER2's WhitespaceTokenSplitter. | |
| Returns a list of *(word, char_start, char_end)* tuples. | |
| """ | |
| return [(m.group(), m.start(), m.end()) for m in _WORD_PATTERN.finditer(text)] | |
| # --------------------------------------------------------------------------- | |
| # Amount parsing | |
| # --------------------------------------------------------------------------- | |
| _CURRENCY_PATTERN = re.compile(r"(?:Rs\.?|INR|₹)\s*", re.IGNORECASE) | |
| _NUMBER_PATTERN = re.compile(r"[\d,]+(?:\.\d+)?") | |
| def parse_amount(raw: str) -> float | None: | |
| """Strip currency symbols and extract the first numeric value. | |
| Handles Rs., Rs, INR, and the rupee sign. Commas are removed before | |
| conversion. Returns *None* when no number can be found. | |
| """ | |
| cleaned = _CURRENCY_PATTERN.sub("", raw).strip() | |
| match = _NUMBER_PATTERN.search(cleaned) | |
| if not match: | |
| return None | |
| try: | |
| return float(match.group().replace(",", "")) | |
| except ValueError: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Date normalisation | |
| # --------------------------------------------------------------------------- | |
| _MONTH_MAP: dict[str, int] = { | |
| "jan": 1, "january": 1, | |
| "feb": 2, "february": 2, | |
| "mar": 3, "march": 3, | |
| "apr": 4, "april": 4, | |
| "may": 5, | |
| "jun": 6, "june": 6, | |
| "jul": 7, "july": 7, | |
| "aug": 8, "august": 8, | |
| "sep": 9, "september": 9, | |
| "oct": 10, "october": 10, | |
| "nov": 11, "november": 11, | |
| "dec": 12, "december": 12, | |
| } | |
| # Patterns ordered from most specific to least specific. | |
| _DATE_PATTERNS: list[re.Pattern[str]] = [ | |
| # DD-MM-YYYY or DD/MM/YYYY | |
| re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})"), | |
| # DD-Mon-YYYY or DD/Mon/YYYY | |
| re.compile( | |
| r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{4})" | |
| ), | |
| # DD-MM-YY or DD/MM/YY | |
| re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})(?!\d)"), | |
| # DD-Mon-YY or DD/Mon/YY | |
| re.compile( | |
| r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{2})(?!\d)" | |
| ), | |
| # DDMonYYYY (e.g. 23Dec2025) | |
| re.compile(r"(\d{1,2})([A-Za-z]+)(\d{4})"), | |
| ] | |
| def _century_window(yy: int) -> int: | |
| """Apply century windowing: YY > 50 -> 19YY, else 20YY.""" | |
| return 1900 + yy if yy > 50 else 2000 + yy | |
| def _parse_month(token: str) -> int | None: | |
| """Return 1-12 for a numeric or named month, or *None*.""" | |
| if token.isdigit(): | |
| val = int(token) | |
| return val if 1 <= val <= 12 else None | |
| return _MONTH_MAP.get(token.lower()) | |
| def normalize_date(raw: str, received_date: str | None = None) -> str | None: | |
| """Parse a date string in various Indian SMS formats and return DD-MM-YYYY. | |
| Supported input formats: | |
| DD-MM-YYYY, DD/MM/YYYY, DD-MM-YY, DD/MM/YY, | |
| DD-Mon-YYYY, DD-Mon-YY, DDMonYYYY. | |
| Falls back to *received_date* (which must already be DD-MM-YYYY) when | |
| *raw* cannot be parsed. Returns *None* if nothing works. | |
| """ | |
| for pattern in _DATE_PATTERNS: | |
| m = pattern.search(raw) | |
| if not m: | |
| continue | |
| day_s, month_s, year_s = m.group(1), m.group(2), m.group(3) | |
| day = int(day_s) | |
| month = _parse_month(month_s) | |
| if month is None: | |
| continue | |
| year = int(year_s) | |
| if year < 100: | |
| year = _century_window(year) | |
| # Basic validation | |
| if not (2000 <= year <= 2100): | |
| continue | |
| if not (1 <= month <= 12): | |
| continue | |
| if not (1 <= day <= 31): | |
| continue | |
| return f"{day:02d}-{month:02d}-{year}" | |
| # Fallback | |
| if received_date is not None: | |
| return received_date | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Span decoding | |
| # --------------------------------------------------------------------------- | |
| def decode_spans( | |
| span_scores: np.ndarray, | |
| text: str, | |
| words: list[str], | |
| word_spans: list[tuple[int, int]], | |
| threshold: float = 0.3, | |
| ) -> dict[str, Optional[tuple[str, float]]]: | |
| """Decode the span-extraction head output into field values. | |
| Parameters | |
| ---------- | |
| span_scores: | |
| Array of shape ``[4, num_words, max_width]`` — one slice per | |
| extraction field. | |
| text: | |
| The original SMS text. | |
| words: | |
| Tokenised words (from :func:`split_into_words`). | |
| word_spans: | |
| ``(char_start, char_end)`` pairs for each word. | |
| threshold: | |
| Minimum confidence to accept a span. | |
| Returns | |
| ------- | |
| dict | |
| Mapping of field name to ``(extracted_text, confidence)`` or | |
| *None* when no span exceeds *threshold*. | |
| """ | |
| num_words = len(words) | |
| result: dict[str, Optional[tuple[str, float]]] = {} | |
| for field_idx, field_name in enumerate(EXTRACTION_FIELDS): | |
| field_scores = span_scores[field_idx] # [num_words, max_width] | |
| best_score = 0.0 | |
| best_span: tuple[int, int, float] | None = None | |
| for start in range(min(num_words, field_scores.shape[0])): | |
| for width in range(field_scores.shape[1]): | |
| end = start + width | |
| if end >= num_words: | |
| break | |
| score = float(field_scores[start, width]) | |
| if score > best_score and score > threshold: | |
| best_score = score | |
| best_span = (start, end, score) | |
| if best_span is not None: | |
| s, e, conf = best_span | |
| char_start = word_spans[s][0] | |
| char_end = word_spans[e][1] | |
| result[field_name] = (text[char_start:char_end], conf) | |
| else: | |
| result[field_name] = None | |
| return result | |