Sowrabhm commited on
Commit
ca3ccd1
·
verified ·
1 Parent(s): 0ecdc93

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +94 -0
  2. fintext/__init__.py +4 -0
  3. fintext/extractor.py +261 -0
  4. fintext/utils.py +241 -0
  5. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ fintext-extractor: Transaction Extraction from Bank SMS
3
+
4
+ Gradio demo for the two-stage NER pipeline that extracts structured
5
+ transaction data from bank SMS/notifications using ONNX Runtime.
6
+ """
7
+
8
+ import json
9
+ import time
10
+ import gradio as gr
11
+ from fintext import FintextExtractor
12
+
13
+ # Load model at startup (downloads ~1.8GB on first run)
14
+ print("Loading fintext-extractor model...")
15
+ extractor = FintextExtractor.from_pretrained("Sowrabhm/fintext-extractor", precision="fp16")
16
+ print("Model loaded!")
17
+
18
+
19
+ def extract_transaction(sms_text: str) -> str:
20
+ """Run two-stage extraction and return formatted JSON."""
21
+ if not sms_text or not sms_text.strip():
22
+ return json.dumps({"error": "Please enter SMS text"}, indent=2)
23
+
24
+ start = time.time()
25
+
26
+ # Stage 1: Classification
27
+ cls_result = extractor.classify(sms_text)
28
+
29
+ # Stage 2: Full extraction
30
+ result = extractor.extract(sms_text)
31
+
32
+ elapsed_ms = (time.time() - start) * 1000
33
+
34
+ # Build rich output
35
+ output = {
36
+ "is_transaction": result["is_transaction"],
37
+ "confidence": round(cls_result["confidence"], 3),
38
+ }
39
+
40
+ if result["is_transaction"]:
41
+ output["transaction_amount"] = result["transaction_amount"]
42
+ output["transaction_type"] = result["transaction_type"]
43
+ output["transaction_date"] = result["transaction_date"]
44
+ output["transaction_description"] = result["transaction_description"]
45
+ output["masked_account_digits"] = result["masked_account_digits"]
46
+
47
+ output["inference_time_ms"] = round(elapsed_ms, 1)
48
+
49
+ return json.dumps(output, indent=2, default=str)
50
+
51
+
52
+ # Example SMS messages (all synthetic)
53
+ examples = [
54
+ ["Rs.5,000 debited from a/c XX1234 for Amazon Pay on 08-Mar-26"],
55
+ ["Credit Alert: INR 25,000 credited to a/c XX5678 on 15-Jan-2026"],
56
+ ["INR 3,499.00 paid to Netflix via card ending 9876 on 01-Feb-2026"],
57
+ ["Dear Customer, Rs.850.50 has been credited to your a/c XX2468 on 05-Mar-2026. UPI Ref: 678912345"],
58
+ ["Transaction of Rs.15,750 at Flipkart on 28-Feb-2026 from card XX3579. Avl bal: Rs.42,300"],
59
+ ["OTP 483921 for transaction of Rs.1,200. Do not share with anyone."],
60
+ ["Your FD of Rs.50,000 matures on 20-Apr-2026. Visit branch."],
61
+ ["Reminder: EMI of Rs.12,500 due on 10-Mar-2026 for loan XX4321"],
62
+ ]
63
+
64
+ # Build Gradio interface
65
+ demo = gr.Interface(
66
+ fn=extract_transaction,
67
+ inputs=gr.Textbox(
68
+ label="SMS / Notification Text",
69
+ placeholder="Paste a bank SMS here...",
70
+ lines=3,
71
+ ),
72
+ outputs=gr.Code(
73
+ label="Extracted Transaction Data",
74
+ language="json",
75
+ ),
76
+ title="fintext-extractor",
77
+ description=(
78
+ "Extract structured transaction data from bank SMS using on-device NER. "
79
+ "Two-stage pipeline: DeBERTa classifier filters non-transactions, "
80
+ "then GLiNER2 extracts amount, date, type, description, and account digits.\n\n"
81
+ "**Try the examples below or paste your own SMS text.**"
82
+ ),
83
+ article=(
84
+ "**Links:** "
85
+ "[Model](https://huggingface.co/Sowrabhm/fintext-extractor) | "
86
+ "[GitHub](https://github.com/sowrabhmv/fintext-extractor) | "
87
+ "License: CC-BY-4.0"
88
+ ),
89
+ examples=examples,
90
+ cache_examples=False,
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
fintext/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from fintext.extractor import FintextExtractor
2
+
3
+ __all__ = ["FintextExtractor"]
4
+ __version__ = "1.0.0"
fintext/extractor.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Two-stage ONNX inference pipeline for transaction extraction from SMS text.
2
+
3
+ Stage 1 — Classification: determines whether the message describes a completed
4
+ financial transaction (debit or credit).
5
+
6
+ Stage 2 — Extraction: pulls structured fields (amount, date, type, description,
7
+ masked account digits) from messages classified as transactions.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+
14
+ import numpy as np
15
+ import onnxruntime as ort
16
+ from tokenizers import Tokenizer
17
+
18
+ from fintext.utils import (
19
+ CLASSIFICATION_LABELS,
20
+ EXTRACTION_FIELDS,
21
+ SCHEMA_TOKENS,
22
+ decode_spans,
23
+ normalize_date,
24
+ parse_amount,
25
+ split_into_words,
26
+ )
27
+
28
+
29
+ class FintextExtractor:
30
+ """Two-stage ONNX inference for transaction extraction from SMS text."""
31
+
32
+ def __init__(self, model_dir: str, precision: str = "fp16") -> None:
33
+ """Load ONNX models and tokenizers from a local directory.
34
+
35
+ Args:
36
+ model_dir: Path to directory containing onnx/, tokenizer/,
37
+ tokenizer_extraction/ sub-directories.
38
+ precision: ``"fp16"`` or ``"fp32"`` -- which ONNX model variant to
39
+ load.
40
+ """
41
+ if precision not in ("fp16", "fp32"):
42
+ raise ValueError(f"precision must be 'fp16' or 'fp32', got '{precision}'")
43
+
44
+ self._precision = precision
45
+ self._model_dir = model_dir
46
+
47
+ # ONNX session options
48
+ opts = ort.SessionOptions()
49
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
50
+ opts.intra_op_num_threads = 4
51
+
52
+ # Load classification model
53
+ cls_path = os.path.join(model_dir, "onnx", f"deberta_classifier_{precision}.onnx")
54
+ self._cls_session = ort.InferenceSession(
55
+ cls_path, opts, providers=["CPUExecutionProvider"]
56
+ )
57
+
58
+ # Load extraction model
59
+ ext_path = os.path.join(model_dir, "onnx", f"extraction_full_{precision}.onnx")
60
+ self._ext_session = ort.InferenceSession(
61
+ ext_path, opts, providers=["CPUExecutionProvider"]
62
+ )
63
+
64
+ # Load tokenizers
65
+ cls_tok_path = os.path.join(model_dir, "tokenizer", "tokenizer.json")
66
+ ext_tok_path = os.path.join(model_dir, "tokenizer_extraction", "tokenizer.json")
67
+ self._cls_tokenizer = Tokenizer.from_file(cls_tok_path)
68
+ self._ext_tokenizer = Tokenizer.from_file(ext_tok_path)
69
+
70
+ # Configure classification tokenizer
71
+ self._cls_tokenizer.enable_truncation(max_length=128)
72
+ self._cls_tokenizer.enable_padding(length=128)
73
+
74
+ @classmethod
75
+ def from_pretrained(
76
+ cls,
77
+ repo_id: str = "Sowrabhm/fintext-extractor",
78
+ precision: str = "fp16",
79
+ ) -> FintextExtractor:
80
+ """Download models from Hugging Face Hub and load them.
81
+
82
+ Args:
83
+ repo_id: Hugging Face model repo ID.
84
+ precision: ``"fp16"`` or ``"fp32"``.
85
+ """
86
+ from huggingface_hub import snapshot_download
87
+
88
+ # Download only the files needed for the requested precision
89
+ allow = [
90
+ f"onnx/deberta_classifier_{precision}.onnx",
91
+ f"onnx/deberta_classifier_{precision}.onnx.data",
92
+ f"onnx/extraction_full_{precision}.onnx",
93
+ f"onnx/extraction_full_{precision}.onnx.data",
94
+ "tokenizer/*",
95
+ "tokenizer_extraction/*",
96
+ "config.json",
97
+ ]
98
+ local_dir = snapshot_download(repo_id, allow_patterns=allow)
99
+ return cls(local_dir, precision=precision)
100
+
101
+ # ------------------------------------------------------------------
102
+ # Public API
103
+ # ------------------------------------------------------------------
104
+
105
+ def extract(self, text: str, received_date: str | None = None) -> dict:
106
+ """Run full two-stage pipeline on a single SMS text.
107
+
108
+ Args:
109
+ text: SMS / notification text.
110
+ received_date: Optional fallback date in DD-MM-YYYY format.
111
+
112
+ Returns:
113
+ dict with keys: ``is_transaction``, ``transaction_amount``,
114
+ ``transaction_type``, ``transaction_date``,
115
+ ``transaction_description``, ``masked_account_digits``.
116
+ """
117
+ # Stage 1: Classification
118
+ cls_result = self.classify(text)
119
+
120
+ if not cls_result["is_transaction"]:
121
+ return {
122
+ "is_transaction": False,
123
+ "transaction_amount": None,
124
+ "transaction_type": None,
125
+ "transaction_date": None,
126
+ "transaction_description": None,
127
+ "masked_account_digits": None,
128
+ }
129
+
130
+ # Stage 2: Extraction
131
+ return self._extract_fields(text, received_date)
132
+
133
+ def classify(self, text: str) -> dict:
134
+ """Run classification only (stage 1).
135
+
136
+ Returns:
137
+ dict with ``is_transaction`` (bool) and ``confidence`` (float).
138
+ """
139
+ # Tokenize with padding/truncation to 128
140
+ encoded = self._cls_tokenizer.encode(text)
141
+ input_ids = np.array([encoded.ids], dtype=np.int64)
142
+ attention_mask = np.array([encoded.attention_mask], dtype=np.int64)
143
+
144
+ # Run classification
145
+ outputs = self._cls_session.run(
146
+ None,
147
+ {"input_ids": input_ids, "attention_mask": attention_mask},
148
+ )
149
+
150
+ logits = outputs[0][0] # [2] -- logits for [non-transaction, transaction]
151
+
152
+ # Softmax
153
+ exp_logits = np.exp(logits - np.max(logits))
154
+ probs = exp_logits / exp_logits.sum()
155
+
156
+ is_transaction = bool(probs[1] > 0.5)
157
+ confidence = float(probs[1]) if is_transaction else float(probs[0])
158
+
159
+ return {"is_transaction": is_transaction, "confidence": confidence}
160
+
161
+ def extract_batch(
162
+ self, texts: list[str], received_date: str | None = None
163
+ ) -> list[dict]:
164
+ """Run extraction on multiple texts sequentially.
165
+
166
+ Args:
167
+ texts: List of SMS / notification texts.
168
+ received_date: Optional fallback date.
169
+
170
+ Returns:
171
+ List of extraction result dicts.
172
+ """
173
+ return [self.extract(t, received_date) for t in texts]
174
+
175
+ # ------------------------------------------------------------------
176
+ # Internals
177
+ # ------------------------------------------------------------------
178
+
179
+ def _extract_fields(self, text: str, received_date: str | None = None) -> dict:
180
+ """Stage 2: Extract transaction fields using the extraction model."""
181
+ # Split text into words with character spans
182
+ word_info = split_into_words(text)
183
+ words = [w for w, _, _ in word_info]
184
+ word_spans = [(s, e) for _, s, e in word_info]
185
+ num_words = len(words)
186
+ text_words_lower = [w.lower() for w in words]
187
+
188
+ # Build combined schema + text input
189
+ combined_tokens = SCHEMA_TOKENS + text_words_lower
190
+ schema_len = len(SCHEMA_TOKENS)
191
+
192
+ # Subword-tokenize each combined token, build words_mask
193
+ all_subword_ids: list[int] = []
194
+ words_mask_values: list[int] = []
195
+
196
+ for i, token in enumerate(combined_tokens):
197
+ encoded = self._ext_tokenizer.encode(token, add_special_tokens=False)
198
+ subword_ids = encoded.ids
199
+ all_subword_ids.extend(subword_ids)
200
+
201
+ if i >= schema_len:
202
+ # Text word: first subword gets 1-indexed word number
203
+ word_number = i - schema_len + 1
204
+ words_mask_values.append(word_number)
205
+ words_mask_values.extend([0] * (len(subword_ids) - 1))
206
+ else:
207
+ # Schema token: all get 0
208
+ words_mask_values.extend([0] * len(subword_ids))
209
+
210
+ # Truncate to 512 if needed
211
+ max_len = 512
212
+ seq_len = min(len(all_subword_ids), max_len)
213
+
214
+ # Build tensors
215
+ input_ids = np.array([all_subword_ids[:seq_len]], dtype=np.int64)
216
+ attention_mask = np.ones((1, seq_len), dtype=np.int64)
217
+ words_mask = np.array([words_mask_values[:seq_len]], dtype=np.int64)
218
+ text_lengths = np.array([num_words], dtype=np.int64)
219
+
220
+ # Run extraction model
221
+ outputs = self._ext_session.run(
222
+ None,
223
+ {
224
+ "input_ids": input_ids,
225
+ "attention_mask": attention_mask,
226
+ "words_mask": words_mask,
227
+ "text_lengths": text_lengths,
228
+ },
229
+ )
230
+
231
+ type_logits = outputs[0][0] # [2] -- softmax probs for [DEBIT, CREDIT]
232
+ span_scores = outputs[1][0] # [4, num_words, max_width]
233
+
234
+ # Decode transaction type
235
+ transaction_type = CLASSIFICATION_LABELS[int(np.argmax(type_logits))]
236
+
237
+ # Decode entity spans
238
+ spans = decode_spans(span_scores, text, words, word_spans)
239
+
240
+ # Post-process fields
241
+ raw_amount = spans.get("transaction_amount")
242
+ raw_date = spans.get("transaction_date")
243
+ raw_desc = spans.get("transaction_description")
244
+ raw_digits = spans.get("masked_account_digits")
245
+
246
+ amount = parse_amount(raw_amount[0]) if raw_amount else None
247
+ date = normalize_date(raw_date[0], received_date) if raw_date else received_date
248
+ description = raw_desc[0] if raw_desc else None
249
+ digits = raw_digits[0] if raw_digits else None
250
+
251
+ # Validate: must have amount + type to be a valid transaction
252
+ is_transaction = amount is not None and transaction_type is not None
253
+
254
+ return {
255
+ "is_transaction": is_transaction,
256
+ "transaction_amount": amount,
257
+ "transaction_type": transaction_type if is_transaction else None,
258
+ "transaction_date": date,
259
+ "transaction_description": description,
260
+ "masked_account_digits": digits,
261
+ }
fintext/utils.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Post-processing utilities for transaction extraction.
2
+
3
+ Ported from the Android Kotlin GLiNER2 ONNX runner. Provides tokenisation,
4
+ span decoding, amount parsing, and date normalisation for bank SMS messages.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import re
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Constants
16
+ # ---------------------------------------------------------------------------
17
+
18
+ SCHEMA_TOKENS: list[str] = [
19
+ "(", "[P]", "transaction_type",
20
+ "(", "[L]", "DEBIT", "[L]", "CREDIT", ")", ")",
21
+ "[SEP_STRUCT]",
22
+ "(", "[P]", "transaction_info",
23
+ "(", "[C]", "transaction_amount",
24
+ "[C]", "transaction_date",
25
+ "[C]", "transaction_description",
26
+ "[C]", "masked_account_digits", ")", ")",
27
+ "[SEP_TEXT]",
28
+ ]
29
+ """Fixed schema token sequence matching the exported ONNX model."""
30
+
31
+ EXTRACTION_FIELDS: list[str] = [
32
+ "transaction_amount",
33
+ "transaction_date",
34
+ "transaction_description",
35
+ "masked_account_digits",
36
+ ]
37
+ """Ordered field names for the span-extraction head."""
38
+
39
+ CLASSIFICATION_LABELS: list[str] = ["DEBIT", "CREDIT"]
40
+ """Labels emitted by the classification head."""
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Tokenisation
44
+ # ---------------------------------------------------------------------------
45
+
46
+ _WORD_PATTERN = re.compile(
47
+ r"(?:https?://\S+|www\.\S+)" # URLs
48
+ r"|[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}" # emails
49
+ r"|@[a-z0-9_]+" # @-mentions
50
+ r"|\w+(?:[-_]\w+)*" # words (with hyphens/underscores)
51
+ r"|\S", # single non-space fallback
52
+ re.IGNORECASE,
53
+ )
54
+
55
+
56
+ def split_into_words(text: str) -> list[tuple[str, int, int]]:
57
+ """Whitespace-aware tokeniser matching GLiNER2's WhitespaceTokenSplitter.
58
+
59
+ Returns a list of *(word, char_start, char_end)* tuples.
60
+ """
61
+ return [(m.group(), m.start(), m.end()) for m in _WORD_PATTERN.finditer(text)]
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Amount parsing
66
+ # ---------------------------------------------------------------------------
67
+
68
+ _CURRENCY_PATTERN = re.compile(r"(?:Rs\.?|INR|₹)\s*", re.IGNORECASE)
69
+ _NUMBER_PATTERN = re.compile(r"[\d,]+(?:\.\d+)?")
70
+
71
+
72
+ def parse_amount(raw: str) -> float | None:
73
+ """Strip currency symbols and extract the first numeric value.
74
+
75
+ Handles Rs., Rs, INR, and the rupee sign. Commas are removed before
76
+ conversion. Returns *None* when no number can be found.
77
+ """
78
+ cleaned = _CURRENCY_PATTERN.sub("", raw).strip()
79
+ match = _NUMBER_PATTERN.search(cleaned)
80
+ if not match:
81
+ return None
82
+ try:
83
+ return float(match.group().replace(",", ""))
84
+ except ValueError:
85
+ return None
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Date normalisation
90
+ # ---------------------------------------------------------------------------
91
+
92
+ _MONTH_MAP: dict[str, int] = {
93
+ "jan": 1, "january": 1,
94
+ "feb": 2, "february": 2,
95
+ "mar": 3, "march": 3,
96
+ "apr": 4, "april": 4,
97
+ "may": 5,
98
+ "jun": 6, "june": 6,
99
+ "jul": 7, "july": 7,
100
+ "aug": 8, "august": 8,
101
+ "sep": 9, "september": 9,
102
+ "oct": 10, "october": 10,
103
+ "nov": 11, "november": 11,
104
+ "dec": 12, "december": 12,
105
+ }
106
+
107
+ # Patterns ordered from most specific to least specific.
108
+ _DATE_PATTERNS: list[re.Pattern[str]] = [
109
+ # DD-MM-YYYY or DD/MM/YYYY
110
+ re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{4})"),
111
+ # DD-Mon-YYYY or DD/Mon/YYYY
112
+ re.compile(
113
+ r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{4})"
114
+ ),
115
+ # DD-MM-YY or DD/MM/YY
116
+ re.compile(r"(\d{1,2})[/\-](\d{1,2})[/\-](\d{2})(?!\d)"),
117
+ # DD-Mon-YY or DD/Mon/YY
118
+ re.compile(
119
+ r"(\d{1,2})[/\-]([A-Za-z]+)[/\-](\d{2})(?!\d)"
120
+ ),
121
+ # DDMonYYYY (e.g. 23Dec2025)
122
+ re.compile(r"(\d{1,2})([A-Za-z]+)(\d{4})"),
123
+ ]
124
+
125
+
126
+ def _century_window(yy: int) -> int:
127
+ """Apply century windowing: YY > 50 -> 19YY, else 20YY."""
128
+ return 1900 + yy if yy > 50 else 2000 + yy
129
+
130
+
131
+ def _parse_month(token: str) -> int | None:
132
+ """Return 1-12 for a numeric or named month, or *None*."""
133
+ if token.isdigit():
134
+ val = int(token)
135
+ return val if 1 <= val <= 12 else None
136
+ return _MONTH_MAP.get(token.lower())
137
+
138
+
139
+ def normalize_date(raw: str, received_date: str | None = None) -> str | None:
140
+ """Parse a date string in various Indian SMS formats and return DD-MM-YYYY.
141
+
142
+ Supported input formats:
143
+ DD-MM-YYYY, DD/MM/YYYY, DD-MM-YY, DD/MM/YY,
144
+ DD-Mon-YYYY, DD-Mon-YY, DDMonYYYY.
145
+
146
+ Falls back to *received_date* (which must already be DD-MM-YYYY) when
147
+ *raw* cannot be parsed. Returns *None* if nothing works.
148
+ """
149
+ for pattern in _DATE_PATTERNS:
150
+ m = pattern.search(raw)
151
+ if not m:
152
+ continue
153
+
154
+ day_s, month_s, year_s = m.group(1), m.group(2), m.group(3)
155
+
156
+ day = int(day_s)
157
+ month = _parse_month(month_s)
158
+ if month is None:
159
+ continue
160
+
161
+ year = int(year_s)
162
+ if year < 100:
163
+ year = _century_window(year)
164
+
165
+ # Basic validation
166
+ if not (2000 <= year <= 2100):
167
+ continue
168
+ if not (1 <= month <= 12):
169
+ continue
170
+ if not (1 <= day <= 31):
171
+ continue
172
+
173
+ return f"{day:02d}-{month:02d}-{year}"
174
+
175
+ # Fallback
176
+ if received_date is not None:
177
+ return received_date
178
+
179
+ return None
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Span decoding
184
+ # ---------------------------------------------------------------------------
185
+
186
+ def decode_spans(
187
+ span_scores: np.ndarray,
188
+ text: str,
189
+ words: list[str],
190
+ word_spans: list[tuple[int, int]],
191
+ threshold: float = 0.3,
192
+ ) -> dict[str, Optional[tuple[str, float]]]:
193
+ """Decode the span-extraction head output into field values.
194
+
195
+ Parameters
196
+ ----------
197
+ span_scores:
198
+ Array of shape ``[4, num_words, max_width]`` — one slice per
199
+ extraction field.
200
+ text:
201
+ The original SMS text.
202
+ words:
203
+ Tokenised words (from :func:`split_into_words`).
204
+ word_spans:
205
+ ``(char_start, char_end)`` pairs for each word.
206
+ threshold:
207
+ Minimum confidence to accept a span.
208
+
209
+ Returns
210
+ -------
211
+ dict
212
+ Mapping of field name to ``(extracted_text, confidence)`` or
213
+ *None* when no span exceeds *threshold*.
214
+ """
215
+ num_words = len(words)
216
+ result: dict[str, Optional[tuple[str, float]]] = {}
217
+
218
+ for field_idx, field_name in enumerate(EXTRACTION_FIELDS):
219
+ field_scores = span_scores[field_idx] # [num_words, max_width]
220
+ best_score = 0.0
221
+ best_span: tuple[int, int, float] | None = None
222
+
223
+ for start in range(min(num_words, field_scores.shape[0])):
224
+ for width in range(field_scores.shape[1]):
225
+ end = start + width
226
+ if end >= num_words:
227
+ break
228
+ score = float(field_scores[start, width])
229
+ if score > best_score and score > threshold:
230
+ best_score = score
231
+ best_span = (start, end, score)
232
+
233
+ if best_span is not None:
234
+ s, e, conf = best_span
235
+ char_start = word_spans[s][0]
236
+ char_end = word_spans[e][1]
237
+ result[field_name] = (text[char_start:char_end], conf)
238
+ else:
239
+ result[field_name] = None
240
+
241
+ return result
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ onnxruntime>=1.17.0
2
+ numpy>=1.24.0
3
+ huggingface_hub>=0.20.0
4
+ tokenizers>=0.15.0
5
+ gradio>=4.0.0