NyayLens-API / src /qa /inference.py
Sai Pranav Reddy
Clean lightweight deployment
968e24d
import torch
import faiss
import json
import sqlite3
import re
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
class LegalQAEngine:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# ---- Load QA model ----
self.tokenizer = AutoTokenizer.from_pretrained("outputs/qa_model/final")
self.qa_model = AutoModelForQuestionAnswering.from_pretrained(
"outputs/qa_model/final"
).to(self.device)
self.qa_model.eval()
# ---- Load retriever ----
self.embedder = SentenceTransformer("BAAI/bge-base-en-v1.5", device=self.device)
self.index = faiss.read_index("data/processed/faiss/faiss_index.bin")
with open("data/processed/embeddings/paragraph_ids.json", encoding="utf-8") as f:
self.para_ids = json.load(f)
self.db = sqlite3.connect("data/processed/indexed/paragraphs.db")
self.cursor = self.db.cursor()
print("✓ Enhanced QA inference system ready")
# ------------------------------------------------------------------
# TEXT NORMALIZATION (critical for PDF artifacts)
# ------------------------------------------------------------------
def _normalize(self, text: str) -> str:
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
# ------------------------------------------------------------------
# REFUTED CLAUSE DETECTION (Article 21 FIX)
# ------------------------------------------------------------------
def _is_refuted_clause(self, answer_text, paragraph_text):
para = self._normalize(paragraph_text)
ans = self._normalize(answer_text)
# Patterns like:
# "it is not correct to say, ..., that X"
# "it cannot be said, ..., that X"
refutation_regexes = [
r"not correct to say.*?that\s+(.*?)(?:\.|,)",
r"cannot be said.*?that\s+(.*?)(?:\.|,)",
]
for pattern in refutation_regexes:
matches = re.findall(pattern, para)
for refuted_prop in matches:
# If answer is part of the refuted proposition → block
if ans in refuted_prop:
return True
return False
# ------------------------------------------------------------------
# RETRIEVAL
# ------------------------------------------------------------------
def retrieve_paragraphs(self, question, top_k=8):
q_emb = self.embedder.encode(
[question], normalize_embeddings=True, convert_to_numpy=True
)
scores, indices = self.index.search(q_emb, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
para_id = self.para_ids[idx]
self.cursor.execute(
"SELECT judgment_id, page_no, text FROM paragraphs WHERE id = ?",
(para_id,),
)
row = self.cursor.fetchone()
if row:
judgment_id, page_no, text = row
results.append(
{
"judgment_id": judgment_id,
"page_no": page_no,
"text": text,
"retrieval_score": float(score),
}
)
return results
# ------------------------------------------------------------------
# ANSWERING
# ------------------------------------------------------------------
def answer_question(self, question, top_k=8, max_answers=2):
paragraphs = self.retrieve_paragraphs(question, top_k)
candidates = []
for para in paragraphs:
inputs = self.tokenizer(
question,
para["text"],
return_tensors="pt",
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
outputs = self.qa_model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
token_type_ids = inputs["token_type_ids"][0].tolist()
question_end = token_type_ids.index(1)
top_starts = torch.topk(start_logits, k=5).indices
top_ends = torch.topk(end_logits, k=5).indices
for s in top_starts:
for e in top_ends:
if e < s or (e - s) > 80:
continue
# ❌ Block question echo
if s < question_end:
continue
answer_tokens = inputs["input_ids"][0][s : e + 1]
answer_text = self.tokenizer.decode(
answer_tokens, skip_special_tokens=True
).strip()
words = answer_text.split()
if len(words) < 8:
continue
# ❌ Block refuted propositions
if self._is_refuted_clause(answer_text, para["text"]):
continue
score = start_logits[s].item() + end_logits[e].item()
# Doctrinal boost
if any(
k in answer_text.lower()
for k in ["the court", "held that", "it is clear that", "the law"]
):
score += 1.5
candidates.append(
{
"answer": answer_text,
"confidence": score,
"judgment_id": para["judgment_id"],
"page_no": para["page_no"],
"paragraph": para["text"],
"retrieval_score": para["retrieval_score"],
}
)
# ---- Deduplicate answers ----
seen = set()
final = []
for c in sorted(candidates, key=lambda x: x["confidence"], reverse=True):
key = self._normalize(c["answer"])
if key not in seen:
seen.add(key)
final.append(c)
return final[:max_answers]
# ----------------------------------------------------------------------
# DEMO
# ----------------------------------------------------------------------
if __name__ == "__main__":
qa = LegalQAEngine()
questions = [
"What is the scope of Article 21?",
"What are the conditions for granting anticipatory bail?",
"What is the burden of proof in criminal law?",
]
for q in questions:
print("\n" + "=" * 90)
print(f"QUESTION: {q}")
print("=" * 90)
answers = qa.answer_question(q)
for i, ans in enumerate(answers, 1):
print(f"\nANSWER {i}:")
print(ans["answer"])
print(
f"\nSOURCE: {ans['judgment_id']} | Page: {ans['page_no']}"
)
print(f"Retrieval score: {ans['retrieval_score']:.3f}")
print(f"Confidence score: {ans['confidence']:.2f}")
print("\nPARAGRAPH:")
print(ans["paragraph"][:700] + "...")