fintext-extractor / fintext /extractor.py
Sowrabhm's picture
Upload folder using huggingface_hub
ca3ccd1 verified
"""Two-stage ONNX inference pipeline for transaction extraction from SMS text.
Stage 1 — Classification: determines whether the message describes a completed
financial transaction (debit or credit).
Stage 2 — Extraction: pulls structured fields (amount, date, type, description,
masked account digits) from messages classified as transactions.
"""
from __future__ import annotations
import os
import numpy as np
import onnxruntime as ort
from tokenizers import Tokenizer
from fintext.utils import (
CLASSIFICATION_LABELS,
EXTRACTION_FIELDS,
SCHEMA_TOKENS,
decode_spans,
normalize_date,
parse_amount,
split_into_words,
)
class FintextExtractor:
"""Two-stage ONNX inference for transaction extraction from SMS text."""
def __init__(self, model_dir: str, precision: str = "fp16") -> None:
"""Load ONNX models and tokenizers from a local directory.
Args:
model_dir: Path to directory containing onnx/, tokenizer/,
tokenizer_extraction/ sub-directories.
precision: ``"fp16"`` or ``"fp32"`` -- which ONNX model variant to
load.
"""
if precision not in ("fp16", "fp32"):
raise ValueError(f"precision must be 'fp16' or 'fp32', got '{precision}'")
self._precision = precision
self._model_dir = model_dir
# ONNX session options
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.intra_op_num_threads = 4
# Load classification model
cls_path = os.path.join(model_dir, "onnx", f"deberta_classifier_{precision}.onnx")
self._cls_session = ort.InferenceSession(
cls_path, opts, providers=["CPUExecutionProvider"]
)
# Load extraction model
ext_path = os.path.join(model_dir, "onnx", f"extraction_full_{precision}.onnx")
self._ext_session = ort.InferenceSession(
ext_path, opts, providers=["CPUExecutionProvider"]
)
# Load tokenizers
cls_tok_path = os.path.join(model_dir, "tokenizer", "tokenizer.json")
ext_tok_path = os.path.join(model_dir, "tokenizer_extraction", "tokenizer.json")
self._cls_tokenizer = Tokenizer.from_file(cls_tok_path)
self._ext_tokenizer = Tokenizer.from_file(ext_tok_path)
# Configure classification tokenizer
self._cls_tokenizer.enable_truncation(max_length=128)
self._cls_tokenizer.enable_padding(length=128)
@classmethod
def from_pretrained(
cls,
repo_id: str = "Sowrabhm/fintext-extractor",
precision: str = "fp16",
) -> FintextExtractor:
"""Download models from Hugging Face Hub and load them.
Args:
repo_id: Hugging Face model repo ID.
precision: ``"fp16"`` or ``"fp32"``.
"""
from huggingface_hub import snapshot_download
# Download only the files needed for the requested precision
allow = [
f"onnx/deberta_classifier_{precision}.onnx",
f"onnx/deberta_classifier_{precision}.onnx.data",
f"onnx/extraction_full_{precision}.onnx",
f"onnx/extraction_full_{precision}.onnx.data",
"tokenizer/*",
"tokenizer_extraction/*",
"config.json",
]
local_dir = snapshot_download(repo_id, allow_patterns=allow)
return cls(local_dir, precision=precision)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def extract(self, text: str, received_date: str | None = None) -> dict:
"""Run full two-stage pipeline on a single SMS text.
Args:
text: SMS / notification text.
received_date: Optional fallback date in DD-MM-YYYY format.
Returns:
dict with keys: ``is_transaction``, ``transaction_amount``,
``transaction_type``, ``transaction_date``,
``transaction_description``, ``masked_account_digits``.
"""
# Stage 1: Classification
cls_result = self.classify(text)
if not cls_result["is_transaction"]:
return {
"is_transaction": False,
"transaction_amount": None,
"transaction_type": None,
"transaction_date": None,
"transaction_description": None,
"masked_account_digits": None,
}
# Stage 2: Extraction
return self._extract_fields(text, received_date)
def classify(self, text: str) -> dict:
"""Run classification only (stage 1).
Returns:
dict with ``is_transaction`` (bool) and ``confidence`` (float).
"""
# Tokenize with padding/truncation to 128
encoded = self._cls_tokenizer.encode(text)
input_ids = np.array([encoded.ids], dtype=np.int64)
attention_mask = np.array([encoded.attention_mask], dtype=np.int64)
# Run classification
outputs = self._cls_session.run(
None,
{"input_ids": input_ids, "attention_mask": attention_mask},
)
logits = outputs[0][0] # [2] -- logits for [non-transaction, transaction]
# Softmax
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum()
is_transaction = bool(probs[1] > 0.5)
confidence = float(probs[1]) if is_transaction else float(probs[0])
return {"is_transaction": is_transaction, "confidence": confidence}
def extract_batch(
self, texts: list[str], received_date: str | None = None
) -> list[dict]:
"""Run extraction on multiple texts sequentially.
Args:
texts: List of SMS / notification texts.
received_date: Optional fallback date.
Returns:
List of extraction result dicts.
"""
return [self.extract(t, received_date) for t in texts]
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _extract_fields(self, text: str, received_date: str | None = None) -> dict:
"""Stage 2: Extract transaction fields using the extraction model."""
# Split text into words with character spans
word_info = split_into_words(text)
words = [w for w, _, _ in word_info]
word_spans = [(s, e) for _, s, e in word_info]
num_words = len(words)
text_words_lower = [w.lower() for w in words]
# Build combined schema + text input
combined_tokens = SCHEMA_TOKENS + text_words_lower
schema_len = len(SCHEMA_TOKENS)
# Subword-tokenize each combined token, build words_mask
all_subword_ids: list[int] = []
words_mask_values: list[int] = []
for i, token in enumerate(combined_tokens):
encoded = self._ext_tokenizer.encode(token, add_special_tokens=False)
subword_ids = encoded.ids
all_subword_ids.extend(subword_ids)
if i >= schema_len:
# Text word: first subword gets 1-indexed word number
word_number = i - schema_len + 1
words_mask_values.append(word_number)
words_mask_values.extend([0] * (len(subword_ids) - 1))
else:
# Schema token: all get 0
words_mask_values.extend([0] * len(subword_ids))
# Truncate to 512 if needed
max_len = 512
seq_len = min(len(all_subword_ids), max_len)
# Build tensors
input_ids = np.array([all_subword_ids[:seq_len]], dtype=np.int64)
attention_mask = np.ones((1, seq_len), dtype=np.int64)
words_mask = np.array([words_mask_values[:seq_len]], dtype=np.int64)
text_lengths = np.array([num_words], dtype=np.int64)
# Run extraction model
outputs = self._ext_session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"words_mask": words_mask,
"text_lengths": text_lengths,
},
)
type_logits = outputs[0][0] # [2] -- softmax probs for [DEBIT, CREDIT]
span_scores = outputs[1][0] # [4, num_words, max_width]
# Decode transaction type
transaction_type = CLASSIFICATION_LABELS[int(np.argmax(type_logits))]
# Decode entity spans
spans = decode_spans(span_scores, text, words, word_spans)
# Post-process fields
raw_amount = spans.get("transaction_amount")
raw_date = spans.get("transaction_date")
raw_desc = spans.get("transaction_description")
raw_digits = spans.get("masked_account_digits")
amount = parse_amount(raw_amount[0]) if raw_amount else None
date = normalize_date(raw_date[0], received_date) if raw_date else received_date
description = raw_desc[0] if raw_desc else None
digits = raw_digits[0] if raw_digits else None
# Validate: must have amount + type to be a valid transaction
is_transaction = amount is not None and transaction_type is not None
return {
"is_transaction": is_transaction,
"transaction_amount": amount,
"transaction_type": transaction_type if is_transaction else None,
"transaction_date": date,
"transaction_description": description,
"masked_account_digits": digits,
}