hierarchical-filing-classifier / inference_wrapper.py
silashundhausen's picture
Update inference_wrapper.py
6f2deaf verified
import numpy as np
import joblib
import torch
from transformers import AutoModel
import os
class FinancialFilingClassifier:
def __init__(self, model_dir):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading Jina Encoder on {self.device}...")
# Jina-v3 handles Flash Attention internally if installed
self.encoder = AutoModel.from_pretrained(
"jinaai/jina-embeddings-v3",
trust_remote_code=True,
torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32
).to(self.device)
print("Loading XGBoost Cascade...")
self.router = joblib.load(os.path.join(model_dir, "router_xgb.pkl"))
self.router_le = joblib.load(os.path.join(model_dir, "router_le.pkl"))
self.specialists = {}
self.model_dir = model_dir
def _get_vector(self, text, doc_length=None):
"""
Generates the feature vector: [Embedding (1024) + Log_Length (1)]
Args:
text (str): The document text content.
doc_length (int, optional): The real length of the document.
If None, defaults to len(text).
"""
# 1. Feature Engineering: Log Length
# We use the explicitly passed length if available. This allows
# the model to know a document is massive (e.g. Annual Report)
# even if we only run inference on the first 8k tokens.
val = int(doc_length) if doc_length is not None else len(str(text))
log_len = np.log1p(val)
with torch.no_grad():
# 2. Embedding Generation
# SAFETY MARGIN: We cap at 8100 (instead of 8192) to prevent
# "Rotary Embedding" off-by-one crashes in Flash Attention.
vec = self.encoder.encode([text], task="classification", max_length=8100)
# 3. XGBoost Compatibility
# XGBoost requires CPU-bound Numpy arrays.
if isinstance(vec, torch.Tensor):
vec = vec.cpu().numpy()
elif isinstance(vec, list):
vec = np.array(vec)
# 4. Concatenate: [Embedding vector, Log_Length]
return np.hstack([vec, [[log_len]]])
def _load_specialist(self, category):
"""Lazy loads specialist models to save RAM until needed."""
safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
if safe_name not in self.specialists:
try:
clf = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl"))
le = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl"))
self.specialists[safe_name] = (clf, le)
except FileNotFoundError:
return None
return self.specialists[safe_name]
def predict(self, text, doc_length=None):
"""
Predicts the category and type of a financial document.
Args:
text (str): The document text.
doc_length (int, optional): The true character length of the document.
Recommended for highest accuracy.
"""
vector = self._get_vector(text, doc_length=doc_length)
# 1. Router Prediction (General Category)
router_probs = self.router.predict_proba(vector)[0]
# We look at the top 2 candidates to handle ambiguous edge cases
top_indices = np.argsort(router_probs)[::-1][:2]
candidates = []
for idx in top_indices:
category = self.router_le.classes_[idx]
router_conf = router_probs[idx]
# 2. Specialist Prediction (Specific Type)
specialist = self._load_specialist(category)
if specialist:
clf, le = specialist
spec_probs = clf.predict_proba(vector)[0]
best_idx = np.argmax(spec_probs)
label = le.classes_[best_idx]
spec_conf = spec_probs[best_idx]
# Combine Confidence Scores (Geometric Mean)
combined_score = np.sqrt(router_conf * spec_conf)
candidates.append({
"category": category,
"label": label,
"score": float(combined_score)
})
else:
# Fallback if no specialist exists for this category
candidates.append({
"category": category,
"label": category,
"score": float(router_conf)
})
# Return the highest scoring candidate
return max(candidates, key=lambda x: x['score'])