import numpy as np import joblib import torch from transformers import AutoModel import os class FinancialFilingClassifier: def __init__(self, model_dir): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading Jina Encoder on {self.device}...") # Jina-v3 handles Flash Attention internally if installed self.encoder = AutoModel.from_pretrained( "jinaai/jina-embeddings-v3", trust_remote_code=True, torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32 ).to(self.device) print("Loading XGBoost Cascade...") self.router = joblib.load(os.path.join(model_dir, "router_xgb.pkl")) self.router_le = joblib.load(os.path.join(model_dir, "router_le.pkl")) self.specialists = {} self.model_dir = model_dir def _get_vector(self, text, doc_length=None): """ Generates the feature vector: [Embedding (1024) + Log_Length (1)] Args: text (str): The document text content. doc_length (int, optional): The real length of the document. If None, defaults to len(text). """ # 1. Feature Engineering: Log Length # We use the explicitly passed length if available. This allows # the model to know a document is massive (e.g. Annual Report) # even if we only run inference on the first 8k tokens. val = int(doc_length) if doc_length is not None else len(str(text)) log_len = np.log1p(val) with torch.no_grad(): # 2. Embedding Generation # SAFETY MARGIN: We cap at 8100 (instead of 8192) to prevent # "Rotary Embedding" off-by-one crashes in Flash Attention. vec = self.encoder.encode([text], task="classification", max_length=8100) # 3. XGBoost Compatibility # XGBoost requires CPU-bound Numpy arrays. if isinstance(vec, torch.Tensor): vec = vec.cpu().numpy() elif isinstance(vec, list): vec = np.array(vec) # 4. Concatenate: [Embedding vector, Log_Length] return np.hstack([vec, [[log_len]]]) def _load_specialist(self, category): """Lazy loads specialist models to save RAM until needed.""" safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_") if safe_name not in self.specialists: try: clf = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_xgb.pkl")) le = joblib.load(os.path.join(self.model_dir, f"specialist_{safe_name}_le.pkl")) self.specialists[safe_name] = (clf, le) except FileNotFoundError: return None return self.specialists[safe_name] def predict(self, text, doc_length=None): """ Predicts the category and type of a financial document. Args: text (str): The document text. doc_length (int, optional): The true character length of the document. Recommended for highest accuracy. """ vector = self._get_vector(text, doc_length=doc_length) # 1. Router Prediction (General Category) router_probs = self.router.predict_proba(vector)[0] # We look at the top 2 candidates to handle ambiguous edge cases top_indices = np.argsort(router_probs)[::-1][:2] candidates = [] for idx in top_indices: category = self.router_le.classes_[idx] router_conf = router_probs[idx] # 2. Specialist Prediction (Specific Type) specialist = self._load_specialist(category) if specialist: clf, le = specialist spec_probs = clf.predict_proba(vector)[0] best_idx = np.argmax(spec_probs) label = le.classes_[best_idx] spec_conf = spec_probs[best_idx] # Combine Confidence Scores (Geometric Mean) combined_score = np.sqrt(router_conf * spec_conf) candidates.append({ "category": category, "label": label, "score": float(combined_score) }) else: # Fallback if no specialist exists for this category candidates.append({ "category": category, "label": category, "score": float(router_conf) }) # Return the highest scoring candidate return max(candidates, key=lambda x: x['score'])