"""Two-stage ONNX inference pipeline for transaction extraction from SMS text. Stage 1 — Classification: determines whether the message describes a completed financial transaction (debit or credit). Stage 2 — Extraction: pulls structured fields (amount, date, type, description, masked account digits) from messages classified as transactions. """ from __future__ import annotations import os import numpy as np import onnxruntime as ort from tokenizers import Tokenizer from fintext.utils import ( CLASSIFICATION_LABELS, EXTRACTION_FIELDS, SCHEMA_TOKENS, decode_spans, normalize_date, parse_amount, split_into_words, ) class FintextExtractor: """Two-stage ONNX inference for transaction extraction from SMS text.""" def __init__(self, model_dir: str, precision: str = "fp16") -> None: """Load ONNX models and tokenizers from a local directory. Args: model_dir: Path to directory containing onnx/, tokenizer/, tokenizer_extraction/ sub-directories. precision: ``"fp16"`` or ``"fp32"`` -- which ONNX model variant to load. """ if precision not in ("fp16", "fp32"): raise ValueError(f"precision must be 'fp16' or 'fp32', got '{precision}'") self._precision = precision self._model_dir = model_dir # ONNX session options opts = ort.SessionOptions() opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL opts.intra_op_num_threads = 4 # Load classification model cls_path = os.path.join(model_dir, "onnx", f"deberta_classifier_{precision}.onnx") self._cls_session = ort.InferenceSession( cls_path, opts, providers=["CPUExecutionProvider"] ) # Load extraction model ext_path = os.path.join(model_dir, "onnx", f"extraction_full_{precision}.onnx") self._ext_session = ort.InferenceSession( ext_path, opts, providers=["CPUExecutionProvider"] ) # Load tokenizers cls_tok_path = os.path.join(model_dir, "tokenizer", "tokenizer.json") ext_tok_path = os.path.join(model_dir, "tokenizer_extraction", "tokenizer.json") self._cls_tokenizer = Tokenizer.from_file(cls_tok_path) self._ext_tokenizer = Tokenizer.from_file(ext_tok_path) # Configure classification tokenizer self._cls_tokenizer.enable_truncation(max_length=128) self._cls_tokenizer.enable_padding(length=128) @classmethod def from_pretrained( cls, repo_id: str = "Sowrabhm/fintext-extractor", precision: str = "fp16", ) -> FintextExtractor: """Download models from Hugging Face Hub and load them. Args: repo_id: Hugging Face model repo ID. precision: ``"fp16"`` or ``"fp32"``. """ from huggingface_hub import snapshot_download # Download only the files needed for the requested precision allow = [ f"onnx/deberta_classifier_{precision}.onnx", f"onnx/deberta_classifier_{precision}.onnx.data", f"onnx/extraction_full_{precision}.onnx", f"onnx/extraction_full_{precision}.onnx.data", "tokenizer/*", "tokenizer_extraction/*", "config.json", ] local_dir = snapshot_download(repo_id, allow_patterns=allow) return cls(local_dir, precision=precision) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def extract(self, text: str, received_date: str | None = None) -> dict: """Run full two-stage pipeline on a single SMS text. Args: text: SMS / notification text. received_date: Optional fallback date in DD-MM-YYYY format. Returns: dict with keys: ``is_transaction``, ``transaction_amount``, ``transaction_type``, ``transaction_date``, ``transaction_description``, ``masked_account_digits``. """ # Stage 1: Classification cls_result = self.classify(text) if not cls_result["is_transaction"]: return { "is_transaction": False, "transaction_amount": None, "transaction_type": None, "transaction_date": None, "transaction_description": None, "masked_account_digits": None, } # Stage 2: Extraction return self._extract_fields(text, received_date) def classify(self, text: str) -> dict: """Run classification only (stage 1). Returns: dict with ``is_transaction`` (bool) and ``confidence`` (float). """ # Tokenize with padding/truncation to 128 encoded = self._cls_tokenizer.encode(text) input_ids = np.array([encoded.ids], dtype=np.int64) attention_mask = np.array([encoded.attention_mask], dtype=np.int64) # Run classification outputs = self._cls_session.run( None, {"input_ids": input_ids, "attention_mask": attention_mask}, ) logits = outputs[0][0] # [2] -- logits for [non-transaction, transaction] # Softmax exp_logits = np.exp(logits - np.max(logits)) probs = exp_logits / exp_logits.sum() is_transaction = bool(probs[1] > 0.5) confidence = float(probs[1]) if is_transaction else float(probs[0]) return {"is_transaction": is_transaction, "confidence": confidence} def extract_batch( self, texts: list[str], received_date: str | None = None ) -> list[dict]: """Run extraction on multiple texts sequentially. Args: texts: List of SMS / notification texts. received_date: Optional fallback date. Returns: List of extraction result dicts. """ return [self.extract(t, received_date) for t in texts] # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _extract_fields(self, text: str, received_date: str | None = None) -> dict: """Stage 2: Extract transaction fields using the extraction model.""" # Split text into words with character spans word_info = split_into_words(text) words = [w for w, _, _ in word_info] word_spans = [(s, e) for _, s, e in word_info] num_words = len(words) text_words_lower = [w.lower() for w in words] # Build combined schema + text input combined_tokens = SCHEMA_TOKENS + text_words_lower schema_len = len(SCHEMA_TOKENS) # Subword-tokenize each combined token, build words_mask all_subword_ids: list[int] = [] words_mask_values: list[int] = [] for i, token in enumerate(combined_tokens): encoded = self._ext_tokenizer.encode(token, add_special_tokens=False) subword_ids = encoded.ids all_subword_ids.extend(subword_ids) if i >= schema_len: # Text word: first subword gets 1-indexed word number word_number = i - schema_len + 1 words_mask_values.append(word_number) words_mask_values.extend([0] * (len(subword_ids) - 1)) else: # Schema token: all get 0 words_mask_values.extend([0] * len(subword_ids)) # Truncate to 512 if needed max_len = 512 seq_len = min(len(all_subword_ids), max_len) # Build tensors input_ids = np.array([all_subword_ids[:seq_len]], dtype=np.int64) attention_mask = np.ones((1, seq_len), dtype=np.int64) words_mask = np.array([words_mask_values[:seq_len]], dtype=np.int64) text_lengths = np.array([num_words], dtype=np.int64) # Run extraction model outputs = self._ext_session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "words_mask": words_mask, "text_lengths": text_lengths, }, ) type_logits = outputs[0][0] # [2] -- softmax probs for [DEBIT, CREDIT] span_scores = outputs[1][0] # [4, num_words, max_width] # Decode transaction type transaction_type = CLASSIFICATION_LABELS[int(np.argmax(type_logits))] # Decode entity spans spans = decode_spans(span_scores, text, words, word_spans) # Post-process fields raw_amount = spans.get("transaction_amount") raw_date = spans.get("transaction_date") raw_desc = spans.get("transaction_description") raw_digits = spans.get("masked_account_digits") amount = parse_amount(raw_amount[0]) if raw_amount else None date = normalize_date(raw_date[0], received_date) if raw_date else received_date description = raw_desc[0] if raw_desc else None digits = raw_digits[0] if raw_digits else None # Validate: must have amount + type to be a valid transaction is_transaction = amount is not None and transaction_type is not None return { "is_transaction": is_transaction, "transaction_amount": amount, "transaction_type": transaction_type if is_transaction else None, "transaction_date": date, "transaction_description": description, "masked_account_digits": digits, }