Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -4,44 +4,46 @@ import sys
|
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
-
from flashrank import Ranker, RerankRequest
|
| 8 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 9 |
from rank_bm25 import BM25Okapi
|
| 10 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 11 |
|
| 12 |
-
# ---
|
| 13 |
-
# We load these globally so they don't reload on every request
|
| 14 |
nlp = None
|
| 15 |
retriever = None
|
| 16 |
ranker = None
|
| 17 |
nli_model = None
|
| 18 |
|
| 19 |
def load_engines():
|
|
|
|
| 20 |
global nlp, retriever, ranker, nli_model
|
| 21 |
if nlp is not None: return
|
| 22 |
|
| 23 |
-
print("⚡
|
| 24 |
|
| 25 |
-
# NLP
|
| 26 |
import spacy
|
| 27 |
-
try:
|
| 28 |
-
|
|
|
|
| 29 |
import subprocess
|
| 30 |
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm", "--quiet"])
|
| 31 |
nlp = spacy.load("en_core_web_sm", disable=["parser"])
|
| 32 |
-
|
| 33 |
-
#
|
| 34 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 36 |
-
|
| 37 |
-
#
|
| 38 |
-
ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/
|
| 39 |
-
|
| 40 |
-
#
|
|
|
|
| 41 |
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
|
| 42 |
-
print(f"✅
|
| 43 |
|
| 44 |
-
# ---
|
| 45 |
class UniversalGraphKB:
|
| 46 |
def __init__(self):
|
| 47 |
self.indices = {}
|
|
@@ -54,14 +56,17 @@ class UniversalGraphKB:
|
|
| 54 |
def ingest_book(self, text, key="session"):
|
| 55 |
chunks = self.get_chunks(text)
|
| 56 |
|
|
|
|
| 57 |
doc = nlp(text[:100000])
|
| 58 |
people = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
|
| 59 |
locs = [ent.text.lower() for ent in doc.ents if ent.label_ in ["GPE", "LOC"]]
|
|
|
|
| 60 |
main_person = pd.Series(people).value_counts().index[0] if people else "Unknown"
|
| 61 |
main_loc = pd.Series(locs).value_counts().index[0] if locs else "Unknown"
|
| 62 |
|
| 63 |
self.context_graph[key] = f"{main_person.title()} | {main_loc.title()}"
|
| 64 |
|
|
|
|
| 65 |
self.indices[key] = {
|
| 66 |
"text": chunks,
|
| 67 |
"vectors": retriever.encode(chunks, show_progress_bar=False),
|
|
@@ -71,8 +76,9 @@ class UniversalGraphKB:
|
|
| 71 |
|
| 72 |
kb = UniversalGraphKB()
|
| 73 |
|
| 74 |
-
# ---
|
| 75 |
def normalize_dates(text):
|
|
|
|
| 76 |
text = text.lower()
|
| 77 |
mapping = {
|
| 78 |
"eighteenth": "1750", "18th": "1750", "nineteenth": "1850", "19th": "1850",
|
|
@@ -83,21 +89,21 @@ def normalize_dates(text):
|
|
| 83 |
return text
|
| 84 |
|
| 85 |
def extract_features(backstory, key="session"):
|
| 86 |
-
if key not in kb.indices: return 0.0, "", "
|
| 87 |
idx = kb.indices[key]
|
| 88 |
context = kb.context_graph.get(key, "")
|
| 89 |
|
| 90 |
-
#
|
| 91 |
aug_query = f"{backstory} (Context: {context})"
|
| 92 |
q_vec = retriever.encode(aug_query)
|
| 93 |
v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
|
| 94 |
-
candidates = [{"id": i, "text": idx['text'][i]} for i in v_scores.argsort()[-10:][::-1]]
|
| 95 |
|
| 96 |
-
#
|
|
|
|
| 97 |
reranked = ranker.rerank(RerankRequest(query=backstory, passages=candidates))
|
| 98 |
best_chunk = reranked[0]['text']
|
| 99 |
|
| 100 |
-
#
|
| 101 |
norm_claim = normalize_dates(backstory)
|
| 102 |
norm_ev = normalize_dates(best_chunk)
|
| 103 |
years_c = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_claim)]
|
|
@@ -105,30 +111,28 @@ def extract_features(backstory, key="session"):
|
|
| 105 |
|
| 106 |
if years_c and years_e:
|
| 107 |
if not any(abs(yc - ye) < 5 for yc in years_c for ye in years_e):
|
| 108 |
-
return 0.0, best_chunk, f"TIMELINE
|
| 109 |
|
| 110 |
-
#
|
|
|
|
| 111 |
scores = nli_model.predict([(best_chunk, aug_query)])[0]
|
| 112 |
-
|
| 113 |
-
# Softmax
|
| 114 |
exp_scores = np.exp(scores - np.max(scores))
|
| 115 |
probs = exp_scores / exp_scores.sum()
|
| 116 |
|
| 117 |
-
|
|
|
|
| 118 |
|
| 119 |
-
# --- 4. THE MISSING LINK (Paste this at the bottom) ---
|
| 120 |
def predict_logic(book_text, backstory):
|
| 121 |
-
|
|
|
|
| 122 |
kb.ingest_book(book_text, "session")
|
| 123 |
-
|
| 124 |
score, ev, reason = extract_features(backstory, "session")
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
|
| 128 |
-
|
| 129 |
return {
|
| 130 |
-
"prediction":
|
| 131 |
"rationale": reason,
|
| 132 |
-
"evidence": ev[:
|
| 133 |
"score": round(score, 2)
|
| 134 |
}
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
|
|
|
| 7 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 8 |
from rank_bm25 import BM25Okapi
|
| 9 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
+
from flashrank import Ranker, RerankRequest
|
| 11 |
|
| 12 |
+
# --- GLOBAL ENGINES ---
|
|
|
|
| 13 |
nlp = None
|
| 14 |
retriever = None
|
| 15 |
ranker = None
|
| 16 |
nli_model = None
|
| 17 |
|
| 18 |
def load_engines():
|
| 19 |
+
"""Lazy loads models to ensure the Space starts without timing out."""
|
| 20 |
global nlp, retriever, ranker, nli_model
|
| 21 |
if nlp is not None: return
|
| 22 |
|
| 23 |
+
print("⚡ Awakening Titanium Brain...")
|
| 24 |
|
| 25 |
+
# 1. NLP Core (Spacy)
|
| 26 |
import spacy
|
| 27 |
+
try:
|
| 28 |
+
nlp = spacy.load("en_core_web_sm", disable=["parser"])
|
| 29 |
+
except:
|
| 30 |
import subprocess
|
| 31 |
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm", "--quiet"])
|
| 32 |
nlp = spacy.load("en_core_web_sm", disable=["parser"])
|
| 33 |
+
|
| 34 |
+
# 2. Vector Search (MiniLM)
|
| 35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 37 |
+
|
| 38 |
+
# 3. Precision Reranker (FlashRank)
|
| 39 |
+
ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
|
| 40 |
+
|
| 41 |
+
# 4. Logic Core (DeBERTa Cross-Encoder)
|
| 42 |
+
# The CrossEncoder handles (Premise, Hypothesis) logic with high accuracy.
|
| 43 |
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
|
| 44 |
+
print(f"✅ Engines Ready on {device.upper()}")
|
| 45 |
|
| 46 |
+
# --- UNIVERSAL GRAPH KNOWLEDGE BASE ---
|
| 47 |
class UniversalGraphKB:
|
| 48 |
def __init__(self):
|
| 49 |
self.indices = {}
|
|
|
|
| 56 |
def ingest_book(self, text, key="session"):
|
| 57 |
chunks = self.get_chunks(text)
|
| 58 |
|
| 59 |
+
# Entity Graph Extraction
|
| 60 |
doc = nlp(text[:100000])
|
| 61 |
people = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
|
| 62 |
locs = [ent.text.lower() for ent in doc.ents if ent.label_ in ["GPE", "LOC"]]
|
| 63 |
+
|
| 64 |
main_person = pd.Series(people).value_counts().index[0] if people else "Unknown"
|
| 65 |
main_loc = pd.Series(locs).value_counts().index[0] if locs else "Unknown"
|
| 66 |
|
| 67 |
self.context_graph[key] = f"{main_person.title()} | {main_loc.title()}"
|
| 68 |
|
| 69 |
+
# Vector + BM25 Hybrid Indexing
|
| 70 |
self.indices[key] = {
|
| 71 |
"text": chunks,
|
| 72 |
"vectors": retriever.encode(chunks, show_progress_bar=False),
|
|
|
|
| 76 |
|
| 77 |
kb = UniversalGraphKB()
|
| 78 |
|
| 79 |
+
# --- TITANIUM LOGIC ENGINE ---
|
| 80 |
def normalize_dates(text):
|
| 81 |
+
"""Symbolic Layer: Translates words to years for mathematical comparison."""
|
| 82 |
text = text.lower()
|
| 83 |
mapping = {
|
| 84 |
"eighteenth": "1750", "18th": "1750", "nineteenth": "1850", "19th": "1850",
|
|
|
|
| 89 |
return text
|
| 90 |
|
| 91 |
def extract_features(backstory, key="session"):
|
| 92 |
+
if key not in kb.indices: return 0.0, "No Data", "Ingestion Failed"
|
| 93 |
idx = kb.indices[key]
|
| 94 |
context = kb.context_graph.get(key, "")
|
| 95 |
|
| 96 |
+
# 1. Hybrid Retrieval
|
| 97 |
aug_query = f"{backstory} (Context: {context})"
|
| 98 |
q_vec = retriever.encode(aug_query)
|
| 99 |
v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
|
|
|
|
| 100 |
|
| 101 |
+
# 2. FlashRank Reranking
|
| 102 |
+
candidates = [{"id": i, "text": idx['text'][i]} for i in v_scores.argsort()[-15:][::-1]]
|
| 103 |
reranked = ranker.rerank(RerankRequest(query=backstory, passages=candidates))
|
| 104 |
best_chunk = reranked[0]['text']
|
| 105 |
|
| 106 |
+
# 3. TITANIUM GUARDRAIL: Timeline Analysis
|
| 107 |
norm_claim = normalize_dates(backstory)
|
| 108 |
norm_ev = normalize_dates(best_chunk)
|
| 109 |
years_c = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_claim)]
|
|
|
|
| 111 |
|
| 112 |
if years_c and years_e:
|
| 113 |
if not any(abs(yc - ye) < 5 for yc in years_c for ye in years_e):
|
| 114 |
+
return 0.0, best_chunk, f"TIMELINE DISCREPANCY: {years_c[0]} vs {years_e[0]}"
|
| 115 |
|
| 116 |
+
# 4. NEURAL BRAIN: Semantic Entailment
|
| 117 |
+
# Order: (Evidence, Claim) -> Cross-Encoder calculates if Truth proves Claim
|
| 118 |
scores = nli_model.predict([(best_chunk, aug_query)])[0]
|
|
|
|
|
|
|
| 119 |
exp_scores = np.exp(scores - np.max(scores))
|
| 120 |
probs = exp_scores / exp_scores.sum()
|
| 121 |
|
| 122 |
+
# Return Entailment Score (Index 1)
|
| 123 |
+
return float(probs[1]), best_chunk, "Neural Semantic Verification"
|
| 124 |
|
|
|
|
| 125 |
def predict_logic(book_text, backstory):
|
| 126 |
+
"""Main Entry point for the API."""
|
| 127 |
+
load_engines()
|
| 128 |
kb.ingest_book(book_text, "session")
|
|
|
|
| 129 |
score, ev, reason = extract_features(backstory, "session")
|
| 130 |
|
| 131 |
+
# Final Verdict Threshold
|
| 132 |
+
prediction = "Consistent" if score > 0.5 else "Contradiction"
|
|
|
|
| 133 |
return {
|
| 134 |
+
"prediction": prediction,
|
| 135 |
"rationale": reason,
|
| 136 |
+
"evidence": ev[:400] + "...",
|
| 137 |
"score": round(score, 2)
|
| 138 |
}
|