Spaces:
Running
Running
| """Two-stage ONNX inference pipeline for transaction extraction from SMS text. | |
| Stage 1 — Classification: determines whether the message describes a completed | |
| financial transaction (debit or credit). | |
| Stage 2 — Extraction: pulls structured fields (amount, date, type, description, | |
| masked account digits) from messages classified as transactions. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import numpy as np | |
| import onnxruntime as ort | |
| from tokenizers import Tokenizer | |
| from fintext.utils import ( | |
| CLASSIFICATION_LABELS, | |
| EXTRACTION_FIELDS, | |
| SCHEMA_TOKENS, | |
| decode_spans, | |
| normalize_date, | |
| parse_amount, | |
| split_into_words, | |
| ) | |
| class FintextExtractor: | |
| """Two-stage ONNX inference for transaction extraction from SMS text.""" | |
| def __init__(self, model_dir: str, precision: str = "fp16") -> None: | |
| """Load ONNX models and tokenizers from a local directory. | |
| Args: | |
| model_dir: Path to directory containing onnx/, tokenizer/, | |
| tokenizer_extraction/ sub-directories. | |
| precision: ``"fp16"`` or ``"fp32"`` -- which ONNX model variant to | |
| load. | |
| """ | |
| if precision not in ("fp16", "fp32"): | |
| raise ValueError(f"precision must be 'fp16' or 'fp32', got '{precision}'") | |
| self._precision = precision | |
| self._model_dir = model_dir | |
| # ONNX session options | |
| opts = ort.SessionOptions() | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| opts.intra_op_num_threads = 4 | |
| # Load classification model | |
| cls_path = os.path.join(model_dir, "onnx", f"deberta_classifier_{precision}.onnx") | |
| self._cls_session = ort.InferenceSession( | |
| cls_path, opts, providers=["CPUExecutionProvider"] | |
| ) | |
| # Load extraction model | |
| ext_path = os.path.join(model_dir, "onnx", f"extraction_full_{precision}.onnx") | |
| self._ext_session = ort.InferenceSession( | |
| ext_path, opts, providers=["CPUExecutionProvider"] | |
| ) | |
| # Load tokenizers | |
| cls_tok_path = os.path.join(model_dir, "tokenizer", "tokenizer.json") | |
| ext_tok_path = os.path.join(model_dir, "tokenizer_extraction", "tokenizer.json") | |
| self._cls_tokenizer = Tokenizer.from_file(cls_tok_path) | |
| self._ext_tokenizer = Tokenizer.from_file(ext_tok_path) | |
| # Configure classification tokenizer | |
| self._cls_tokenizer.enable_truncation(max_length=128) | |
| self._cls_tokenizer.enable_padding(length=128) | |
| def from_pretrained( | |
| cls, | |
| repo_id: str = "Sowrabhm/fintext-extractor", | |
| precision: str = "fp16", | |
| ) -> FintextExtractor: | |
| """Download models from Hugging Face Hub and load them. | |
| Args: | |
| repo_id: Hugging Face model repo ID. | |
| precision: ``"fp16"`` or ``"fp32"``. | |
| """ | |
| from huggingface_hub import snapshot_download | |
| # Download only the files needed for the requested precision | |
| allow = [ | |
| f"onnx/deberta_classifier_{precision}.onnx", | |
| f"onnx/deberta_classifier_{precision}.onnx.data", | |
| f"onnx/extraction_full_{precision}.onnx", | |
| f"onnx/extraction_full_{precision}.onnx.data", | |
| "tokenizer/*", | |
| "tokenizer_extraction/*", | |
| "config.json", | |
| ] | |
| local_dir = snapshot_download(repo_id, allow_patterns=allow) | |
| return cls(local_dir, precision=precision) | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def extract(self, text: str, received_date: str | None = None) -> dict: | |
| """Run full two-stage pipeline on a single SMS text. | |
| Args: | |
| text: SMS / notification text. | |
| received_date: Optional fallback date in DD-MM-YYYY format. | |
| Returns: | |
| dict with keys: ``is_transaction``, ``transaction_amount``, | |
| ``transaction_type``, ``transaction_date``, | |
| ``transaction_description``, ``masked_account_digits``. | |
| """ | |
| # Stage 1: Classification | |
| cls_result = self.classify(text) | |
| if not cls_result["is_transaction"]: | |
| return { | |
| "is_transaction": False, | |
| "transaction_amount": None, | |
| "transaction_type": None, | |
| "transaction_date": None, | |
| "transaction_description": None, | |
| "masked_account_digits": None, | |
| } | |
| # Stage 2: Extraction | |
| return self._extract_fields(text, received_date) | |
| def classify(self, text: str) -> dict: | |
| """Run classification only (stage 1). | |
| Returns: | |
| dict with ``is_transaction`` (bool) and ``confidence`` (float). | |
| """ | |
| # Tokenize with padding/truncation to 128 | |
| encoded = self._cls_tokenizer.encode(text) | |
| input_ids = np.array([encoded.ids], dtype=np.int64) | |
| attention_mask = np.array([encoded.attention_mask], dtype=np.int64) | |
| # Run classification | |
| outputs = self._cls_session.run( | |
| None, | |
| {"input_ids": input_ids, "attention_mask": attention_mask}, | |
| ) | |
| logits = outputs[0][0] # [2] -- logits for [non-transaction, transaction] | |
| # Softmax | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| probs = exp_logits / exp_logits.sum() | |
| is_transaction = bool(probs[1] > 0.5) | |
| confidence = float(probs[1]) if is_transaction else float(probs[0]) | |
| return {"is_transaction": is_transaction, "confidence": confidence} | |
| def extract_batch( | |
| self, texts: list[str], received_date: str | None = None | |
| ) -> list[dict]: | |
| """Run extraction on multiple texts sequentially. | |
| Args: | |
| texts: List of SMS / notification texts. | |
| received_date: Optional fallback date. | |
| Returns: | |
| List of extraction result dicts. | |
| """ | |
| return [self.extract(t, received_date) for t in texts] | |
| # ------------------------------------------------------------------ | |
| # Internals | |
| # ------------------------------------------------------------------ | |
| def _extract_fields(self, text: str, received_date: str | None = None) -> dict: | |
| """Stage 2: Extract transaction fields using the extraction model.""" | |
| # Split text into words with character spans | |
| word_info = split_into_words(text) | |
| words = [w for w, _, _ in word_info] | |
| word_spans = [(s, e) for _, s, e in word_info] | |
| num_words = len(words) | |
| text_words_lower = [w.lower() for w in words] | |
| # Build combined schema + text input | |
| combined_tokens = SCHEMA_TOKENS + text_words_lower | |
| schema_len = len(SCHEMA_TOKENS) | |
| # Subword-tokenize each combined token, build words_mask | |
| all_subword_ids: list[int] = [] | |
| words_mask_values: list[int] = [] | |
| for i, token in enumerate(combined_tokens): | |
| encoded = self._ext_tokenizer.encode(token, add_special_tokens=False) | |
| subword_ids = encoded.ids | |
| all_subword_ids.extend(subword_ids) | |
| if i >= schema_len: | |
| # Text word: first subword gets 1-indexed word number | |
| word_number = i - schema_len + 1 | |
| words_mask_values.append(word_number) | |
| words_mask_values.extend([0] * (len(subword_ids) - 1)) | |
| else: | |
| # Schema token: all get 0 | |
| words_mask_values.extend([0] * len(subword_ids)) | |
| # Truncate to 512 if needed | |
| max_len = 512 | |
| seq_len = min(len(all_subword_ids), max_len) | |
| # Build tensors | |
| input_ids = np.array([all_subword_ids[:seq_len]], dtype=np.int64) | |
| attention_mask = np.ones((1, seq_len), dtype=np.int64) | |
| words_mask = np.array([words_mask_values[:seq_len]], dtype=np.int64) | |
| text_lengths = np.array([num_words], dtype=np.int64) | |
| # Run extraction model | |
| outputs = self._ext_session.run( | |
| None, | |
| { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "words_mask": words_mask, | |
| "text_lengths": text_lengths, | |
| }, | |
| ) | |
| type_logits = outputs[0][0] # [2] -- softmax probs for [DEBIT, CREDIT] | |
| span_scores = outputs[1][0] # [4, num_words, max_width] | |
| # Decode transaction type | |
| transaction_type = CLASSIFICATION_LABELS[int(np.argmax(type_logits))] | |
| # Decode entity spans | |
| spans = decode_spans(span_scores, text, words, word_spans) | |
| # Post-process fields | |
| raw_amount = spans.get("transaction_amount") | |
| raw_date = spans.get("transaction_date") | |
| raw_desc = spans.get("transaction_description") | |
| raw_digits = spans.get("masked_account_digits") | |
| amount = parse_amount(raw_amount[0]) if raw_amount else None | |
| date = normalize_date(raw_date[0], received_date) if raw_date else received_date | |
| description = raw_desc[0] if raw_desc else None | |
| digits = raw_digits[0] if raw_digits else None | |
| # Validate: must have amount + type to be a valid transaction | |
| is_transaction = amount is not None and transaction_type is not None | |
| return { | |
| "is_transaction": is_transaction, | |
| "transaction_amount": amount, | |
| "transaction_type": transaction_type if is_transaction else None, | |
| "transaction_date": date, | |
| "transaction_description": description, | |
| "masked_account_digits": digits, | |
| } | |