tejashsr's picture
Update model.py
b3c41c2 verified
import os
import re
import sys
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from flashrank import Ranker, RerankRequest
# --- GLOBAL ENGINES ---
nlp = None
retriever = None
ranker = None
nli_model = None
def load_engines():
"""Lazy loads models to ensure the Space starts without timing out."""
global nlp, retriever, ranker, nli_model
if nlp is not None: return
print("⚡ Awakening Titanium Brain...")
# 1. NLP Core (Spacy)
import spacy
try:
nlp = spacy.load("en_core_web_sm", disable=["parser"])
except:
import subprocess
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm", "--quiet"])
nlp = spacy.load("en_core_web_sm", disable=["parser"])
# 2. Vector Search (MiniLM)
device = "cuda" if torch.cuda.is_available() else "cpu"
retriever = SentenceTransformer('all-MiniLM-L6-v2', device=device)
# 3. Precision Reranker (FlashRank)
ranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/app/cache")
# 4. Logic Core (DeBERTa Cross-Encoder)
# The CrossEncoder handles (Premise, Hypothesis) logic with high accuracy.
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
print(f"✅ Engines Ready on {device.upper()}")
# --- UNIVERSAL GRAPH KNOWLEDGE BASE ---
class UniversalGraphKB:
def __init__(self):
self.indices = {}
self.context_graph = {}
def get_chunks(self, text):
words = re.findall(r'\S+', text)
return [" ".join(words[i:i+400]) for i in range(0, len(words), 350) if len(" ".join(words[i:i+400])) > 50]
def ingest_book(self, text, key="session"):
chunks = self.get_chunks(text)
# Entity Graph Extraction
doc = nlp(text[:100000])
people = [ent.text.lower() for ent in doc.ents if ent.label_ == "PERSON"]
locs = [ent.text.lower() for ent in doc.ents if ent.label_ in ["GPE", "LOC"]]
main_person = pd.Series(people).value_counts().index[0] if people else "Unknown"
main_loc = pd.Series(locs).value_counts().index[0] if locs else "Unknown"
self.context_graph[key] = f"{main_person.title()} | {main_loc.title()}"
# Vector + BM25 Hybrid Indexing
self.indices[key] = {
"text": chunks,
"vectors": retriever.encode(chunks, show_progress_bar=False),
"bm25": BM25Okapi([re.findall(r'\w+', c.lower()) for c in chunks])
}
return self.context_graph[key]
kb = UniversalGraphKB()
# --- TITANIUM LOGIC ENGINE ---
def normalize_dates(text):
"""Symbolic Layer: Translates words to years for mathematical comparison."""
text = text.lower()
mapping = {
"eighteenth": "1750", "18th": "1750", "nineteenth": "1850", "19th": "1850",
"twentieth": "1950", "20th": "1950", "twenty-first": "2050", "21st": "2050"
}
for k, v in mapping.items():
if k in text: text += f" ({v}) "
return text
def extract_features(backstory, key="session"):
if key not in kb.indices: return 0.0, "No Data", "Ingestion Failed"
idx = kb.indices[key]
context = kb.context_graph.get(key, "")
# 1. Hybrid Retrieval
aug_query = f"{backstory} (Context: {context})"
q_vec = retriever.encode(aug_query)
v_scores = cosine_similarity([q_vec], idx['vectors'])[0]
# 2. FlashRank Reranking
candidates = [{"id": i, "text": idx['text'][i]} for i in v_scores.argsort()[-15:][::-1]]
reranked = ranker.rerank(RerankRequest(query=backstory, passages=candidates))
best_chunk = reranked[0]['text']
# 3. TITANIUM GUARDRAIL: Timeline Analysis
norm_claim = normalize_dates(backstory)
norm_ev = normalize_dates(best_chunk)
years_c = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_claim)]
years_e = [int(y) for y in re.findall(r'\b([1-2][0-9]{3})\b', norm_ev)]
if years_c and years_e:
if not any(abs(yc - ye) < 5 for yc in years_c for ye in years_e):
return 0.0, best_chunk, f"TIMELINE DISCREPANCY: {years_c[0]} vs {years_e[0]}"
# 4. NEURAL BRAIN: Semantic Entailment
# Order: (Evidence, Claim) -> Cross-Encoder calculates if Truth proves Claim
scores = nli_model.predict([(best_chunk, aug_query)])[0]
exp_scores = np.exp(scores - np.max(scores))
probs = exp_scores / exp_scores.sum()
# Return Entailment Score (Index 1)
return float(probs[1]), best_chunk, "Neural Semantic Verification"
def predict_logic(book_text, backstory):
"""Main Entry point for the API."""
load_engines()
kb.ingest_book(book_text, "session")
score, ev, reason = extract_features(backstory, "session")
# Final Verdict Threshold
prediction = "Consistent" if score > 0.5 else "Contradiction"
return {
"prediction": prediction,
"rationale": reason,
"evidence": ev[:400] + "...",
"score": round(score, 2)
}