hmc-rag / src /retriever.py
webmuppet
Initial commit — health marketing compliance RAG
bad8b6c
"""
Retriever: searches PageIndex trees for relevant sections.
Uses two-level LLM-based tree traversal — first pick the right branch,
then pick the right leaf nodes within it.
"""
import json
import litellm
from src.config import MODEL, DOCUMENT_REGISTRY
from src.usage import _extract_usage, _empty_usage, _sum_usage
# System prompt to suppress thinking on structured output calls (Gemma 4 etc.)
_NO_THINK = {"role": "system", "content": "Do not use thinking. Respond directly with the JSON only."}
def load_tree(domain_key: str) -> dict:
"""Load a PageIndex tree from disk."""
index_path = DOCUMENT_REGISTRY[domain_key]["index"]
with open(index_path, "r", encoding="utf-8") as f:
return json.load(f)
def format_nodes_shallow(nodes: list, depth: int = 0) -> str:
"""Format nodes at the current level only (no recursion into children)."""
lines = []
for node in nodes:
indent = " " * depth
summary = node.get("summary") or node.get("prefix_summary", "")
# Truncate summaries to keep prompt short
summary_str = f" — {summary[:80]}" if summary else ""
child_count = len(node.get("nodes", []))
children_str = f" ({child_count} sub-sections)" if child_count > 0 else ""
lines.append(
f"{indent}[{node.get('node_id', '?')}] "
f"{node.get('title', '')}{summary_str}{children_str}"
)
return "\n".join(lines)
def format_branch(node: dict) -> str:
"""Format a single branch — the node and all its children (full depth)."""
lines = []
def _fmt(n, depth=0):
indent = " " * depth
summary = n.get("summary") or n.get("prefix_summary", "")
summary_str = f" — {summary[:80]}" if summary else ""
lines.append(f"{indent}[{n.get('node_id', '?')}] {n.get('title', '')}{summary_str}")
for child in n.get("nodes", []):
_fmt(child, depth + 1)
_fmt(node)
return "\n".join(lines)
def get_tree_overview(tree: dict) -> str:
"""Full tree overview (used by test_pipeline for inspection)."""
def format_node(node, depth=0):
lines = []
indent = " " * depth
summary = node.get("summary") or node.get("prefix_summary", "")
summary_str = f" — {summary[:100]}" if summary else ""
lines.append(f"{indent}[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}")
for child in node.get("nodes", []):
lines.extend(format_node(child, depth + 1))
return lines
structure = tree.get("structure", [])
lines = []
for node in structure:
lines.extend(format_node(node))
return "\n".join(lines)
def find_node_by_id(tree: dict, node_id: str) -> dict | None:
"""Find a node in the tree by its node_id."""
def search(nodes):
for node in nodes:
if node.get("node_id") == node_id:
return node
found = search(node.get("nodes", []))
if found:
return found
return None
return search(tree.get("structure", []))
def _parse_json_array(content: str) -> list:
"""Parse a JSON array from LLM output, handling markdown wrapping."""
if "```" in content:
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
content = content.strip()
return json.loads(content)
def _collect_leaf_nodes(node: dict) -> list[dict]:
"""Get all leaf nodes (no children) under a given node."""
children = node.get("nodes", [])
if not children:
return [node]
leaves = []
for child in children:
leaves.extend(_collect_leaf_nodes(child))
return leaves
def retrieve_from_domain(
query: str,
domain_key: str,
max_sections: int = 3,
profession: str | None = None,
) -> list[dict]:
"""Retrieve relevant sections using two-level tree traversal.
Level 1: Pick which top-level branches are relevant (small prompt).
Level 2: Within each selected branch, pick specific leaf nodes.
For small trees (<20 nodes total), skip to single-pass retrieval.
Profession (if set) is injected into the LLM prompts as soft context so
the retriever favours sections binding the user's profession.
"""
tree = load_tree(domain_key)
structure = tree.get("structure", [])
# Flatten to count total nodes
def count_nodes(nodes):
total = len(nodes)
for n in nodes:
total += count_nodes(n.get("nodes", []))
return total
total_nodes = count_nodes(structure)
# Small trees: single-pass is fine
if total_nodes <= 20:
results, usage = _single_pass_retrieve(query, tree, domain_key, max_sections, profession)
return results, usage
# Large trees: two-level traversal
results, usage = _two_level_retrieve(query, tree, domain_key, max_sections, profession)
return results, usage
def _profession_line(profession: str | None) -> str:
"""Standard soft-context line for retrieval prompts."""
if not profession:
return ""
return (
f"\nThe user has stated their profession: **{profession}**. "
f"When the source material includes `binds:` scope tags, prefer "
f"sections that bind {profession} or apply across all professions. "
f"Sections binding only OTHER professions are less relevant unless "
f"the question is explicitly comparative.\n"
)
def _single_pass_retrieve(query, tree, domain_key, max_sections, profession=None):
"""For small trees, do a single LLM call over the full structure."""
overview = get_tree_overview(tree)
prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question.
Document: {DOCUMENT_REGISTRY[domain_key]['description']}
{_profession_line(profession)}
Structure:
{overview}
Question: {query}
Pick up to {max_sections} node_ids that best answer this question. Prefer specific regulatory text over general introductions. Choose complementary sections rather than overlapping ones.
Return ONLY a JSON array. Example: ["0003", "0007"]"""
try:
response = litellm.completion(
model=MODEL,
messages=[_NO_THINK, {"role": "user", "content": prompt}],
temperature=0,
max_tokens=500,
)
usage = _extract_usage(response)
node_ids = _parse_json_array((response.choices[0].message.content or "").strip())
return _nodes_to_results(tree, domain_key, node_ids, max_sections), usage
except Exception as e:
print(f"Retrieval error ({domain_key}): {e}")
return [], _empty_usage()
def _two_level_retrieve(query, tree, domain_key, max_sections, profession=None):
"""Two-level traversal: pick branches, then pick leaves."""
structure = tree.get("structure", [])
all_usage = []
# Get the top-level children (skip the root wrapper if it exists)
top_nodes = structure
if len(structure) == 1 and structure[0].get("nodes"):
top_nodes = structure[0].get("nodes", [])
# Level 1: Which top-level branches are relevant?
branch_overview = format_nodes_shallow(top_nodes)
prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question.
Document: {DOCUMENT_REGISTRY[domain_key]['description']}
{_profession_line(profession)}
Top-level sections:
{branch_overview}
Question: {query}
Which 1-2 top-level sections are most likely to contain the answer? Think about what the user needs to know, not just keyword matches.
Return ONLY a JSON array of node_ids. Example: ["0002", "0018"]"""
try:
response = litellm.completion(
model=MODEL,
messages=[_NO_THINK, {"role": "user", "content": prompt}],
temperature=0,
max_tokens=500,
)
all_usage.append(_extract_usage(response))
branch_ids = _parse_json_array((response.choices[0].message.content or "").strip())
except Exception as e:
print(f"Level 1 retrieval error ({domain_key}): {e}")
return [], _empty_usage()
# Level 2: Within selected branches, find specific nodes
all_results = []
for bid in branch_ids[:2]:
branch_node = find_node_by_id(tree, bid)
if not branch_node:
continue
children = branch_node.get("nodes", [])
if not children:
# Branch is a leaf — return it directly
text = branch_node.get("text", "")
all_results.append({
"domain": domain_key,
"node_id": bid,
"title": branch_node.get("title", ""),
"text": text,
"line_num": branch_node.get("line_num"),
"source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key),
})
continue
# Show the branch's children to the LLM
branch_detail = format_branch(branch_node)
prompt2 = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question.
Section: {branch_node.get('title', '')}
{_profession_line(profession)}
Sub-sections:
{branch_detail}
Question: {query}
Pick up to {max_sections} specific sub-sections that answer this question. Prefer nodes with regulatory text over general introductions. Choose complementary sections rather than overlapping ones.
Return ONLY a JSON array of node_ids. Example: ["0021", "0022"]"""
try:
response = litellm.completion(
model=MODEL,
messages=[_NO_THINK, {"role": "user", "content": prompt2}],
temperature=0,
max_tokens=500,
)
all_usage.append(_extract_usage(response))
leaf_ids = _parse_json_array((response.choices[0].message.content or "").strip())
all_results.extend(
_nodes_to_results(tree, domain_key, leaf_ids, max_sections)
)
except Exception as e:
print(f"Level 2 retrieval error ({domain_key}/{bid}): {e}")
# Fallback: return the branch itself
text = branch_node.get("text", "")
all_results.append({
"domain": domain_key,
"node_id": bid,
"title": branch_node.get("title", ""),
"text": text,
"line_num": branch_node.get("line_num"),
"source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key),
})
return all_results[:max_sections], _sum_usage(*all_usage) if all_usage else _empty_usage()
def _extract_source_url(text: str) -> str | None:
"""Extract a Source: URL from the beginning of node text."""
for line in text.split("\n")[:5]:
line = line.strip()
if line.startswith("Source: http"):
return line[len("Source: "):].strip()
return None
def _domain_fallback_url(domain_key: str) -> str | None:
"""Get the fallback source URL for a domain from the registry."""
info = DOCUMENT_REGISTRY.get(domain_key, {})
return info.get("source_url")
def _nodes_to_results(tree, domain_key, node_ids, max_sections):
"""Convert node_ids to result dicts."""
results = []
for nid in node_ids[:max_sections]:
node = find_node_by_id(tree, nid)
if node:
text = node.get("text", "")
url = node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key)
results.append({
"domain": domain_key,
"node_id": nid,
"title": node.get("title", ""),
"text": text,
"line_num": node.get("line_num"),
"source_url": url,
})
return results
def retrieve(
query: str,
domain_keys: list[str],
max_sections_per_domain: int = 3,
profession: str | None = None,
) -> tuple[list[dict], dict]:
"""Retrieve relevant sections from multiple domains.
Profession (if set) is forwarded as soft context to the per-domain retriever
so the LLM-driven tree traversal favours sections binding the user's profession.
"""
all_results = []
all_usage = []
for domain_key in domain_keys:
results, usage = retrieve_from_domain(
query, domain_key, max_sections_per_domain, profession=profession
)
all_results.extend(results)
all_usage.append(usage)
return all_results, _sum_usage(*all_usage) if all_usage else _empty_usage()