CaffeinatedCoding commited on
Commit
0430f42
·
verified ·
1 Parent(s): 639ffe2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. src/retrieval.py +27 -20
  2. src/verify.py +63 -35
src/retrieval.py CHANGED
@@ -22,25 +22,28 @@ METADATA_PATH = os.getenv("METADATA_PATH", "models/faiss_index/chunk_metadata.js
22
  PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
23
  TOP_K = 5
24
 
25
- # Similarity threshold if best score is below this, query is out of domain
26
- # Score range: 0 to 1 (cosine similarity with normalized vectors)
27
- # 0.3 = very loose match, 0.5 = decent match, 0.7 = strong match
28
- SIMILARITY_THRESHOLD = 0.45
 
 
 
29
 
30
  def _load_resources():
31
  """Load index, metadata and parent store. Called once at module import."""
32
-
33
  print("Loading FAISS index...")
34
  index = faiss.read_index(INDEX_PATH)
35
  print(f"Index loaded: {index.ntotal} vectors")
36
-
37
  print("Loading chunk metadata...")
38
  metadata = []
39
  with open(METADATA_PATH, "r", encoding="utf-8") as f:
40
  for line in f:
41
  metadata.append(json.loads(line))
42
  print(f"Metadata loaded: {len(metadata)} chunks")
43
-
44
  print("Loading parent judgments...")
45
  parent_store = {}
46
  with open(PARENT_PATH, "r", encoding="utf-8") as f:
@@ -48,7 +51,7 @@ def _load_resources():
48
  parent = json.loads(line)
49
  parent_store[parent["judgment_id"]] = parent["text"]
50
  print(f"Parent store loaded: {len(parent_store)} judgments")
51
-
52
  return index, metadata, parent_store
53
 
54
  _index, _metadata, _parent_store = _load_resources()
@@ -57,28 +60,32 @@ _index, _metadata, _parent_store = _load_resources()
57
  def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
58
  """
59
  Find top-k chunks most similar to the query embedding.
60
- Returns empty list if best score is below SIMILARITY_THRESHOLD
61
- (meaning the query is likely out of domain).
 
 
 
 
62
  """
63
  query_vec = query_embedding.reshape(1, -1).astype(np.float32)
64
  scores, indices = _index.search(query_vec, top_k)
65
-
66
- # Check if best match is above threshold
67
  best_score = float(scores[0][0])
68
- if best_score < SIMILARITY_THRESHOLD:
69
  return [] # Out of domain — agent will handle this
70
-
71
  results = []
72
  for score, idx in zip(scores[0], indices[0]):
73
  if idx == -1:
74
  continue
75
-
76
  chunk = _metadata[idx]
77
  expanded = _get_expanded_context(
78
  chunk["judgment_id"],
79
  chunk["text"]
80
  )
81
-
82
  results.append({
83
  "chunk_id": chunk["chunk_id"],
84
  "judgment_id": chunk["judgment_id"],
@@ -88,7 +95,7 @@ def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
88
  "expanded_context": expanded,
89
  "similarity_score": float(score)
90
  })
91
-
92
  return results
93
 
94
 
@@ -105,16 +112,16 @@ def _get_expanded_context(judgment_id: str, chunk_text: str) -> str:
105
  parent_text = _parent_store.get(judgment_id, "")
106
  if not parent_text:
107
  return chunk_text
108
-
109
  # Find chunk position in parent
110
  anchor = chunk_text[:80]
111
  start_pos = parent_text.find(anchor)
112
  if start_pos == -1:
113
  return chunk_text
114
-
115
  # ~4 chars per token, 1024 tokens = ~4096 chars
116
  WINDOW = 4096
117
  expand_start = max(0, start_pos - WINDOW // 4)
118
  expand_end = min(len(parent_text), start_pos + WINDOW)
119
-
120
  return parent_text[expand_start:expand_end]
 
22
  PARENT_PATH = os.getenv("PARENT_PATH", "data/parent_judgments.jsonl")
23
  TOP_K = 5
24
 
25
+ # Similarity threshold for out-of-domain detection.
26
+ # This index uses L2 distance HIGHER score = FURTHER AWAY = worse match.
27
+ # Legal queries typically score 0.6 - 0.8.
28
+ # Out-of-domain queries (cricket, Bollywood) score 0.9+.
29
+ # Block anything where the best match is above this threshold.
30
+ SIMILARITY_THRESHOLD = 0.85
31
+
32
 
33
  def _load_resources():
34
  """Load index, metadata and parent store. Called once at module import."""
35
+
36
  print("Loading FAISS index...")
37
  index = faiss.read_index(INDEX_PATH)
38
  print(f"Index loaded: {index.ntotal} vectors")
39
+
40
  print("Loading chunk metadata...")
41
  metadata = []
42
  with open(METADATA_PATH, "r", encoding="utf-8") as f:
43
  for line in f:
44
  metadata.append(json.loads(line))
45
  print(f"Metadata loaded: {len(metadata)} chunks")
46
+
47
  print("Loading parent judgments...")
48
  parent_store = {}
49
  with open(PARENT_PATH, "r", encoding="utf-8") as f:
 
51
  parent = json.loads(line)
52
  parent_store[parent["judgment_id"]] = parent["text"]
53
  print(f"Parent store loaded: {len(parent_store)} judgments")
54
+
55
  return index, metadata, parent_store
56
 
57
  _index, _metadata, _parent_store = _load_resources()
 
60
  def retrieve(query_embedding: np.ndarray, top_k: int = TOP_K) -> List[Dict]:
61
  """
62
  Find top-k chunks most similar to the query embedding.
63
+ Returns empty list if best score is above SIMILARITY_THRESHOLD
64
+ (meaning the query is likely out of domain — no close match found).
65
+
66
+ L2 distance logic:
67
+ low score = close match = good = let through
68
+ high score = far match = bad = block
69
  """
70
  query_vec = query_embedding.reshape(1, -1).astype(np.float32)
71
  scores, indices = _index.search(query_vec, top_k)
72
+
73
+ # Block if even the best match is too far away
74
  best_score = float(scores[0][0])
75
+ if best_score > SIMILARITY_THRESHOLD:
76
  return [] # Out of domain — agent will handle this
77
+
78
  results = []
79
  for score, idx in zip(scores[0], indices[0]):
80
  if idx == -1:
81
  continue
82
+
83
  chunk = _metadata[idx]
84
  expanded = _get_expanded_context(
85
  chunk["judgment_id"],
86
  chunk["text"]
87
  )
88
+
89
  results.append({
90
  "chunk_id": chunk["chunk_id"],
91
  "judgment_id": chunk["judgment_id"],
 
95
  "expanded_context": expanded,
96
  "similarity_score": float(score)
97
  })
98
+
99
  return results
100
 
101
 
 
112
  parent_text = _parent_store.get(judgment_id, "")
113
  if not parent_text:
114
  return chunk_text
115
+
116
  # Find chunk position in parent
117
  anchor = chunk_text[:80]
118
  start_pos = parent_text.find(anchor)
119
  if start_pos == -1:
120
  return chunk_text
121
+
122
  # ~4 chars per token, 1024 tokens = ~4096 chars
123
  WINDOW = 4096
124
  expand_start = max(0, start_pos - WINDOW // 4)
125
  expand_end = min(len(parent_text), start_pos + WINDOW)
126
+
127
  return parent_text[expand_start:expand_end]
src/verify.py CHANGED
@@ -1,45 +1,73 @@
1
  """
2
- Citation verification. Deterministic string matching — no ML.
3
-
4
- LOGIC:
5
- - Extract all quoted phrases (in double quotes) from LLM answer
6
- - Check each phrase verbatim against all retrieved chunk texts
7
- - ALL found Verified
8
- - ANY missing → Unverified
9
- - No quotes in answer → Verified (no verifiable claim made)
10
-
11
- DOCUMENTED LIMITATION:
12
- Paraphrased claims that are not quoted pass as Verified.
13
- Full NLI-based verification is out of scope — documented in README.
14
  """
15
 
16
  import re
17
- from typing import List, Dict, Tuple
 
 
 
 
 
 
 
 
 
18
 
19
- def extract_quotes(text: str) -> List[str]:
20
- """Extract double-quoted phrases of at least 8 characters."""
21
- return re.findall(r'"([^"]{8,})"', text)
22
 
23
- def verify_citations(
24
- llm_answer: str,
25
- retrieved_chunks: List[Dict]
26
- ) -> Tuple[str, List[str]]:
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- Returns (status, unverified_quotes).
29
- status: "Verified" | "Unverified" | "No verifiable claims"
 
 
 
 
 
 
 
 
 
30
  """
31
- quotes = extract_quotes(llm_answer)
32
-
33
  if not quotes:
34
- return "No verifiable claims", []
35
-
36
- all_context = " ".join(
37
- c.get("expanded_context", c.get("chunk_text", ""))
38
- for c in retrieved_chunks
39
- ).lower()
40
-
41
- unverified = [q for q in quotes if q.lower() not in all_context]
42
-
 
 
 
 
 
 
 
 
43
  if unverified:
44
- return "Unverified", unverified
45
- return "Verified", []
 
1
  """
2
+ Citation verification module.
3
+ Checks whether quoted phrases in LLM answer appear in retrieved context.
4
+
5
+ Deterministic no ML inference.
6
+ Documented limitation: paraphrases pass as verified because
7
+ exact paraphrase matching requires NLI which is out of scope.
 
 
 
 
 
 
8
  """
9
 
10
  import re
11
+ import unicodedata
12
+
13
+
14
+ def _normalise(text: str) -> str:
15
+ """Lowercase, strip punctuation, collapse whitespace."""
16
+ text = text.lower()
17
+ text = unicodedata.normalize("NFKD", text)
18
+ text = re.sub(r"[^\w\s]", " ", text)
19
+ text = re.sub(r"\s+", " ", text).strip()
20
+ return text
21
 
 
 
 
22
 
23
+ def _extract_quotes(text: str) -> list[str]:
24
+ """Extract all quoted phrases from text."""
25
+ patterns = [
26
+ r'"([^"]{10,})"', # standard double quotes
27
+ r'\u201c([^\u201d]{10,})\u201d', # curly double quotes
28
+ r"'([^']{10,})'", # single quotes
29
+ ]
30
+ quotes = []
31
+ for pattern in patterns:
32
+ found = re.findall(pattern, text)
33
+ quotes.extend(found)
34
+ return quotes
35
+
36
+
37
+ def verify_citations(answer: str, contexts: list[dict]) -> tuple[bool, list[str]]:
38
  """
39
+ Check whether quoted phrases in answer appear in context windows.
40
+
41
+ Returns:
42
+ (verified: bool, unverified_quotes: list[str])
43
+
44
+ Logic:
45
+ - Extract all quoted phrases from answer
46
+ - If no quotes: return (True, []) — no verifiable claims made
47
+ - For each quote: check if normalised quote is substring of any normalised context
48
+ - If ALL quotes found: (True, [])
49
+ - If ANY quote not found: (False, [list of missing quotes])
50
  """
51
+ quotes = _extract_quotes(answer)
52
+
53
  if not quotes:
54
+ return True, []
55
+
56
+ # Build normalised context corpus
57
+ all_context_text = " ".join(
58
+ _normalise(ctx.get("text", "") or ctx.get("excerpt", ""))
59
+ for ctx in contexts
60
+ )
61
+
62
+ unverified = []
63
+ for quote in quotes:
64
+ normalised_quote = _normalise(quote)
65
+ # Skip very short normalised quotes — likely artifacts
66
+ if len(normalised_quote) < 8:
67
+ continue
68
+ if normalised_quote not in all_context_text:
69
+ unverified.append(quote)
70
+
71
  if unverified:
72
+ return False, unverified
73
+ return True, []