tejashsr commited on
Commit
f63de94
·
verified ·
1 Parent(s): e39f0e1

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +55 -77
model.py CHANGED
@@ -1,4 +1,5 @@
1
- import os, re, sys, subprocess
 
2
  import pandas as pd
3
  import numpy as np
4
  import spacy
@@ -7,41 +8,34 @@ from flashrank import Ranker, RerankRequest
7
  from sentence_transformers import SentenceTransformer
8
  from rank_bm25 import BM25Okapi
9
  from sklearn.metrics.pairwise import cosine_similarity
10
- from sklearn.ensemble import RandomForestClassifier
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
 
13
- # ============================
14
- # 1. LAZY LOADING GLOBALS
15
- # ============================
16
  nlp = None
17
  retriever = None
18
  ranker = None
19
  tokenizer = None
20
  nli_model = None
21
- kb = None
22
- clf = None
23
 
24
- def load_engines_if_needed():
25
- global nlp, retriever, ranker, tokenizer, nli_model, kb, clf
26
-
27
- if nlp is None:
28
- print("⏳ Lazy Loading: Starting Engines...")
29
- try: nlp = spacy.load("en_core_web_sm", disable=["parser"])
30
- except:
31
- subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
32
- nlp = spacy.load("en_core_web_sm", disable=["parser"])
33
 
34
- retriever = SentenceTransformer('all-MiniLM-L6-v2')
35
- ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/opt")
36
- tokenizer = AutoTokenizer.from_pretrained("cross-encoder/nli-deberta-v3-base")
37
- nli_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/nli-deberta-v3-base")
38
-
39
- kb = UniversalGraphKB()
40
- print("✅ Engines Ready")
 
 
 
 
 
 
41
 
42
- # ============================
43
- # 2. UNIVERSAL KNOWLEDGE GRAPH
44
- # ============================
45
  class UniversalGraphKB:
46
  def __init__(self):
47
  self.indices = {}
@@ -50,14 +44,16 @@ class UniversalGraphKB:
50
  def get_chunks(self, text):
51
  words = re.findall(r'\S+', text)
52
  chunks = []
53
- step = 400
54
  for i in range(0, len(words), step):
55
  chunk = " ".join(words[i:i + 500])
56
  if len(chunk) > 50: chunks.append(chunk)
57
  return chunks
58
 
59
- def ingest_book(self, text, key="session_book"):
60
  chunks = self.get_chunks(text)
 
 
61
  doc = nlp(text[:100000])
62
  names = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
63
  main_char = pd.Series(names).value_counts().index[0] if names else "Unknown"
@@ -70,11 +66,11 @@ class UniversalGraphKB:
70
  }
71
  return main_char.title()
72
 
73
- # ============================
74
- # 3. TITANIUM LOGIC (With Word-to-Number Patch)
75
- # ============================
76
  def normalize_dates(text):
77
- """Converts text centuries to approximate years for the Regex to catch."""
78
  text = text.lower()
79
  mapping = {
80
  "eighteenth": "1750", "18th": "1750",
@@ -83,8 +79,7 @@ def normalize_dates(text):
83
  "twenty-first": "2050", "21st": "2050"
84
  }
85
  for word, year in mapping.items():
86
- if word in text:
87
- text += f" ({year}) " # Append the digit so Regex sees it
88
  return text
89
 
90
  def get_nli_score(premise, hypothesis):
@@ -92,78 +87,61 @@ def get_nli_score(premise, hypothesis):
92
  with torch.no_grad():
93
  outputs = nli_model(**inputs)
94
  probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
95
- # DeBERTa-v3-base NLI Output: [Contradiction, Entailment, Neutral] (usually)
96
- # We return the Entailment score (Index 1) minus Contradiction (Index 0)
97
- # Higher = More Consistent
98
- return float(probs[1])
99
 
100
- def extract_features(backstory, book_key="session_book"):
101
- if kb is None or book_key not in kb.indices: return [0,0], "", None, "Book not ingested."
102
- idx = kb.indices[book_key]
103
- protagonist = kb.protagonists.get(book_key, "")
104
 
105
- # Pre-process backstory for dates
106
  backstory_norm = normalize_dates(backstory)
107
-
108
  aug_query = f"{backstory} (Context: {protagonist})"
109
 
 
110
  q_vec = retriever.encode(aug_query)
111
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
112
  candidates = list(v_scores.argsort()[-30:][::-1])
113
  passages = [{"id": i, "text": idx['text'][i]} for i in candidates]
114
 
115
- rerank_req = RerankRequest(query=backstory, passages=passages)
116
- results = ranker.rerank(rerank_req)
117
  best_chunk = results[0]['text']
118
-
119
- # NORMALIZE CHUNK FOR DATES TOO
120
  best_chunk_norm = normalize_dates(best_chunk)
121
 
122
- # --- LOGIC GUARDRAILS ---
 
 
123
  if backstory.strip() in best_chunk:
124
  return [1.0, 0], best_chunk, 1, "VERIFIED: Exact Text Match"
125
 
 
126
  YEAR_PATTERN = r'\b([1-2][0-9]{3})\b'
127
  q_years = [int(y) for y in re.findall(YEAR_PATTERN, backstory_norm)]
128
  e_years = [int(y) for y in re.findall(YEAR_PATTERN, best_chunk_norm)]
129
 
130
  if q_years and e_years:
131
- # If gap > 5 years -> Contradiction
132
  if not any(abs(by - ey) < 5 for by in q_years for ey in e_years):
133
  return [0.0, 1], best_chunk, 0, f"CRITICAL: Timeline Mismatch ({q_years[0]} vs {e_years[0]})"
134
 
 
135
  score = get_nli_score(aug_query, best_chunk)
136
  return [score, 0], best_chunk, None, ""
137
 
138
- # ============================
139
- # 4. PUBLIC WRAPPER
140
- # ============================
141
- def predict_for_website(backstory, book_text=None):
142
- load_engines_if_needed()
143
 
144
- if book_text:
145
- kb.ingest_book(book_text, "session_book")
146
-
147
- feats, ev, verdict, rat = extract_features(backstory, "session_book")
148
-
149
- # 1. Guardrail Verdict (Math/Exact)
150
  if verdict is not None:
151
- return {
152
- "prediction": "Consistent" if verdict==1 else "Contradiction",
153
- "confidence": 1.0,
154
- "rationale": rat,
155
- "evidence": ev[:300] + "...",
156
- "protagonist": kb.protagonists.get("session_book", "Unknown")
157
- }
158
-
159
- # 2. Neural Verdict (NLI Score)
160
- # Threshold 0.5: If Entailment > 0.5, it's consistent.
161
- pred = 1 if feats[0] > 0.2 else 0
162
 
 
 
163
  return {
164
- "prediction": "Consistent" if pred==1 else "Contradiction",
165
- "confidence": round(feats[0], 2),
166
- "rationale": f"Semantic Consistency Score: {feats[0]:.2f}",
167
- "evidence": ev[:300] + "...",
168
- "protagonist": kb.protagonists.get("session_book", "Unknown")
169
  }
 
1
+ import re
2
+ import sys
3
  import pandas as pd
4
  import numpy as np
5
  import spacy
 
8
  from sentence_transformers import SentenceTransformer
9
  from rank_bm25 import BM25Okapi
10
  from sklearn.metrics.pairwise import cosine_similarity
 
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
 
13
+ # --- GLOBAL ENGINES (LAZY LOAD) ---
 
 
14
  nlp = None
15
  retriever = None
16
  ranker = None
17
  tokenizer = None
18
  nli_model = None
 
 
19
 
20
+ def load_engines():
21
+ global nlp, retriever, ranker, tokenizer, nli_model
22
+ if nlp is not None: return
 
 
 
 
 
 
23
 
24
+ print("⚡ TITANIUM: Waking up Neural Engines...")
25
+ nlp = spacy.load("en_core_web_sm", disable=["parser"])
26
+
27
+ # 1. Retrieval Engine
28
+ retriever = SentenceTransformer('all-MiniLM-L6-v2')
29
+
30
+ # 2. Rerank Engine
31
+ ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
32
+
33
+ # 3. Logic Engine (DeBERTa-v3)
34
+ tokenizer = AutoTokenizer.from_pretrained("cross-encoder/nli-deberta-v3-base")
35
+ nli_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/nli-deberta-v3-base")
36
+ print("✅ TITANIUM: Ready.")
37
 
38
+ # --- UNIVERSAL KNOWLEDGE GRAPH ---
 
 
39
  class UniversalGraphKB:
40
  def __init__(self):
41
  self.indices = {}
 
44
  def get_chunks(self, text):
45
  words = re.findall(r'\S+', text)
46
  chunks = []
47
+ step = 400
48
  for i in range(0, len(words), step):
49
  chunk = " ".join(words[i:i + 500])
50
  if len(chunk) > 50: chunks.append(chunk)
51
  return chunks
52
 
53
+ def ingest_book(self, text, key="session"):
54
  chunks = self.get_chunks(text)
55
+
56
+ # Auto-Protagonist Detection
57
  doc = nlp(text[:100000])
58
  names = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
59
  main_char = pd.Series(names).value_counts().index[0] if names else "Unknown"
 
66
  }
67
  return main_char.title()
68
 
69
+ kb = UniversalGraphKB()
70
+
71
+ # --- TITANIUM LOGIC GUARDRAILS ---
72
  def normalize_dates(text):
73
+ """Visual Confirmation: Turns words to numbers for the Logic Engine."""
74
  text = text.lower()
75
  mapping = {
76
  "eighteenth": "1750", "18th": "1750",
 
79
  "twenty-first": "2050", "21st": "2050"
80
  }
81
  for word, year in mapping.items():
82
+ if word in text: text += f" ({year}) "
 
83
  return text
84
 
85
  def get_nli_score(premise, hypothesis):
 
87
  with torch.no_grad():
88
  outputs = nli_model(**inputs)
89
  probs = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
90
+ return float(probs[1]) # Entailment
 
 
 
91
 
92
+ def extract_features(backstory, key="session"):
93
+ if key not in kb.indices: return [0,0], "", None, "Book not uploaded"
94
+ idx = kb.indices[key]
95
+ protagonist = kb.protagonists.get(key, "")
96
 
97
+ # 1. Normalize
98
  backstory_norm = normalize_dates(backstory)
 
99
  aug_query = f"{backstory} (Context: {protagonist})"
100
 
101
+ # 2. Search & Rerank
102
  q_vec = retriever.encode(aug_query)
103
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
104
  candidates = list(v_scores.argsort()[-30:][::-1])
105
  passages = [{"id": i, "text": idx['text'][i]} for i in candidates]
106
 
107
+ results = ranker.rerank(RerankRequest(query=backstory, passages=passages))
 
108
  best_chunk = results[0]['text']
 
 
109
  best_chunk_norm = normalize_dates(best_chunk)
110
 
111
+ # --- GUARDRAILS (The "Visual Confirmation") ---
112
+
113
+ # A. Exact Match
114
  if backstory.strip() in best_chunk:
115
  return [1.0, 0], best_chunk, 1, "VERIFIED: Exact Text Match"
116
 
117
+ # B. Math Timeline Guardrail
118
  YEAR_PATTERN = r'\b([1-2][0-9]{3})\b'
119
  q_years = [int(y) for y in re.findall(YEAR_PATTERN, backstory_norm)]
120
  e_years = [int(y) for y in re.findall(YEAR_PATTERN, best_chunk_norm)]
121
 
122
  if q_years and e_years:
 
123
  if not any(abs(by - ey) < 5 for by in q_years for ey in e_years):
124
  return [0.0, 1], best_chunk, 0, f"CRITICAL: Timeline Mismatch ({q_years[0]} vs {e_years[0]})"
125
 
126
+ # C. Neural Semantic Check
127
  score = get_nli_score(aug_query, best_chunk)
128
  return [score, 0], best_chunk, None, ""
129
 
130
+ # --- API WRAPPER ---
131
+ def predict_logic(book_text, backstory):
132
+ load_engines()
133
+ kb.ingest_book(book_text, "session")
134
+ feats, ev, verdict, rat = extract_features(backstory, "session")
135
 
136
+ # Guardrail Triggered
 
 
 
 
 
137
  if verdict is not None:
138
+ return {"prediction": "Consistent" if verdict==1 else "Contradiction", "rationale": rat, "evidence": ev[:350] + "...", "score": 1.0 if verdict==1 else 0.0}
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # Neural Decision (Threshold 0.15)
141
+ pred = 1 if feats[0] > 0.15 else 0
142
  return {
143
+ "prediction": "Consistent" if pred==1 else "Contradiction",
144
+ "rationale": f"Semantic Consistency Score: {feats[0]:.2f}",
145
+ "evidence": ev[:350] + "...",
146
+ "score": round(feats[0], 2)
 
147
  }