File size: 4,844 Bytes
51454d1
 
 
 
 
 
 
 
 
 
6f2deaf
 
51454d1
 
 
 
 
 
 
 
 
 
 
 
6f2deaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51454d1
6f2deaf
 
 
 
 
 
 
 
 
 
 
 
 
51454d1
 
 
6f2deaf
51454d1
 
 
 
 
 
 
 
 
 
6f2deaf
 
 
 
 
 
 
 
 
 
 
 
51454d1
6f2deaf
51454d1
 
 
 
 
 
6f2deaf
 
51454d1
 
 
 
 
 
 
 
6f2deaf
 
51454d1
6f2deaf
 
 
 
 
51454d1
6f2deaf
 
 
 
 
 
51454d1
6f2deaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import joblib
import torch
from transformers import AutoModel
import os

class FinancialFilingClassifier:
    def __init__(self, model_dir):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Loading Jina Encoder on {self.device}...")
        
        # Jina-v3 handles Flash Attention internally if installed
        self.encoder = AutoModel.from_pretrained(
            "jinaai/jina-embeddings-v3", 
            trust_remote_code=True, 
            torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32
        ).to(self.device)
        
        print("Loading XGBoost Cascade...")
        self.router = joblib.load(os.path.join(model_dir, "router_xgb.pkl"))
        self.router_le = joblib.load(os.path.join(model_dir, "router_le.pkl"))
        self.specialists = {}
        self.model_dir = model_dir

    def _get_vector(self, text, doc_length=None):
        """
        Generates the feature vector: [Embedding (1024) + Log_Length (1)]
        
        Args:
            text (str): The document text content.
            doc_length (int, optional): The real length of the document. 
                                      If None, defaults to len(text).
        """
        # 1. Feature Engineering: Log Length
        # We use the explicitly passed length if available. This allows
        # the model to know a document is massive (e.g. Annual Report) 
        # even if we only run inference on the first 8k tokens.
        val = int(doc_length) if doc_length is not None else len(str(text))
        log_len = np.log1p(val)

        with torch.no_grad():
            # 2. Embedding Generation
            # SAFETY MARGIN: We cap at 8100 (instead of 8192) to prevent 
            # "Rotary Embedding" off-by-one crashes in Flash Attention.
            vec = self.encoder.encode([text], task="classification", max_length=8100)
            
            # 3. XGBoost Compatibility
            # XGBoost requires CPU-bound Numpy arrays.
            if isinstance(vec, torch.Tensor):
                vec = vec.cpu().numpy()
            elif isinstance(vec, list):
                vec = np.array(vec)

        # 4. Concatenate: [Embedding vector, Log_Length]
        return np.hstack([vec, [[log_len]]])

    def _load_specialist(self, category):
        """Lazy loads specialist models to save RAM until needed."""
        safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
        if safe_name not in self.specialists:
            try:
                clf = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl"))
                le = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl"))
                self.specialists[safe_name] = (clf, le)
            except FileNotFoundError:
                return None
        return self.specialists[safe_name]

    def predict(self, text, doc_length=None):
        """
        Predicts the category and type of a financial document.
        
        Args:
            text (str): The document text.
            doc_length (int, optional): The true character length of the document.
                                      Recommended for highest accuracy.
        """
        vector = self._get_vector(text, doc_length=doc_length)
        
        # 1. Router Prediction (General Category)
        router_probs = self.router.predict_proba(vector)[0]
        # We look at the top 2 candidates to handle ambiguous edge cases
        top_indices = np.argsort(router_probs)[::-1][:2]
        
        candidates = []
        for idx in top_indices:
            category = self.router_le.classes_[idx]
            router_conf = router_probs[idx]
            
            # 2. Specialist Prediction (Specific Type)
            specialist = self._load_specialist(category)
            
            if specialist:
                clf, le = specialist
                spec_probs = clf.predict_proba(vector)[0]
                best_idx = np.argmax(spec_probs)
                label = le.classes_[best_idx]
                spec_conf = spec_probs[best_idx]
                
                # Combine Confidence Scores (Geometric Mean)
                combined_score = np.sqrt(router_conf * spec_conf)
                candidates.append({
                    "category": category, 
                    "label": label, 
                    "score": float(combined_score)
                })
            else:
                # Fallback if no specialist exists for this category
                candidates.append({
                    "category": category, 
                    "label": category, 
                    "score": float(router_conf)
                })
        
        # Return the highest scoring candidate
        return max(candidates, key=lambda x: x['score'])