|
|
import numpy as np |
|
|
import joblib |
|
|
import torch |
|
|
from transformers import AutoModel |
|
|
import os |
|
|
|
|
|
class FinancialFilingClassifier: |
|
|
def __init__(self, model_dir): |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print(f"Loading Jina Encoder on {self.device}...") |
|
|
|
|
|
|
|
|
self.encoder = AutoModel.from_pretrained( |
|
|
"jinaai/jina-embeddings-v3", |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32 |
|
|
).to(self.device) |
|
|
|
|
|
print("Loading XGBoost Cascade...") |
|
|
self.router = joblib.load(os.path.join(model_dir, "router_xgb.pkl")) |
|
|
self.router_le = joblib.load(os.path.join(model_dir, "router_le.pkl")) |
|
|
self.specialists = {} |
|
|
self.model_dir = model_dir |
|
|
|
|
|
def _get_vector(self, text, doc_length=None): |
|
|
""" |
|
|
Generates the feature vector: [Embedding (1024) + Log_Length (1)] |
|
|
|
|
|
Args: |
|
|
text (str): The document text content. |
|
|
doc_length (int, optional): The real length of the document. |
|
|
If None, defaults to len(text). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val = int(doc_length) if doc_length is not None else len(str(text)) |
|
|
log_len = np.log1p(val) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
vec = self.encoder.encode([text], task="classification", max_length=8100) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(vec, torch.Tensor): |
|
|
vec = vec.cpu().numpy() |
|
|
elif isinstance(vec, list): |
|
|
vec = np.array(vec) |
|
|
|
|
|
|
|
|
return np.hstack([vec, [[log_len]]]) |
|
|
|
|
|
def _load_specialist(self, category): |
|
|
"""Lazy loads specialist models to save RAM until needed.""" |
|
|
safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_") |
|
|
if safe_name not in self.specialists: |
|
|
try: |
|
|
clf = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl")) |
|
|
le = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl")) |
|
|
self.specialists[safe_name] = (clf, le) |
|
|
except FileNotFoundError: |
|
|
return None |
|
|
return self.specialists[safe_name] |
|
|
|
|
|
def predict(self, text, doc_length=None): |
|
|
""" |
|
|
Predicts the category and type of a financial document. |
|
|
|
|
|
Args: |
|
|
text (str): The document text. |
|
|
doc_length (int, optional): The true character length of the document. |
|
|
Recommended for highest accuracy. |
|
|
""" |
|
|
vector = self._get_vector(text, doc_length=doc_length) |
|
|
|
|
|
|
|
|
router_probs = self.router.predict_proba(vector)[0] |
|
|
|
|
|
top_indices = np.argsort(router_probs)[::-1][:2] |
|
|
|
|
|
candidates = [] |
|
|
for idx in top_indices: |
|
|
category = self.router_le.classes_[idx] |
|
|
router_conf = router_probs[idx] |
|
|
|
|
|
|
|
|
specialist = self._load_specialist(category) |
|
|
|
|
|
if specialist: |
|
|
clf, le = specialist |
|
|
spec_probs = clf.predict_proba(vector)[0] |
|
|
best_idx = np.argmax(spec_probs) |
|
|
label = le.classes_[best_idx] |
|
|
spec_conf = spec_probs[best_idx] |
|
|
|
|
|
|
|
|
combined_score = np.sqrt(router_conf * spec_conf) |
|
|
candidates.append({ |
|
|
"category": category, |
|
|
"label": label, |
|
|
"score": float(combined_score) |
|
|
}) |
|
|
else: |
|
|
|
|
|
candidates.append({ |
|
|
"category": category, |
|
|
"label": category, |
|
|
"score": float(router_conf) |
|
|
}) |
|
|
|
|
|
|
|
|
return max(candidates, key=lambda x: x['score']) |