silashundhausen commited on
Commit
6f2deaf
·
verified ·
1 Parent(s): 1933e19

Update inference_wrapper.py

Browse files
Files changed (1) hide show
  1. inference_wrapper.py +62 -9
inference_wrapper.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import numpy as np
3
  import joblib
4
  import torch
@@ -9,6 +8,8 @@ class FinancialFilingClassifier:
9
  def __init__(self, model_dir):
10
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  print(f"Loading Jina Encoder on {self.device}...")
 
 
12
  self.encoder = AutoModel.from_pretrained(
13
  "jinaai/jina-embeddings-v3",
14
  trust_remote_code=True,
@@ -21,13 +22,40 @@ class FinancialFilingClassifier:
21
  self.specialists = {}
22
  self.model_dir = model_dir
23
 
24
- def _get_vector(self, text):
25
- log_len = np.log1p(len(str(text)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  with torch.no_grad():
27
- vec = self.encoder.encode([text], task="classification", max_length=8192)
 
 
 
 
 
 
 
 
 
 
 
 
28
  return np.hstack([vec, [[log_len]]])
29
 
30
  def _load_specialist(self, category):
 
31
  safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
32
  if safe_name not in self.specialists:
33
  try:
@@ -38,15 +66,28 @@ class FinancialFilingClassifier:
38
  return None
39
  return self.specialists[safe_name]
40
 
41
- def predict(self, text):
42
- vector = self._get_vector(text)
 
 
 
 
 
 
 
 
 
 
43
  router_probs = self.router.predict_proba(vector)[0]
 
44
  top_indices = np.argsort(router_probs)[::-1][:2]
45
 
46
  candidates = []
47
  for idx in top_indices:
48
  category = self.router_le.classes_[idx]
49
  router_conf = router_probs[idx]
 
 
50
  specialist = self._load_specialist(category)
51
 
52
  if specialist:
@@ -55,9 +96,21 @@ class FinancialFilingClassifier:
55
  best_idx = np.argmax(spec_probs)
56
  label = le.classes_[best_idx]
57
  spec_conf = spec_probs[best_idx]
 
 
58
  combined_score = np.sqrt(router_conf * spec_conf)
59
- candidates.append({"category": category, "label": label, "score": float(combined_score)})
 
 
 
 
60
  else:
61
- candidates.append({"category": category, "label": category, "score": float(router_conf)})
 
 
 
 
 
62
 
63
- return max(candidates, key=lambda x: x['score'])
 
 
 
1
  import numpy as np
2
  import joblib
3
  import torch
 
8
  def __init__(self, model_dir):
9
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
  print(f"Loading Jina Encoder on {self.device}...")
11
+
12
+ # Jina-v3 handles Flash Attention internally if installed
13
  self.encoder = AutoModel.from_pretrained(
14
  "jinaai/jina-embeddings-v3",
15
  trust_remote_code=True,
 
22
  self.specialists = {}
23
  self.model_dir = model_dir
24
 
25
+ def _get_vector(self, text, doc_length=None):
26
+ """
27
+ Generates the feature vector: [Embedding (1024) + Log_Length (1)]
28
+
29
+ Args:
30
+ text (str): The document text content.
31
+ doc_length (int, optional): The real length of the document.
32
+ If None, defaults to len(text).
33
+ """
34
+ # 1. Feature Engineering: Log Length
35
+ # We use the explicitly passed length if available. This allows
36
+ # the model to know a document is massive (e.g. Annual Report)
37
+ # even if we only run inference on the first 8k tokens.
38
+ val = int(doc_length) if doc_length is not None else len(str(text))
39
+ log_len = np.log1p(val)
40
+
41
  with torch.no_grad():
42
+ # 2. Embedding Generation
43
+ # SAFETY MARGIN: We cap at 8100 (instead of 8192) to prevent
44
+ # "Rotary Embedding" off-by-one crashes in Flash Attention.
45
+ vec = self.encoder.encode([text], task="classification", max_length=8100)
46
+
47
+ # 3. XGBoost Compatibility
48
+ # XGBoost requires CPU-bound Numpy arrays.
49
+ if isinstance(vec, torch.Tensor):
50
+ vec = vec.cpu().numpy()
51
+ elif isinstance(vec, list):
52
+ vec = np.array(vec)
53
+
54
+ # 4. Concatenate: [Embedding vector, Log_Length]
55
  return np.hstack([vec, [[log_len]]])
56
 
57
  def _load_specialist(self, category):
58
+ """Lazy loads specialist models to save RAM until needed."""
59
  safe_name = category.replace(" ", "_").replace("&", "and").replace("/", "_")
60
  if safe_name not in self.specialists:
61
  try:
 
66
  return None
67
  return self.specialists[safe_name]
68
 
69
+ def predict(self, text, doc_length=None):
70
+ """
71
+ Predicts the category and type of a financial document.
72
+
73
+ Args:
74
+ text (str): The document text.
75
+ doc_length (int, optional): The true character length of the document.
76
+ Recommended for highest accuracy.
77
+ """
78
+ vector = self._get_vector(text, doc_length=doc_length)
79
+
80
+ # 1. Router Prediction (General Category)
81
  router_probs = self.router.predict_proba(vector)[0]
82
+ # We look at the top 2 candidates to handle ambiguous edge cases
83
  top_indices = np.argsort(router_probs)[::-1][:2]
84
 
85
  candidates = []
86
  for idx in top_indices:
87
  category = self.router_le.classes_[idx]
88
  router_conf = router_probs[idx]
89
+
90
+ # 2. Specialist Prediction (Specific Type)
91
  specialist = self._load_specialist(category)
92
 
93
  if specialist:
 
96
  best_idx = np.argmax(spec_probs)
97
  label = le.classes_[best_idx]
98
  spec_conf = spec_probs[best_idx]
99
+
100
+ # Combine Confidence Scores (Geometric Mean)
101
  combined_score = np.sqrt(router_conf * spec_conf)
102
+ candidates.append({
103
+ "category": category,
104
+ "label": label,
105
+ "score": float(combined_score)
106
+ })
107
  else:
108
+ # Fallback if no specialist exists for this category
109
+ candidates.append({
110
+ "category": category,
111
+ "label": category,
112
+ "score": float(router_conf)
113
+ })
114
 
115
+ # Return the highest scoring candidate
116
+ return max(candidates, key=lambda x: x['score'])