File size: 4,237 Bytes
0214972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0430f42
 
 
 
 
 
 
0214972
 
 
0430f42
0214972
 
 
0430f42
0214972
 
 
 
 
 
0430f42
0214972
 
 
 
 
 
 
0430f42
0214972
 
 
 
 
 
 
 
0430f42
 
 
 
 
 
0214972
 
 
0430f42
 
0214972
0430f42
0214972
0430f42
0214972
 
 
 
0430f42
0214972
 
 
 
 
0430f42
0214972
 
 
 
 
 
 
 
 
0430f42
0214972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0430f42
0214972
 
 
 
 
0430f42
0214972
 
 
 
0430f42
0214972
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
FAISS retrieval module.

Loads the FAISS index and chunk metadata once at startup.
Given a query embedding, returns the top-k most similar chunks
plus an expanded context window from the parent judgment.

WHY load at startup and not per request?
Loading a 650MB index takes ~3 seconds. If you loaded it per request,
every user query would take 3+ seconds just for setup. Loading once
at startup means retrieval takes ~5ms per query.
"""

import json
import numpy as np
import faiss
import os
from typing import List, Dict

INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "models/faiss_index/index.faiss")
METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.jsonl")
PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
TOP_K = 5

# Similarity threshold for out-of-domain detection.
# This index uses L2 distance — HIGHER score = FURTHER AWAY = worse match.
# Legal queries typically score 0.6 - 0.8.
# Out-of-domain queries (cricket, Bollywood) score 0.9+.
# Block anything where the best match is above this threshold.
SIMILARITY_THRESHOLD = 0.85


def _load_resources():
    """Load index, metadata and parent store. Called once at module import."""

    print("Loading FAISS index...")
    index = faiss.read_index(INDEX_PATH)
    print(f"Index loaded: {index.ntotal} vectors")

    print("Loading chunk metadata...")
    metadata = []
    with open(METADATA_PATH, "r", encoding="utf-8") as f:
        for line in f:
            metadata.append(json.loads(line))
    print(f"Metadata loaded: {len(metadata)} chunks")

    print("Loading parent judgments...")
    parent_store = {}
    with open(PARENT_PATH, "r", encoding="utf-8") as f:
        for line in f:
            parent = json.loads(line)
            parent_store[parent["judgment_id"]] = parent["text"]
    print(f"Parent store loaded: {len(parent_store)} judgments")

    return index, metadata, parent_store

_index, _metadata, _parent_store = _load_resources()


def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
    """
    Find top-k chunks most similar to the query embedding.
    Returns empty list if best score is above SIMILARITY_THRESHOLD
    (meaning the query is likely out of domain — no close match found).

    L2 distance logic:
        low score  = close match = good = let through
        high score = far match   = bad  = block
    """
    query_vec = query_embedding.reshape(1, -1).astype(np.float32)
    scores, indices = _index.search(query_vec, top_k)

    # Block if even the best match is too far away
    best_score = float(scores[0][0])
    if best_score > SIMILARITY_THRESHOLD:
        return []  # Out of domain — agent will handle this

    results = []
    for score, idx in zip(scores[0], indices[0]):
        if idx == -1:
            continue

        chunk = _metadata[idx]
        expanded = _get_expanded_context(
            chunk["judgment_id"],
            chunk["text"]
        )

        results.append({
            "chunk_id": chunk["chunk_id"],
            "judgment_id": chunk["judgment_id"],
            "title": chunk.get("title", ""),
            "year": chunk.get("year", ""),
            "chunk_text": chunk["text"],
            "expanded_context": expanded,
            "similarity_score": float(score)
        })

    return results


def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
    """
    Get ~1024 token window from parent judgment centred on the chunk.
    Falls back to chunk text if parent not found.

    WHY expand context?
    The chunk is 512 tokens — enough for retrieval.
    But the LLM needs more surrounding context to give a complete answer.
    We go back to the full judgment and extract a wider window.
    """
    parent_text = _parent_store.get(judgment_id, "")
    if not parent_text:
        return chunk_text

    # Find chunk position in parent
    anchor = chunk_text[:80]
    start_pos = parent_text.find(anchor)
    if start_pos == -1:
        return chunk_text

    # ~4 chars per token, 1024 tokens = ~4096 chars
    WINDOW = 4096
    expand_start = max(0, start_pos - WINDOW // 4)
    expand_end = min(len(parent_text), start_pos + WINDOW)

    return parent_text[expand_start:expand_end]