""" Retriever: searches PageIndex trees for relevant sections. Uses two-level LLM-based tree traversal — first pick the right branch, then pick the right leaf nodes within it. """ import json import litellm from src.config import MODEL, DOCUMENT_REGISTRY from src.usage import _extract_usage, _empty_usage, _sum_usage # System prompt to suppress thinking on structured output calls (Gemma 4 etc.) _NO_THINK = {"role": "system", "content": "Do not use thinking. Respond directly with the JSON only."} def load_tree(domain_key: str) -> dict: """Load a PageIndex tree from disk.""" index_path = DOCUMENT_REGISTRY[domain_key]["index"] with open(index_path, "r", encoding="utf-8") as f: return json.load(f) def format_nodes_shallow(nodes: list, depth: int = 0) -> str: """Format nodes at the current level only (no recursion into children).""" lines = [] for node in nodes: indent = " " * depth summary = node.get("summary") or node.get("prefix_summary", "") # Truncate summaries to keep prompt short summary_str = f" — {summary[:80]}" if summary else "" child_count = len(node.get("nodes", [])) children_str = f" ({child_count} sub-sections)" if child_count > 0 else "" lines.append( f"{indent}[{node.get('node_id', '?')}] " f"{node.get('title', '')}{summary_str}{children_str}" ) return "\n".join(lines) def format_branch(node: dict) -> str: """Format a single branch — the node and all its children (full depth).""" lines = [] def _fmt(n, depth=0): indent = " " * depth summary = n.get("summary") or n.get("prefix_summary", "") summary_str = f" — {summary[:80]}" if summary else "" lines.append(f"{indent}[{n.get('node_id', '?')}] {n.get('title', '')}{summary_str}") for child in n.get("nodes", []): _fmt(child, depth + 1) _fmt(node) return "\n".join(lines) def get_tree_overview(tree: dict) -> str: """Full tree overview (used by test_pipeline for inspection).""" def format_node(node, depth=0): lines = [] indent = " " * depth summary = node.get("summary") or node.get("prefix_summary", "") summary_str = f" — {summary[:100]}" if summary else "" lines.append(f"{indent}[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}") for child in node.get("nodes", []): lines.extend(format_node(child, depth + 1)) return lines structure = tree.get("structure", []) lines = [] for node in structure: lines.extend(format_node(node)) return "\n".join(lines) def find_node_by_id(tree: dict, node_id: str) -> dict | None: """Find a node in the tree by its node_id.""" def search(nodes): for node in nodes: if node.get("node_id") == node_id: return node found = search(node.get("nodes", [])) if found: return found return None return search(tree.get("structure", [])) def _parse_json_array(content: str) -> list: """Parse a JSON array from LLM output, handling markdown wrapping.""" if "```" in content: content = content.split("```")[1] if content.startswith("json"): content = content[4:] content = content.strip() return json.loads(content) def _collect_leaf_nodes(node: dict) -> list[dict]: """Get all leaf nodes (no children) under a given node.""" children = node.get("nodes", []) if not children: return [node] leaves = [] for child in children: leaves.extend(_collect_leaf_nodes(child)) return leaves def retrieve_from_domain( query: str, domain_key: str, max_sections: int = 3, profession: str | None = None, ) -> list[dict]: """Retrieve relevant sections using two-level tree traversal. Level 1: Pick which top-level branches are relevant (small prompt). Level 2: Within each selected branch, pick specific leaf nodes. For small trees (<20 nodes total), skip to single-pass retrieval. Profession (if set) is injected into the LLM prompts as soft context so the retriever favours sections binding the user's profession. """ tree = load_tree(domain_key) structure = tree.get("structure", []) # Flatten to count total nodes def count_nodes(nodes): total = len(nodes) for n in nodes: total += count_nodes(n.get("nodes", [])) return total total_nodes = count_nodes(structure) # Small trees: single-pass is fine if total_nodes <= 20: results, usage = _single_pass_retrieve(query, tree, domain_key, max_sections, profession) return results, usage # Large trees: two-level traversal results, usage = _two_level_retrieve(query, tree, domain_key, max_sections, profession) return results, usage def _profession_line(profession: str | None) -> str: """Standard soft-context line for retrieval prompts.""" if not profession: return "" return ( f"\nThe user has stated their profession: **{profession}**. " f"When the source material includes `binds:` scope tags, prefer " f"sections that bind {profession} or apply across all professions. " f"Sections binding only OTHER professions are less relevant unless " f"the question is explicitly comparative.\n" ) def _single_pass_retrieve(query, tree, domain_key, max_sections, profession=None): """For small trees, do a single LLM call over the full structure.""" overview = get_tree_overview(tree) prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. Document: {DOCUMENT_REGISTRY[domain_key]['description']} {_profession_line(profession)} Structure: {overview} Question: {query} Pick up to {max_sections} node_ids that best answer this question. Prefer specific regulatory text over general introductions. Choose complementary sections rather than overlapping ones. Return ONLY a JSON array. Example: ["0003", "0007"]""" try: response = litellm.completion( model=MODEL, messages=[_NO_THINK, {"role": "user", "content": prompt}], temperature=0, max_tokens=500, ) usage = _extract_usage(response) node_ids = _parse_json_array((response.choices[0].message.content or "").strip()) return _nodes_to_results(tree, domain_key, node_ids, max_sections), usage except Exception as e: print(f"Retrieval error ({domain_key}): {e}") return [], _empty_usage() def _two_level_retrieve(query, tree, domain_key, max_sections, profession=None): """Two-level traversal: pick branches, then pick leaves.""" structure = tree.get("structure", []) all_usage = [] # Get the top-level children (skip the root wrapper if it exists) top_nodes = structure if len(structure) == 1 and structure[0].get("nodes"): top_nodes = structure[0].get("nodes", []) # Level 1: Which top-level branches are relevant? branch_overview = format_nodes_shallow(top_nodes) prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. Document: {DOCUMENT_REGISTRY[domain_key]['description']} {_profession_line(profession)} Top-level sections: {branch_overview} Question: {query} Which 1-2 top-level sections are most likely to contain the answer? Think about what the user needs to know, not just keyword matches. Return ONLY a JSON array of node_ids. Example: ["0002", "0018"]""" try: response = litellm.completion( model=MODEL, messages=[_NO_THINK, {"role": "user", "content": prompt}], temperature=0, max_tokens=500, ) all_usage.append(_extract_usage(response)) branch_ids = _parse_json_array((response.choices[0].message.content or "").strip()) except Exception as e: print(f"Level 1 retrieval error ({domain_key}): {e}") return [], _empty_usage() # Level 2: Within selected branches, find specific nodes all_results = [] for bid in branch_ids[:2]: branch_node = find_node_by_id(tree, bid) if not branch_node: continue children = branch_node.get("nodes", []) if not children: # Branch is a leaf — return it directly text = branch_node.get("text", "") all_results.append({ "domain": domain_key, "node_id": bid, "title": branch_node.get("title", ""), "text": text, "line_num": branch_node.get("line_num"), "source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key), }) continue # Show the branch's children to the LLM branch_detail = format_branch(branch_node) prompt2 = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. Section: {branch_node.get('title', '')} {_profession_line(profession)} Sub-sections: {branch_detail} Question: {query} Pick up to {max_sections} specific sub-sections that answer this question. Prefer nodes with regulatory text over general introductions. Choose complementary sections rather than overlapping ones. Return ONLY a JSON array of node_ids. Example: ["0021", "0022"]""" try: response = litellm.completion( model=MODEL, messages=[_NO_THINK, {"role": "user", "content": prompt2}], temperature=0, max_tokens=500, ) all_usage.append(_extract_usage(response)) leaf_ids = _parse_json_array((response.choices[0].message.content or "").strip()) all_results.extend( _nodes_to_results(tree, domain_key, leaf_ids, max_sections) ) except Exception as e: print(f"Level 2 retrieval error ({domain_key}/{bid}): {e}") # Fallback: return the branch itself text = branch_node.get("text", "") all_results.append({ "domain": domain_key, "node_id": bid, "title": branch_node.get("title", ""), "text": text, "line_num": branch_node.get("line_num"), "source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key), }) return all_results[:max_sections], _sum_usage(*all_usage) if all_usage else _empty_usage() def _extract_source_url(text: str) -> str | None: """Extract a Source: URL from the beginning of node text.""" for line in text.split("\n")[:5]: line = line.strip() if line.startswith("Source: http"): return line[len("Source: "):].strip() return None def _domain_fallback_url(domain_key: str) -> str | None: """Get the fallback source URL for a domain from the registry.""" info = DOCUMENT_REGISTRY.get(domain_key, {}) return info.get("source_url") def _nodes_to_results(tree, domain_key, node_ids, max_sections): """Convert node_ids to result dicts.""" results = [] for nid in node_ids[:max_sections]: node = find_node_by_id(tree, nid) if node: text = node.get("text", "") url = node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key) results.append({ "domain": domain_key, "node_id": nid, "title": node.get("title", ""), "text": text, "line_num": node.get("line_num"), "source_url": url, }) return results def retrieve( query: str, domain_keys: list[str], max_sections_per_domain: int = 3, profession: str | None = None, ) -> tuple[list[dict], dict]: """Retrieve relevant sections from multiple domains. Profession (if set) is forwarded as soft context to the per-domain retriever so the LLM-driven tree traversal favours sections binding the user's profession. """ all_results = [] all_usage = [] for domain_key in domain_keys: results, usage = retrieve_from_domain( query, domain_key, max_sections_per_domain, profession=profession ) all_results.extend(results) all_usage.append(usage) return all_results, _sum_usage(*all_usage) if all_usage else _empty_usage()