EphAsad commited on
Commit
d776bf4
·
verified ·
1 Parent(s): 186a43a

Update rag/rag_retriever.py

Browse files
Files changed (1) hide show
  1. rag/rag_retriever.py +110 -16
rag/rag_retriever.py CHANGED
@@ -1,8 +1,12 @@
1
  # rag/rag_retriever.py
2
  # ============================================================
3
- # RAG retriever:
4
- # - Loads kb_index.json
5
- # - Retrieves best-matching chunks for a given phenotype + genus
 
 
 
 
6
  # ============================================================
7
 
8
  from __future__ import annotations
@@ -13,6 +17,26 @@ import numpy as np
13
  from rag.rag_embedder import embed_text, load_kb_index
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
17
  """
18
  Cosine similarity for normalized embeddings.
@@ -20,12 +44,26 @@ def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
20
  return float(np.dot(a, b))
21
 
22
 
 
 
 
 
23
  def retrieve_rag_context(
24
  phenotype_text: str,
25
  target_genus: str,
26
  top_k: int = 5,
27
  kb_path: str = "data/rag/index/kb_index.json",
28
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
29
 
30
  kb = load_kb_index(kb_path)
31
  records = kb.get("records", [])
@@ -37,64 +75,120 @@ def retrieve_rag_context(
37
  "combined_context": "",
38
  }
39
 
40
- q_emb = embed_text(phenotype_text, normalize=True)
 
 
 
 
 
 
 
 
41
  target_genus_lc = (target_genus or "").strip().lower()
42
 
43
  scored_records: List[Dict[str, Any]] = []
44
 
 
 
 
 
45
  for rec in records:
46
- g = (rec.get("genus") or "").strip().lower()
47
- if target_genus_lc and g != target_genus_lc:
48
  continue
49
 
50
  emb = rec.get("embedding")
51
  if emb is None:
52
  continue
53
 
54
- score = _cosine_similarity(q_emb, emb)
 
 
 
 
 
55
  scored_records.append({
56
  "id": rec.get("id"),
57
  "genus": rec.get("genus"),
58
  "species": rec.get("species"),
59
- "source_type": rec.get("level"),
60
  "path": rec.get("source_file"),
61
  "text": rec.get("text"),
62
  "score": score,
63
  })
64
 
65
- # Fallback: use all records
 
 
 
66
  if not scored_records:
67
  for rec in records:
68
  emb = rec.get("embedding")
69
  if emb is None:
70
  continue
71
- score = _cosine_similarity(q_emb, emb)
 
 
 
 
 
 
72
  scored_records.append({
73
  "id": rec.get("id"),
74
  "genus": rec.get("genus"),
75
  "species": rec.get("species"),
76
- "source_type": rec.get("level"),
77
  "path": rec.get("source_file"),
78
  "text": rec.get("text"),
79
  "score": score,
80
  })
81
 
 
 
 
 
82
  scored_records.sort(key=lambda r: r["score"], reverse=True)
83
- top = scored_records[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  combined_ctx_parts: List[str] = []
86
- for rec in top:
87
- label = rec["genus"]
 
88
  if rec.get("species"):
89
  label = f"{label} {rec['species']}"
 
90
  combined_ctx_parts.append(
91
- f"[{label} {rec['source_type']}] {rec['text']}"
 
92
  )
93
 
94
  combined_context = "\n\n".join(combined_ctx_parts)
95
 
96
  return {
97
  "genus": target_genus,
98
- "chunks": top,
99
  "combined_context": combined_context,
100
  }
 
1
  # rag/rag_retriever.py
2
  # ============================================================
3
+ # RAG retriever (Stage 2 – microbiology-aware)
4
+ #
5
+ # Improvements:
6
+ # - Source-type weighting (species > genus > notes)
7
+ # - Genus-aware query expansion
8
+ # - Diversity enforcement (avoid duplicate sources)
9
+ # - Explicit ranking & score annotations for generator
10
  # ============================================================
11
 
12
  from __future__ import annotations
 
17
  from rag.rag_embedder import embed_text, load_kb_index
18
 
19
 
20
+ # ------------------------------------------------------------
21
+ # Configuration
22
+ # ------------------------------------------------------------
23
+
24
+ # Weight different knowledge chunk types
25
+ SOURCE_TYPE_WEIGHTS = {
26
+ "species": 1.15,
27
+ "genus": 1.00,
28
+ "table": 1.10,
29
+ "note": 0.85,
30
+ }
31
+
32
+ # Max chunks allowed per source file (diversity control)
33
+ MAX_CHUNKS_PER_SOURCE = 1
34
+
35
+
36
+ # ------------------------------------------------------------
37
+ # Similarity helper
38
+ # ------------------------------------------------------------
39
+
40
  def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
41
  """
42
  Cosine similarity for normalized embeddings.
 
44
  return float(np.dot(a, b))
45
 
46
 
47
+ # ------------------------------------------------------------
48
+ # Public API
49
+ # ------------------------------------------------------------
50
+
51
  def retrieve_rag_context(
52
  phenotype_text: str,
53
  target_genus: str,
54
  top_k: int = 5,
55
  kb_path: str = "data/rag/index/kb_index.json",
56
  ) -> Dict[str, Any]:
57
+ """
58
+ Retrieve the most relevant RAG chunks for a phenotype + genus.
59
+
60
+ Returns:
61
+ {
62
+ "genus": target_genus,
63
+ "chunks": [...], # ranked chunk metadata
64
+ "combined_context": "..." # formatted context for generator
65
+ }
66
+ """
67
 
68
  kb = load_kb_index(kb_path)
69
  records = kb.get("records", [])
 
75
  "combined_context": "",
76
  }
77
 
78
+ # --------------------------------------------------------
79
+ # Build genus-aware query
80
+ # --------------------------------------------------------
81
+
82
+ query_text = phenotype_text.strip()
83
+ if target_genus:
84
+ query_text = f"{query_text}\nTarget genus: {target_genus}"
85
+
86
+ q_emb = embed_text(query_text, normalize=True)
87
  target_genus_lc = (target_genus or "").strip().lower()
88
 
89
  scored_records: List[Dict[str, Any]] = []
90
 
91
+ # --------------------------------------------------------
92
+ # Primary pass: genus-filtered retrieval
93
+ # --------------------------------------------------------
94
+
95
  for rec in records:
96
+ rec_genus = (rec.get("genus") or "").strip().lower()
97
+ if target_genus_lc and rec_genus != target_genus_lc:
98
  continue
99
 
100
  emb = rec.get("embedding")
101
  if emb is None:
102
  continue
103
 
104
+ base_score = _cosine_similarity(q_emb, emb)
105
+ source_type = rec.get("level")
106
+ weight = SOURCE_TYPE_WEIGHTS.get(source_type, 1.0)
107
+
108
+ score = base_score * weight
109
+
110
  scored_records.append({
111
  "id": rec.get("id"),
112
  "genus": rec.get("genus"),
113
  "species": rec.get("species"),
114
+ "source_type": source_type,
115
  "path": rec.get("source_file"),
116
  "text": rec.get("text"),
117
  "score": score,
118
  })
119
 
120
+ # --------------------------------------------------------
121
+ # Fallback: no genus-matched records
122
+ # --------------------------------------------------------
123
+
124
  if not scored_records:
125
  for rec in records:
126
  emb = rec.get("embedding")
127
  if emb is None:
128
  continue
129
+
130
+ base_score = _cosine_similarity(q_emb, emb)
131
+ source_type = rec.get("level")
132
+ weight = SOURCE_TYPE_WEIGHTS.get(source_type, 1.0)
133
+
134
+ score = base_score * weight
135
+
136
  scored_records.append({
137
  "id": rec.get("id"),
138
  "genus": rec.get("genus"),
139
  "species": rec.get("species"),
140
+ "source_type": source_type,
141
  "path": rec.get("source_file"),
142
  "text": rec.get("text"),
143
  "score": score,
144
  })
145
 
146
+ # --------------------------------------------------------
147
+ # Sort by weighted score
148
+ # --------------------------------------------------------
149
+
150
  scored_records.sort(key=lambda r: r["score"], reverse=True)
151
+
152
+ # --------------------------------------------------------
153
+ # Diversity enforcement (avoid duplicate sources)
154
+ # --------------------------------------------------------
155
+
156
+ selected: List[Dict[str, Any]] = []
157
+ source_counts: Dict[str, int] = {}
158
+
159
+ for rec in scored_records:
160
+ src = rec.get("path") or ""
161
+ count = source_counts.get(src, 0)
162
+
163
+ if count >= MAX_CHUNKS_PER_SOURCE:
164
+ continue
165
+
166
+ selected.append(rec)
167
+ source_counts[src] = count + 1
168
+
169
+ if len(selected) >= top_k:
170
+ break
171
+
172
+ # --------------------------------------------------------
173
+ # Build combined context with explicit ranking
174
+ # --------------------------------------------------------
175
 
176
  combined_ctx_parts: List[str] = []
177
+
178
+ for idx, rec in enumerate(selected, start=1):
179
+ label = rec.get("genus") or "Unknown genus"
180
  if rec.get("species"):
181
  label = f"{label} {rec['species']}"
182
+
183
  combined_ctx_parts.append(
184
+ f"[RANK {idx} | SCORE {rec['score']:.3f} | {label} — {rec['source_type']}]\n"
185
+ f"{rec['text']}"
186
  )
187
 
188
  combined_context = "\n\n".join(combined_ctx_parts)
189
 
190
  return {
191
  "genus": target_genus,
192
+ "chunks": selected,
193
  "combined_context": combined_context,
194
  }