Update inference_wrapper.py
Browse files- inference_wrapper.py +62 -9
inference_wrapper.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
import numpy as np
|
| 3 |
import joblib
|
| 4 |
import torch
|
|
@@ -9,6 +8,8 @@ class FinancialFilingClassifier:
|
|
| 9 |
def __init__(self, model_dir):
|
| 10 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 11 |
print(f"Loading Jina Encoder on {self.device}...")
|
|
|
|
|
|
|
| 12 |
self.encoder = AutoModel.from_pretrained(
|
| 13 |
"jinaai/jina-embeddings-v3",
|
| 14 |
trust_remote_code=True,
|
|
@@ -21,13 +22,40 @@ class FinancialFilingClassifier:
|
|
| 21 |
self.specialists = {}
|
| 22 |
self.model_dir = model_dir
|
| 23 |
|
| 24 |
-
def _get_vector(self, text):
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
with torch.no_grad():
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return np.hstack([vec, [[log_len]]])
|
| 29 |
|
| 30 |
def _load_specialist(self, category):
|
|
|
|
| 31 |
safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
|
| 32 |
if safe_name not in self.specialists:
|
| 33 |
try:
|
|
@@ -38,15 +66,28 @@ class FinancialFilingClassifier:
|
|
| 38 |
return None
|
| 39 |
return self.specialists[safe_name]
|
| 40 |
|
| 41 |
-
def predict(self, text):
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
router_probs = self.router.predict_proba(vector)[0]
|
|
|
|
| 44 |
top_indices = np.argsort(router_probs)[::-1][:2]
|
| 45 |
|
| 46 |
candidates = []
|
| 47 |
for idx in top_indices:
|
| 48 |
category = self.router_le.classes_[idx]
|
| 49 |
router_conf = router_probs[idx]
|
|
|
|
|
|
|
| 50 |
specialist = self._load_specialist(category)
|
| 51 |
|
| 52 |
if specialist:
|
|
@@ -55,9 +96,21 @@ class FinancialFilingClassifier:
|
|
| 55 |
best_idx = np.argmax(spec_probs)
|
| 56 |
label = le.classes_[best_idx]
|
| 57 |
spec_conf = spec_probs[best_idx]
|
|
|
|
|
|
|
| 58 |
combined_score = np.sqrt(router_conf * spec_conf)
|
| 59 |
-
candidates.append({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
else:
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import joblib
|
| 3 |
import torch
|
|
|
|
| 8 |
def __init__(self, model_dir):
|
| 9 |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 10 |
print(f"Loading Jina Encoder on {self.device}...")
|
| 11 |
+
|
| 12 |
+
# Jina-v3 handles Flash Attention internally if installed
|
| 13 |
self.encoder = AutoModel.from_pretrained(
|
| 14 |
"jinaai/jina-embeddings-v3",
|
| 15 |
trust_remote_code=True,
|
|
|
|
| 22 |
self.specialists = {}
|
| 23 |
self.model_dir = model_dir
|
| 24 |
|
| 25 |
+
def _get_vector(self, text, doc_length=None):
|
| 26 |
+
"""
|
| 27 |
+
Generates the feature vector: [Embedding (1024) + Log_Length (1)]
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
text (str): The document text content.
|
| 31 |
+
doc_length (int, optional): The real length of the document.
|
| 32 |
+
If None, defaults to len(text).
|
| 33 |
+
"""
|
| 34 |
+
# 1. Feature Engineering: Log Length
|
| 35 |
+
# We use the explicitly passed length if available. This allows
|
| 36 |
+
# the model to know a document is massive (e.g. Annual Report)
|
| 37 |
+
# even if we only run inference on the first 8k tokens.
|
| 38 |
+
val = int(doc_length) if doc_length is not None else len(str(text))
|
| 39 |
+
log_len = np.log1p(val)
|
| 40 |
+
|
| 41 |
with torch.no_grad():
|
| 42 |
+
# 2. Embedding Generation
|
| 43 |
+
# SAFETY MARGIN: We cap at 8100 (instead of 8192) to prevent
|
| 44 |
+
# "Rotary Embedding" off-by-one crashes in Flash Attention.
|
| 45 |
+
vec = self.encoder.encode([text], task="classification", max_length=8100)
|
| 46 |
+
|
| 47 |
+
# 3. XGBoost Compatibility
|
| 48 |
+
# XGBoost requires CPU-bound Numpy arrays.
|
| 49 |
+
if isinstance(vec, torch.Tensor):
|
| 50 |
+
vec = vec.cpu().numpy()
|
| 51 |
+
elif isinstance(vec, list):
|
| 52 |
+
vec = np.array(vec)
|
| 53 |
+
|
| 54 |
+
# 4. Concatenate: [Embedding vector, Log_Length]
|
| 55 |
return np.hstack([vec, [[log_len]]])
|
| 56 |
|
| 57 |
def _load_specialist(self, category):
|
| 58 |
+
"""Lazy loads specialist models to save RAM until needed."""
|
| 59 |
safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
|
| 60 |
if safe_name not in self.specialists:
|
| 61 |
try:
|
|
|
|
| 66 |
return None
|
| 67 |
return self.specialists[safe_name]
|
| 68 |
|
| 69 |
+
def predict(self, text, doc_length=None):
|
| 70 |
+
"""
|
| 71 |
+
Predicts the category and type of a financial document.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
text (str): The document text.
|
| 75 |
+
doc_length (int, optional): The true character length of the document.
|
| 76 |
+
Recommended for highest accuracy.
|
| 77 |
+
"""
|
| 78 |
+
vector = self._get_vector(text, doc_length=doc_length)
|
| 79 |
+
|
| 80 |
+
# 1. Router Prediction (General Category)
|
| 81 |
router_probs = self.router.predict_proba(vector)[0]
|
| 82 |
+
# We look at the top 2 candidates to handle ambiguous edge cases
|
| 83 |
top_indices = np.argsort(router_probs)[::-1][:2]
|
| 84 |
|
| 85 |
candidates = []
|
| 86 |
for idx in top_indices:
|
| 87 |
category = self.router_le.classes_[idx]
|
| 88 |
router_conf = router_probs[idx]
|
| 89 |
+
|
| 90 |
+
# 2. Specialist Prediction (Specific Type)
|
| 91 |
specialist = self._load_specialist(category)
|
| 92 |
|
| 93 |
if specialist:
|
|
|
|
| 96 |
best_idx = np.argmax(spec_probs)
|
| 97 |
label = le.classes_[best_idx]
|
| 98 |
spec_conf = spec_probs[best_idx]
|
| 99 |
+
|
| 100 |
+
# Combine Confidence Scores (Geometric Mean)
|
| 101 |
combined_score = np.sqrt(router_conf * spec_conf)
|
| 102 |
+
candidates.append({
|
| 103 |
+
"category": category,
|
| 104 |
+
"label": label,
|
| 105 |
+
"score": float(combined_score)
|
| 106 |
+
})
|
| 107 |
else:
|
| 108 |
+
# Fallback if no specialist exists for this category
|
| 109 |
+
candidates.append({
|
| 110 |
+
"category": category,
|
| 111 |
+
"label": category,
|
| 112 |
+
"score": float(router_conf)
|
| 113 |
+
})
|
| 114 |
|
| 115 |
+
# Return the highest scoring candidate
|
| 116 |
+
return max(candidates, key=lambda x: x['score'])
|