File size: 9,935 Bytes
ca3ccd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Two-stage ONNX inference pipeline for transaction extraction from SMS text.



Stage 1 — Classification: determines whether the message describes a completed

financial transaction (debit or credit).



Stage 2 — Extraction: pulls structured fields (amount, date, type, description,

masked account digits) from messages classified as transactions.

"""

from __future__ import annotations

import os

import numpy as np
import onnxruntime as ort
from tokenizers import Tokenizer

from fintext.utils import (
    CLASSIFICATION_LABELS,
    EXTRACTION_FIELDS,
    SCHEMA_TOKENS,
    decode_spans,
    normalize_date,
    parse_amount,
    split_into_words,
)


class FintextExtractor:
    """Two-stage ONNX inference for transaction extraction from SMS text."""

    def __init__(self, model_dir: str, precision: str = "fp16") -> None:
        """Load ONNX models and tokenizers from a local directory.



        Args:

            model_dir: Path to directory containing onnx/, tokenizer/,

                tokenizer_extraction/ sub-directories.

            precision: ``"fp16"`` or ``"fp32"`` -- which ONNX model variant to

                load.

        """
        if precision not in ("fp16", "fp32"):
            raise ValueError(f"precision must be 'fp16' or 'fp32', got '{precision}'")

        self._precision = precision
        self._model_dir = model_dir

        # ONNX session options
        opts = ort.SessionOptions()
        opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        opts.intra_op_num_threads = 4

        # Load classification model
        cls_path = os.path.join(model_dir, "onnx", f"deberta_classifier_{precision}.onnx")
        self._cls_session = ort.InferenceSession(
            cls_path, opts, providers=["CPUExecutionProvider"]
        )

        # Load extraction model
        ext_path = os.path.join(model_dir, "onnx", f"extraction_full_{precision}.onnx")
        self._ext_session = ort.InferenceSession(
            ext_path, opts, providers=["CPUExecutionProvider"]
        )

        # Load tokenizers
        cls_tok_path = os.path.join(model_dir, "tokenizer", "tokenizer.json")
        ext_tok_path = os.path.join(model_dir, "tokenizer_extraction", "tokenizer.json")
        self._cls_tokenizer = Tokenizer.from_file(cls_tok_path)
        self._ext_tokenizer = Tokenizer.from_file(ext_tok_path)

        # Configure classification tokenizer
        self._cls_tokenizer.enable_truncation(max_length=128)
        self._cls_tokenizer.enable_padding(length=128)

    @classmethod
    def from_pretrained(

        cls,

        repo_id: str = "Sowrabhm/fintext-extractor",

        precision: str = "fp16",

    ) -> FintextExtractor:
        """Download models from Hugging Face Hub and load them.



        Args:

            repo_id: Hugging Face model repo ID.

            precision: ``"fp16"`` or ``"fp32"``.

        """
        from huggingface_hub import snapshot_download

        # Download only the files needed for the requested precision
        allow = [
            f"onnx/deberta_classifier_{precision}.onnx",
            f"onnx/deberta_classifier_{precision}.onnx.data",
            f"onnx/extraction_full_{precision}.onnx",
            f"onnx/extraction_full_{precision}.onnx.data",
            "tokenizer/*",
            "tokenizer_extraction/*",
            "config.json",
        ]
        local_dir = snapshot_download(repo_id, allow_patterns=allow)
        return cls(local_dir, precision=precision)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def extract(self, text: str, received_date: str | None = None) -> dict:
        """Run full two-stage pipeline on a single SMS text.



        Args:

            text: SMS / notification text.

            received_date: Optional fallback date in DD-MM-YYYY format.



        Returns:

            dict with keys: ``is_transaction``, ``transaction_amount``,

            ``transaction_type``, ``transaction_date``,

            ``transaction_description``, ``masked_account_digits``.

        """
        # Stage 1: Classification
        cls_result = self.classify(text)

        if not cls_result["is_transaction"]:
            return {
                "is_transaction": False,
                "transaction_amount": None,
                "transaction_type": None,
                "transaction_date": None,
                "transaction_description": None,
                "masked_account_digits": None,
            }

        # Stage 2: Extraction
        return self._extract_fields(text, received_date)

    def classify(self, text: str) -> dict:
        """Run classification only (stage 1).



        Returns:

            dict with ``is_transaction`` (bool) and ``confidence`` (float).

        """
        # Tokenize with padding/truncation to 128
        encoded = self._cls_tokenizer.encode(text)
        input_ids = np.array([encoded.ids], dtype=np.int64)
        attention_mask = np.array([encoded.attention_mask], dtype=np.int64)

        # Run classification
        outputs = self._cls_session.run(
            None,
            {"input_ids": input_ids, "attention_mask": attention_mask},
        )

        logits = outputs[0][0]  # [2] -- logits for [non-transaction, transaction]

        # Softmax
        exp_logits = np.exp(logits - np.max(logits))
        probs = exp_logits / exp_logits.sum()

        is_transaction = bool(probs[1] > 0.5)
        confidence = float(probs[1]) if is_transaction else float(probs[0])

        return {"is_transaction": is_transaction, "confidence": confidence}

    def extract_batch(

        self, texts: list[str], received_date: str | None = None

    ) -> list[dict]:
        """Run extraction on multiple texts sequentially.



        Args:

            texts: List of SMS / notification texts.

            received_date: Optional fallback date.



        Returns:

            List of extraction result dicts.

        """
        return [self.extract(t, received_date) for t in texts]

    # ------------------------------------------------------------------
    # Internals
    # ------------------------------------------------------------------

    def _extract_fields(self, text: str, received_date: str | None = None) -> dict:
        """Stage 2: Extract transaction fields using the extraction model."""
        # Split text into words with character spans
        word_info = split_into_words(text)
        words = [w for w, _, _ in word_info]
        word_spans = [(s, e) for _, s, e in word_info]
        num_words = len(words)
        text_words_lower = [w.lower() for w in words]

        # Build combined schema + text input
        combined_tokens = SCHEMA_TOKENS + text_words_lower
        schema_len = len(SCHEMA_TOKENS)

        # Subword-tokenize each combined token, build words_mask
        all_subword_ids: list[int] = []
        words_mask_values: list[int] = []

        for i, token in enumerate(combined_tokens):
            encoded = self._ext_tokenizer.encode(token, add_special_tokens=False)
            subword_ids = encoded.ids
            all_subword_ids.extend(subword_ids)

            if i >= schema_len:
                # Text word: first subword gets 1-indexed word number
                word_number = i - schema_len + 1
                words_mask_values.append(word_number)
                words_mask_values.extend([0] * (len(subword_ids) - 1))
            else:
                # Schema token: all get 0
                words_mask_values.extend([0] * len(subword_ids))

        # Truncate to 512 if needed
        max_len = 512
        seq_len = min(len(all_subword_ids), max_len)

        # Build tensors
        input_ids = np.array([all_subword_ids[:seq_len]], dtype=np.int64)
        attention_mask = np.ones((1, seq_len), dtype=np.int64)
        words_mask = np.array([words_mask_values[:seq_len]], dtype=np.int64)
        text_lengths = np.array([num_words], dtype=np.int64)

        # Run extraction model
        outputs = self._ext_session.run(
            None,
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "words_mask": words_mask,
                "text_lengths": text_lengths,
            },
        )

        type_logits = outputs[0][0]  # [2] -- softmax probs for [DEBIT, CREDIT]
        span_scores = outputs[1][0]  # [4, num_words, max_width]

        # Decode transaction type
        transaction_type = CLASSIFICATION_LABELS[int(np.argmax(type_logits))]

        # Decode entity spans
        spans = decode_spans(span_scores, text, words, word_spans)

        # Post-process fields
        raw_amount = spans.get("transaction_amount")
        raw_date = spans.get("transaction_date")
        raw_desc = spans.get("transaction_description")
        raw_digits = spans.get("masked_account_digits")

        amount = parse_amount(raw_amount[0]) if raw_amount else None
        date = normalize_date(raw_date[0], received_date) if raw_date else received_date
        description = raw_desc[0] if raw_desc else None
        digits = raw_digits[0] if raw_digits else None

        # Validate: must have amount + type to be a valid transaction
        is_transaction = amount is not None and transaction_type is not None

        return {
            "is_transaction": is_transaction,
            "transaction_amount": amount,
            "transaction_type": transaction_type if is_transaction else None,
            "transaction_date": date,
            "transaction_description": description,
            "masked_account_digits": digits,
        }