tejashsr commited on
Commit
d27a216
·
verified ·
1 Parent(s): 8172812

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +33 -40
model.py CHANGED
@@ -5,34 +5,32 @@ import numpy as np
5
  import spacy
6
  import torch
7
  from flashrank import Ranker, RerankRequest
8
- from sentence_transformers import SentenceTransformer
9
  from rank_bm25 import BM25Okapi
10
  from sklearn.metrics.pairwise import cosine_similarity
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
 
13
- # --- GLOBAL ENGINES (LAZY LOAD) ---
14
  nlp = None
15
  retriever = None
16
  ranker = None
17
- tokenizer = None
18
  nli_model = None
19
 
20
  def load_engines():
21
- global nlp, retriever, ranker, tokenizer, nli_model
22
  if nlp is not None: return
23
 
24
- print("⚡ TITANIUM: Waking up Neural Engines...")
25
  nlp = spacy.load("en_core_web_sm", disable=["parser"])
26
 
27
- # 1. Retrieval Engine
28
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
29
 
30
- # 2. Rerank Engine
31
  ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
32
 
33
- # 3. Logic Engine (DeBERTa-v3)
34
- tokenizer = AutoTokenizer.from_pretrained("cross-encoder/nli-deberta-v3-base")
35
- nli_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/nli-deberta-v3-base")
36
  print("✅ TITANIUM: Ready.")
37
 
38
  # --- UNIVERSAL KNOWLEDGE GRAPH ---
@@ -52,8 +50,6 @@ class UniversalGraphKB:
52
 
53
  def ingest_book(self, text, key="session"):
54
  chunks = self.get_chunks(text)
55
-
56
- # Auto-Protagonist Detection
57
  doc = nlp(text[:100000])
58
  names = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
59
  main_char = pd.Series(names).value_counts().index[0] if names else "Unknown"
@@ -70,7 +66,6 @@ kb = UniversalGraphKB()
70
 
71
  # --- TITANIUM LOGIC GUARDRAILS ---
72
  def normalize_dates(text):
73
- """Visual Confirmation: Turns words to numbers for the Logic Engine."""
74
  text = text.lower()
75
  mapping = {
76
  "eighteenth": "1750", "18th": "1750",
@@ -82,15 +77,8 @@ def normalize_dates(text):
82
  if word in text: text += f" ({year}) "
83
  return text
84
 
85
- def get_nli_score(premise, hypothesis):
86
- inputs = tokenizer(premise, hypothesis, return_tensors='pt', truncation=True, max_length=512)
87
- with torch.no_grad():
88
- outputs = nli_model(**inputs)
89
- probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
90
- return float(probs[1]) # Entailment
91
-
92
  def extract_features(backstory, key="session"):
93
- if key not in kb.indices: return [0,0], "", None, "Book not uploaded"
94
  idx = kb.indices[key]
95
  protagonist = kb.protagonists.get(key, "")
96
 
@@ -98,50 +86,55 @@ def extract_features(backstory, key="session"):
98
  backstory_norm = normalize_dates(backstory)
99
  aug_query = f"{backstory} (Context: {protagonist})"
100
 
101
- # 2. Search & Rerank
102
  q_vec = retriever.encode(aug_query)
103
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
104
- candidates = list(v_scores.argsort()[-30:][::-1])
105
  passages = [{"id": i, "text": idx['text'][i]} for i in candidates]
106
 
 
107
  results = ranker.rerank(RerankRequest(query=backstory, passages=passages))
108
  best_chunk = results[0]['text']
109
  best_chunk_norm = normalize_dates(best_chunk)
110
 
111
- # --- GUARDRAILS (The "Visual Confirmation") ---
112
 
113
  # A. Exact Match
114
  if backstory.strip() in best_chunk:
115
- return [1.0, 0], best_chunk, 1, "VERIFIED: Exact Text Match"
116
 
117
- # B. Math Timeline Guardrail
118
  YEAR_PATTERN = r'\b([1-2][0-9]{3})\b'
119
  q_years = [int(y) for y in re.findall(YEAR_PATTERN, backstory_norm)]
120
  e_years = [int(y) for y in re.findall(YEAR_PATTERN, best_chunk_norm)]
121
 
122
  if q_years and e_years:
123
  if not any(abs(by - ey) < 5 for by in q_years for ey in e_years):
124
- return [0.0, 1], best_chunk, 0, f"CRITICAL: Timeline Mismatch ({q_years[0]} vs {e_years[0]})"
125
 
126
- # C. Neural Semantic Check
127
- score = get_nli_score(aug_query, best_chunk)
128
- return [score, 0], best_chunk, None, ""
 
 
 
 
 
 
 
129
 
130
  # --- API WRAPPER ---
131
  def predict_logic(book_text, backstory):
132
  load_engines()
133
  kb.ingest_book(book_text, "session")
134
- feats, ev, verdict, rat = extract_features(backstory, "session")
135
 
136
- # Guardrail Triggered
137
- if verdict is not None:
138
- return {"prediction": "Consistent" if verdict==1 else "Contradiction", "rationale": rat, "evidence": ev[:350] + "...", "score": 1.0 if verdict==1 else 0.0}
139
 
140
- # Neural Decision (Threshold 0.15)
141
- pred = 1 if feats[0] > 0.15 else 0
142
  return {
143
- "prediction": "Consistent" if pred==1 else "Contradiction",
144
- "rationale": f"Semantic Consistency Score: {feats[0]:.2f}",
145
  "evidence": ev[:350] + "...",
146
- "score": round(feats[0], 2)
147
  }
 
5
  import spacy
6
  import torch
7
  from flashrank import Ranker, RerankRequest
8
+ from sentence_transformers import SentenceTransformer, CrossEncoder
9
  from rank_bm25 import BM25Okapi
10
  from sklearn.metrics.pairwise import cosine_similarity
 
11
 
12
+ # --- GLOBAL ENGINES ---
13
  nlp = None
14
  retriever = None
15
  ranker = None
 
16
  nli_model = None
17
 
18
  def load_engines():
19
+ global nlp, retriever, ranker, nli_model
20
  if nlp is not None: return
21
 
22
+ print("⚡ TITANIUM: Waking up...")
23
  nlp = spacy.load("en_core_web_sm", disable=["parser"])
24
 
25
+ # 1. Retrieval (MiniLM)
26
  retriever = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
+ # 2. Rerank (FlashRank)
29
  ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
30
 
31
+ # 3. Logic (CrossEncoder - THE FIX)
32
+ # This wrapper handles the labels automatically. No more 0.00 errors.
33
+ nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base')
34
  print("✅ TITANIUM: Ready.")
35
 
36
  # --- UNIVERSAL KNOWLEDGE GRAPH ---
 
50
 
51
  def ingest_book(self, text, key="session"):
52
  chunks = self.get_chunks(text)
 
 
53
  doc = nlp(text[:100000])
54
  names = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
55
  main_char = pd.Series(names).value_counts().index[0] if names else "Unknown"
 
66
 
67
  # --- TITANIUM LOGIC GUARDRAILS ---
68
  def normalize_dates(text):
 
69
  text = text.lower()
70
  mapping = {
71
  "eighteenth": "1750", "18th": "1750",
 
77
  if word in text: text += f" ({year}) "
78
  return text
79
 
 
 
 
 
 
 
 
80
  def extract_features(backstory, key="session"):
81
+ if key not in kb.indices: return 0.0, "", "Book not uploaded"
82
  idx = kb.indices[key]
83
  protagonist = kb.protagonists.get(key, "")
84
 
 
86
  backstory_norm = normalize_dates(backstory)
87
  aug_query = f"{backstory} (Context: {protagonist})"
88
 
89
+ # 2. Search
90
  q_vec = retriever.encode(aug_query)
91
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
92
+ candidates = list(v_scores.argsort()[-15:][::-1])
93
  passages = [{"id": i, "text": idx['text'][i]} for i in candidates]
94
 
95
+ # 3. Rerank
96
  results = ranker.rerank(RerankRequest(query=backstory, passages=passages))
97
  best_chunk = results[0]['text']
98
  best_chunk_norm = normalize_dates(best_chunk)
99
 
100
+ # --- GUARDRAILS ---
101
 
102
  # A. Exact Match
103
  if backstory.strip() in best_chunk:
104
+ return 1.0, best_chunk, "VERIFIED: Exact Text Match"
105
 
106
+ # B. Math Timeline
107
  YEAR_PATTERN = r'\b([1-2][0-9]{3})\b'
108
  q_years = [int(y) for y in re.findall(YEAR_PATTERN, backstory_norm)]
109
  e_years = [int(y) for y in re.findall(YEAR_PATTERN, best_chunk_norm)]
110
 
111
  if q_years and e_years:
112
  if not any(abs(by - ey) < 5 for by in q_years for ey in e_years):
113
+ return 0.0, best_chunk, f"TIMELINE MISMATCH: {q_years[0]} vs {e_years[0]}"
114
 
115
+ # C. Neural Semantic Check (CrossEncoder)
116
+ # Returns logits: [Contradiction, Entailment, Neutral]
117
+ scores = nli_model.predict([(aug_query, best_chunk)])[0]
118
+
119
+ # We want Entailment (Index 1). We apply Softmax manually for a nice percentage.
120
+ exp_scores = np.exp(scores)
121
+ probs = exp_scores / np.sum(exp_scores)
122
+ entailment_score = probs[1]
123
+
124
+ return float(entailment_score), best_chunk, "SEMANTIC ANALYSIS"
125
 
126
  # --- API WRAPPER ---
127
  def predict_logic(book_text, backstory):
128
  load_engines()
129
  kb.ingest_book(book_text, "session")
130
+ score, ev, reason = extract_features(backstory, "session")
131
 
132
+ # Decision Threshold
133
+ pred = "Consistent" if score > 0.3 else "Contradiction"
 
134
 
 
 
135
  return {
136
+ "prediction": pred,
137
+ "rationale": reason,
138
  "evidence": ev[:350] + "...",
139
+ "score": round(score, 2)
140
  }