Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- src/retrieval.py +27 -20
- src/verify.py +63 -35
src/retrieval.py
CHANGED
|
@@ -22,25 +22,28 @@ METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.js
|
|
| 22 |
PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
|
| 23 |
TOP_K = 5
|
| 24 |
|
| 25 |
-
# Similarity threshold
|
| 26 |
-
#
|
| 27 |
-
#
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def _load_resources():
|
| 31 |
"""Load index, metadata and parent store. Called once at module import."""
|
| 32 |
-
|
| 33 |
print("Loading FAISS index...")
|
| 34 |
index = faiss.read_index(INDEX_PATH)
|
| 35 |
print(f"Index loaded: {index.ntotal} vectors")
|
| 36 |
-
|
| 37 |
print("Loading chunk metadata...")
|
| 38 |
metadata = []
|
| 39 |
with open(METADATA_PATH, "r", encoding="utf-8") as f:
|
| 40 |
for line in f:
|
| 41 |
metadata.append(json.loads(line))
|
| 42 |
print(f"Metadata loaded: {len(metadata)} chunks")
|
| 43 |
-
|
| 44 |
print("Loading parent judgments...")
|
| 45 |
parent_store = {}
|
| 46 |
with open(PARENT_PATH, "r", encoding="utf-8") as f:
|
|
@@ -48,7 +51,7 @@ def _load_resources():
|
|
| 48 |
parent = json.loads(line)
|
| 49 |
parent_store[parent["judgment_id"]] = parent["text"]
|
| 50 |
print(f"Parent store loaded: {len(parent_store)} judgments")
|
| 51 |
-
|
| 52 |
return index, metadata, parent_store
|
| 53 |
|
| 54 |
_index, _metadata, _parent_store = _load_resources()
|
|
@@ -57,28 +60,32 @@ _index, _metadata, _parent_store = _load_resources()
|
|
| 57 |
def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
|
| 58 |
"""
|
| 59 |
Find top-k chunks most similar to the query embedding.
|
| 60 |
-
Returns empty list if best score is
|
| 61 |
-
(meaning the query is likely out of domain).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
query_vec = query_embedding.reshape(1, -1).astype(np.float32)
|
| 64 |
scores, indices = _index.search(query_vec, top_k)
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
best_score = float(scores[0][0])
|
| 68 |
-
if best_score
|
| 69 |
return [] # Out of domain — agent will handle this
|
| 70 |
-
|
| 71 |
results = []
|
| 72 |
for score, idx in zip(scores[0], indices[0]):
|
| 73 |
if idx == -1:
|
| 74 |
continue
|
| 75 |
-
|
| 76 |
chunk = _metadata[idx]
|
| 77 |
expanded = _get_expanded_context(
|
| 78 |
chunk["judgment_id"],
|
| 79 |
chunk["text"]
|
| 80 |
)
|
| 81 |
-
|
| 82 |
results.append({
|
| 83 |
"chunk_id": chunk["chunk_id"],
|
| 84 |
"judgment_id": chunk["judgment_id"],
|
|
@@ -88,7 +95,7 @@ def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
|
|
| 88 |
"expanded_context": expanded,
|
| 89 |
"similarity_score": float(score)
|
| 90 |
})
|
| 91 |
-
|
| 92 |
return results
|
| 93 |
|
| 94 |
|
|
@@ -105,16 +112,16 @@ def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
|
|
| 105 |
parent_text = _parent_store.get(judgment_id, "")
|
| 106 |
if not parent_text:
|
| 107 |
return chunk_text
|
| 108 |
-
|
| 109 |
# Find chunk position in parent
|
| 110 |
anchor = chunk_text[:80]
|
| 111 |
start_pos = parent_text.find(anchor)
|
| 112 |
if start_pos == -1:
|
| 113 |
return chunk_text
|
| 114 |
-
|
| 115 |
# ~4 chars per token, 1024 tokens = ~4096 chars
|
| 116 |
WINDOW = 4096
|
| 117 |
expand_start = max(0, start_pos - WINDOW // 4)
|
| 118 |
expand_end = min(len(parent_text), start_pos + WINDOW)
|
| 119 |
-
|
| 120 |
return parent_text[expand_start:expand_end]
|
|
|
|
| 22 |
PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
|
| 23 |
TOP_K = 5
|
| 24 |
|
| 25 |
+
# Similarity threshold for out-of-domain detection.
|
| 26 |
+
# This index uses L2 distance — HIGHER score = FURTHER AWAY = worse match.
|
| 27 |
+
# Legal queries typically score 0.6 - 0.8.
|
| 28 |
+
# Out-of-domain queries (cricket, Bollywood) score 0.9+.
|
| 29 |
+
# Block anything where the best match is above this threshold.
|
| 30 |
+
SIMILARITY_THRESHOLD = 0.85
|
| 31 |
+
|
| 32 |
|
| 33 |
def _load_resources():
|
| 34 |
"""Load index, metadata and parent store. Called once at module import."""
|
| 35 |
+
|
| 36 |
print("Loading FAISS index...")
|
| 37 |
index = faiss.read_index(INDEX_PATH)
|
| 38 |
print(f"Index loaded: {index.ntotal} vectors")
|
| 39 |
+
|
| 40 |
print("Loading chunk metadata...")
|
| 41 |
metadata = []
|
| 42 |
with open(METADATA_PATH, "r", encoding="utf-8") as f:
|
| 43 |
for line in f:
|
| 44 |
metadata.append(json.loads(line))
|
| 45 |
print(f"Metadata loaded: {len(metadata)} chunks")
|
| 46 |
+
|
| 47 |
print("Loading parent judgments...")
|
| 48 |
parent_store = {}
|
| 49 |
with open(PARENT_PATH, "r", encoding="utf-8") as f:
|
|
|
|
| 51 |
parent = json.loads(line)
|
| 52 |
parent_store[parent["judgment_id"]] = parent["text"]
|
| 53 |
print(f"Parent store loaded: {len(parent_store)} judgments")
|
| 54 |
+
|
| 55 |
return index, metadata, parent_store
|
| 56 |
|
| 57 |
_index, _metadata, _parent_store = _load_resources()
|
|
|
|
| 60 |
def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
|
| 61 |
"""
|
| 62 |
Find top-k chunks most similar to the query embedding.
|
| 63 |
+
Returns empty list if best score is above SIMILARITY_THRESHOLD
|
| 64 |
+
(meaning the query is likely out of domain — no close match found).
|
| 65 |
+
|
| 66 |
+
L2 distance logic:
|
| 67 |
+
low score = close match = good = let through
|
| 68 |
+
high score = far match = bad = block
|
| 69 |
"""
|
| 70 |
query_vec = query_embedding.reshape(1, -1).astype(np.float32)
|
| 71 |
scores, indices = _index.search(query_vec, top_k)
|
| 72 |
+
|
| 73 |
+
# Block if even the best match is too far away
|
| 74 |
best_score = float(scores[0][0])
|
| 75 |
+
if best_score > SIMILARITY_THRESHOLD:
|
| 76 |
return [] # Out of domain — agent will handle this
|
| 77 |
+
|
| 78 |
results = []
|
| 79 |
for score, idx in zip(scores[0], indices[0]):
|
| 80 |
if idx == -1:
|
| 81 |
continue
|
| 82 |
+
|
| 83 |
chunk = _metadata[idx]
|
| 84 |
expanded = _get_expanded_context(
|
| 85 |
chunk["judgment_id"],
|
| 86 |
chunk["text"]
|
| 87 |
)
|
| 88 |
+
|
| 89 |
results.append({
|
| 90 |
"chunk_id": chunk["chunk_id"],
|
| 91 |
"judgment_id": chunk["judgment_id"],
|
|
|
|
| 95 |
"expanded_context": expanded,
|
| 96 |
"similarity_score": float(score)
|
| 97 |
})
|
| 98 |
+
|
| 99 |
return results
|
| 100 |
|
| 101 |
|
|
|
|
| 112 |
parent_text = _parent_store.get(judgment_id, "")
|
| 113 |
if not parent_text:
|
| 114 |
return chunk_text
|
| 115 |
+
|
| 116 |
# Find chunk position in parent
|
| 117 |
anchor = chunk_text[:80]
|
| 118 |
start_pos = parent_text.find(anchor)
|
| 119 |
if start_pos == -1:
|
| 120 |
return chunk_text
|
| 121 |
+
|
| 122 |
# ~4 chars per token, 1024 tokens = ~4096 chars
|
| 123 |
WINDOW = 4096
|
| 124 |
expand_start = max(0, start_pos - WINDOW // 4)
|
| 125 |
expand_end = min(len(parent_text), start_pos + WINDOW)
|
| 126 |
+
|
| 127 |
return parent_text[expand_start:expand_end]
|
src/verify.py
CHANGED
|
@@ -1,45 +1,73 @@
|
|
| 1 |
"""
|
| 2 |
-
Citation verification
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
- ANY missing → Unverified
|
| 9 |
-
- No quotes in answer → Verified (no verifiable claim made)
|
| 10 |
-
|
| 11 |
-
DOCUMENTED LIMITATION:
|
| 12 |
-
Paraphrased claims that are not quoted pass as Verified.
|
| 13 |
-
Full NLI-based verification is out of scope — documented in README.
|
| 14 |
"""
|
| 15 |
|
| 16 |
import re
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
def extract_quotes(text: str) -> List[str]:
|
| 20 |
-
"""Extract double-quoted phrases of at least 8 characters."""
|
| 21 |
-
return re.findall(r'"([^"]{8,})"', text)
|
| 22 |
|
| 23 |
-
def
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"""
|
| 31 |
-
quotes =
|
| 32 |
-
|
| 33 |
if not quotes:
|
| 34 |
-
return
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if unverified:
|
| 44 |
-
return
|
| 45 |
-
return
|
|
|
|
| 1 |
"""
|
| 2 |
+
Citation verification module.
|
| 3 |
+
Checks whether quoted phrases in LLM answer appear in retrieved context.
|
| 4 |
+
|
| 5 |
+
Deterministic — no ML inference.
|
| 6 |
+
Documented limitation: paraphrases pass as verified because
|
| 7 |
+
exact paraphrase matching requires NLI which is out of scope.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import re
|
| 11 |
+
import unicodedata
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _normalise(text: str) -> str:
|
| 15 |
+
"""Lowercase, strip punctuation, collapse whitespace."""
|
| 16 |
+
text = text.lower()
|
| 17 |
+
text = unicodedata.normalize("NFKD", text)
|
| 18 |
+
text = re.sub(r"[^\w\s]", " ", text)
|
| 19 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 20 |
+
return text
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
def _extract_quotes(text: str) -> list[str]:
|
| 24 |
+
"""Extract all quoted phrases from text."""
|
| 25 |
+
patterns = [
|
| 26 |
+
r'"([^"]{10,})"', # standard double quotes
|
| 27 |
+
r'\u201c([^\u201d]{10,})\u201d', # curly double quotes
|
| 28 |
+
r"'([^']{10,})'", # single quotes
|
| 29 |
+
]
|
| 30 |
+
quotes = []
|
| 31 |
+
for pattern in patterns:
|
| 32 |
+
found = re.findall(pattern, text)
|
| 33 |
+
quotes.extend(found)
|
| 34 |
+
return quotes
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def verify_citations(answer: str, contexts: list[dict]) -> tuple[bool, list[str]]:
|
| 38 |
"""
|
| 39 |
+
Check whether quoted phrases in answer appear in context windows.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
(verified: bool, unverified_quotes: list[str])
|
| 43 |
+
|
| 44 |
+
Logic:
|
| 45 |
+
- Extract all quoted phrases from answer
|
| 46 |
+
- If no quotes: return (True, []) — no verifiable claims made
|
| 47 |
+
- For each quote: check if normalised quote is substring of any normalised context
|
| 48 |
+
- If ALL quotes found: (True, [])
|
| 49 |
+
- If ANY quote not found: (False, [list of missing quotes])
|
| 50 |
"""
|
| 51 |
+
quotes = _extract_quotes(answer)
|
| 52 |
+
|
| 53 |
if not quotes:
|
| 54 |
+
return True, []
|
| 55 |
+
|
| 56 |
+
# Build normalised context corpus
|
| 57 |
+
all_context_text = " ".join(
|
| 58 |
+
_normalise(ctx.get("text", "") or ctx.get("excerpt", ""))
|
| 59 |
+
for ctx in contexts
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
unverified = []
|
| 63 |
+
for quote in quotes:
|
| 64 |
+
normalised_quote = _normalise(quote)
|
| 65 |
+
# Skip very short normalised quotes — likely artifacts
|
| 66 |
+
if len(normalised_quote) < 8:
|
| 67 |
+
continue
|
| 68 |
+
if normalised_quote not in all_context_text:
|
| 69 |
+
unverified.append(quote)
|
| 70 |
+
|
| 71 |
if unverified:
|
| 72 |
+
return False, unverified
|
| 73 |
+
return True, []
|