Spaces:
Running
Running
File size: 4,237 Bytes
0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 0430f42 0214972 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
FAISS retrieval module.
Loads the FAISS index and chunk metadata once at startup.
Given a query embedding, returns the top-k most similar chunks
plus an expanded context window from the parent judgment.
WHY load at startup and not per request?
Loading a 650MB index takes ~3 seconds. If you loaded it per request,
every user query would take 3+ seconds just for setup. Loading once
at startup means retrieval takes ~5ms per query.
"""
import json
import numpy as np
import faiss
import os
from typing import List, Dict
INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "models/faiss_index/index.faiss")
METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.jsonl")
PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
TOP_K = 5
# Similarity threshold for out-of-domain detection.
# This index uses L2 distance — HIGHER score = FURTHER AWAY = worse match.
# Legal queries typically score 0.6 - 0.8.
# Out-of-domain queries (cricket, Bollywood) score 0.9+.
# Block anything where the best match is above this threshold.
SIMILARITY_THRESHOLD = 0.85
def _load_resources():
"""Load index, metadata and parent store. Called once at module import."""
print("Loading FAISS index...")
index = faiss.read_index(INDEX_PATH)
print(f"Index loaded: {index.ntotal} vectors")
print("Loading chunk metadata...")
metadata = []
with open(METADATA_PATH, "r", encoding="utf-8") as f:
for line in f:
metadata.append(json.loads(line))
print(f"Metadata loaded: {len(metadata)} chunks")
print("Loading parent judgments...")
parent_store = {}
with open(PARENT_PATH, "r", encoding="utf-8") as f:
for line in f:
parent = json.loads(line)
parent_store[parent["judgment_id"]] = parent["text"]
print(f"Parent store loaded: {len(parent_store)} judgments")
return index, metadata, parent_store
_index, _metadata, _parent_store = _load_resources()
def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
"""
Find top-k chunks most similar to the query embedding.
Returns empty list if best score is above SIMILARITY_THRESHOLD
(meaning the query is likely out of domain — no close match found).
L2 distance logic:
low score = close match = good = let through
high score = far match = bad = block
"""
query_vec = query_embedding.reshape(1, -1).astype(np.float32)
scores, indices = _index.search(query_vec, top_k)
# Block if even the best match is too far away
best_score = float(scores[0][0])
if best_score > SIMILARITY_THRESHOLD:
return [] # Out of domain — agent will handle this
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1:
continue
chunk = _metadata[idx]
expanded = _get_expanded_context(
chunk["judgment_id"],
chunk["text"]
)
results.append({
"chunk_id": chunk["chunk_id"],
"judgment_id": chunk["judgment_id"],
"title": chunk.get("title", ""),
"year": chunk.get("year", ""),
"chunk_text": chunk["text"],
"expanded_context": expanded,
"similarity_score": float(score)
})
return results
def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
"""
Get ~1024 token window from parent judgment centred on the chunk.
Falls back to chunk text if parent not found.
WHY expand context?
The chunk is 512 tokens — enough for retrieval.
But the LLM needs more surrounding context to give a complete answer.
We go back to the full judgment and extract a wider window.
"""
parent_text = _parent_store.get(judgment_id, "")
if not parent_text:
return chunk_text
# Find chunk position in parent
anchor = chunk_text[:80]
start_pos = parent_text.find(anchor)
if start_pos == -1:
return chunk_text
# ~4 chars per token, 1024 tokens = ~4096 chars
WINDOW = 4096
expand_start = max(0, start_pos - WINDOW // 4)
expand_end = min(len(parent_text), start_pos + WINDOW)
return parent_text[expand_start:expand_end] |