tejashsr commited on
Commit
b3c41c2
·
verified ·
1 Parent(s): fb84c7a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +39 -35
model.py CHANGED
@@ -4,44 +4,46 @@ import sys
4
  import pandas as pd
5
  import numpy as np
6
  import torch
7
- from flashrank import Ranker, RerankRequest
8
  from sentence_transformers import SentenceTransformer, CrossEncoder
9
  from rank_bm25 import BM25Okapi
10
  from sklearn.metrics.pairwise import cosine_similarity
 
11
 
12
- # --- 1. GLOBAL ENGINES (Lazy Loading) ---
13
- # We load these globally so they don't reload on every request
14
  nlp = None
15
  retriever = None
16
  ranker = None
17
  nli_model = None
18
 
19
  def load_engines():
 
20
  global nlp, retriever, ranker, nli_model
21
  if nlp is not None: return
22
 
23
- print("⚡ TITANIUM: Waking up...")
24
 
25
- # NLP
26
  import spacy
27
- try: nlp = spacy.load("en_core_web_sm", disable=["parser"])
28
- except:
 
29
  import subprocess
30
  subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm", "--quiet"])
31
  nlp = spacy.load("en_core_web_sm", disable=["parser"])
32
-
33
- # SEARCH
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
  retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)
36
-
37
- # RERANK
38
- ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/opt")
39
-
40
- # LOGIC BRAIN (CrossEncoder = 100% Accuracy Fix)
 
41
  nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
42
- print(f"✅ TITANIUM: Ready on {device.upper()}")
43
 
44
- # --- 2. UNIVERSAL GRAPH BUILDER ---
45
  class UniversalGraphKB:
46
  def __init__(self):
47
  self.indices = {}
@@ -54,14 +56,17 @@ class UniversalGraphKB:
54
  def ingest_book(self, text, key="session"):
55
  chunks = self.get_chunks(text)
56
 
 
57
  doc = nlp(text[:100000])
58
  people = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
59
  locs = [ent.text.lower() for ent in doc.ents if ent.label_ in ["GPE", "LOC"]]
 
60
  main_person = pd.Series(people).value_counts().index[0] if people else "Unknown"
61
  main_loc = pd.Series(locs).value_counts().index[0] if locs else "Unknown"
62
 
63
  self.context_graph[key] = f"{main_person.title()} | {main_loc.title()}"
64
 
 
65
  self.indices[key] = {
66
  "text": chunks,
67
  "vectors": retriever.encode(chunks, show_progress_bar=False),
@@ -71,8 +76,9 @@ class UniversalGraphKB:
71
 
72
  kb = UniversalGraphKB()
73
 
74
- # --- 3. TITANIUM LOGIC GUARDRAILS ---
75
  def normalize_dates(text):
 
76
  text = text.lower()
77
  mapping = {
78
  "eighteenth": "1750", "18th": "1750", "nineteenth": "1850", "19th": "1850",
@@ -83,21 +89,21 @@ def normalize_dates(text):
83
  return text
84
 
85
  def extract_features(backstory, key="session"):
86
- if key not in kb.indices: return 0.0, "", "No Data"
87
  idx = kb.indices[key]
88
  context = kb.context_graph.get(key, "")
89
 
90
- # A. Search
91
  aug_query = f"{backstory} (Context: {context})"
92
  q_vec = retriever.encode(aug_query)
93
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
94
- candidates = [{"id": i, "text": idx['text'][i]} for i in v_scores.argsort()[-10:][::-1]]
95
 
96
- # B. Rerank
 
97
  reranked = ranker.rerank(RerankRequest(query=backstory, passages=candidates))
98
  best_chunk = reranked[0]['text']
99
 
100
- # C. Date Guardrail
101
  norm_claim = normalize_dates(backstory)
102
  norm_ev = normalize_dates(best_chunk)
103
  years_c = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_claim)]
@@ -105,30 +111,28 @@ def extract_features(backstory, key="session"):
105
 
106
  if years_c and years_e:
107
  if not any(abs(yc - ye) < 5 for yc in years_c for ye in years_e):
108
- return 0.0, best_chunk, f"TIMELINE ERROR: {years_c[0]} vs {years_e[0]}"
109
 
110
- # D. NEURAL CHECK (The Fix: Truth -> Claim)
 
111
  scores = nli_model.predict([(best_chunk, aug_query)])[0]
112
-
113
- # Softmax
114
  exp_scores = np.exp(scores - np.max(scores))
115
  probs = exp_scores / exp_scores.sum()
116
 
117
- return float(probs[1]), best_chunk, "SEMANTIC ANALYSIS"
 
118
 
119
- # --- 4. THE MISSING LINK (Paste this at the bottom) ---
120
  def predict_logic(book_text, backstory):
121
- load_engines() # Ensure engines are loaded
 
122
  kb.ingest_book(book_text, "session")
123
-
124
  score, ev, reason = extract_features(backstory, "session")
125
 
126
- # Decision
127
- pred = "Consistent" if score > 0.5 else "Contradiction"
128
-
129
  return {
130
- "prediction": pred,
131
  "rationale": reason,
132
- "evidence": ev[:350] + "...",
133
  "score": round(score, 2)
134
  }
 
4
  import pandas as pd
5
  import numpy as np
6
  import torch
 
7
  from sentence_transformers import SentenceTransformer, CrossEncoder
8
  from rank_bm25 import BM25Okapi
9
  from sklearn.metrics.pairwise import cosine_similarity
10
+ from flashrank import Ranker, RerankRequest
11
 
12
+ # --- GLOBAL ENGINES ---
 
13
  nlp = None
14
  retriever = None
15
  ranker = None
16
  nli_model = None
17
 
18
  def load_engines():
19
+ """Lazy loads models to ensure the Space starts without timing out."""
20
  global nlp, retriever, ranker, nli_model
21
  if nlp is not None: return
22
 
23
+ print("⚡ Awakening Titanium Brain...")
24
 
25
+ # 1. NLP Core (Spacy)
26
  import spacy
27
+ try:
28
+ nlp = spacy.load("en_core_web_sm", disable=["parser"])
29
+ except:
30
  import subprocess
31
  subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm", "--quiet"])
32
  nlp = spacy.load("en_core_web_sm", disable=["parser"])
33
+
34
+ # 2. Vector Search (MiniLM)
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)
37
+
38
+ # 3. Precision Reranker (FlashRank)
39
+ ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
40
+
41
+ # 4. Logic Core (DeBERTa Cross-Encoder)
42
+ # The CrossEncoder handles (Premise, Hypothesis) logic with high accuracy.
43
  nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
44
+ print(f"✅ Engines Ready on {device.upper()}")
45
 
46
+ # --- UNIVERSAL GRAPH KNOWLEDGE BASE ---
47
  class UniversalGraphKB:
48
  def __init__(self):
49
  self.indices = {}
 
56
  def ingest_book(self, text, key="session"):
57
  chunks = self.get_chunks(text)
58
 
59
+ # Entity Graph Extraction
60
  doc = nlp(text[:100000])
61
  people = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
62
  locs = [ent.text.lower() for ent in doc.ents if ent.label_ in ["GPE", "LOC"]]
63
+
64
  main_person = pd.Series(people).value_counts().index[0] if people else "Unknown"
65
  main_loc = pd.Series(locs).value_counts().index[0] if locs else "Unknown"
66
 
67
  self.context_graph[key] = f"{main_person.title()} | {main_loc.title()}"
68
 
69
+ # Vector + BM25 Hybrid Indexing
70
  self.indices[key] = {
71
  "text": chunks,
72
  "vectors": retriever.encode(chunks, show_progress_bar=False),
 
76
 
77
  kb = UniversalGraphKB()
78
 
79
+ # --- TITANIUM LOGIC ENGINE ---
80
  def normalize_dates(text):
81
+ """Symbolic Layer: Translates words to years for mathematical comparison."""
82
  text = text.lower()
83
  mapping = {
84
  "eighteenth": "1750", "18th": "1750", "nineteenth": "1850", "19th": "1850",
 
89
  return text
90
 
91
  def extract_features(backstory, key="session"):
92
+ if key not in kb.indices: return 0.0, "No Data", "Ingestion Failed"
93
  idx = kb.indices[key]
94
  context = kb.context_graph.get(key, "")
95
 
96
+ # 1. Hybrid Retrieval
97
  aug_query = f"{backstory} (Context: {context})"
98
  q_vec = retriever.encode(aug_query)
99
  v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
 
100
 
101
+ # 2. FlashRank Reranking
102
+ candidates = [{"id": i, "text": idx['text'][i]} for i in v_scores.argsort()[-15:][::-1]]
103
  reranked = ranker.rerank(RerankRequest(query=backstory, passages=candidates))
104
  best_chunk = reranked[0]['text']
105
 
106
+ # 3. TITANIUM GUARDRAIL: Timeline Analysis
107
  norm_claim = normalize_dates(backstory)
108
  norm_ev = normalize_dates(best_chunk)
109
  years_c = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_claim)]
 
111
 
112
  if years_c and years_e:
113
  if not any(abs(yc - ye) < 5 for yc in years_c for ye in years_e):
114
+ return 0.0, best_chunk, f"TIMELINE DISCREPANCY: {years_c[0]} vs {years_e[0]}"
115
 
116
+ # 4. NEURAL BRAIN: Semantic Entailment
117
+ # Order: (Evidence, Claim) -> Cross-Encoder calculates if Truth proves Claim
118
  scores = nli_model.predict([(best_chunk, aug_query)])[0]
 
 
119
  exp_scores = np.exp(scores - np.max(scores))
120
  probs = exp_scores / exp_scores.sum()
121
 
122
+ # Return Entailment Score (Index 1)
123
+ return float(probs[1]), best_chunk, "Neural Semantic Verification"
124
 
 
125
  def predict_logic(book_text, backstory):
126
+ """Main Entry point for the API."""
127
+ load_engines()
128
  kb.ingest_book(book_text, "session")
 
129
  score, ev, reason = extract_features(backstory, "session")
130
 
131
+ # Final Verdict Threshold
132
+ prediction = "Consistent" if score > 0.5 else "Contradiction"
 
133
  return {
134
+ "prediction": prediction,
135
  "rationale": reason,
136
+ "evidence": ev[:400] + "...",
137
  "score": round(score, 2)
138
  }