Spaces:
Sleeping
Sleeping
| """ | |
| Retriever: searches PageIndex trees for relevant sections. | |
| Uses two-level LLM-based tree traversal — first pick the right branch, | |
| then pick the right leaf nodes within it. | |
| """ | |
| import json | |
| import litellm | |
| from src.config import MODEL, DOCUMENT_REGISTRY | |
| from src.usage import _extract_usage, _empty_usage, _sum_usage | |
| # System prompt to suppress thinking on structured output calls (Gemma 4 etc.) | |
| _NO_THINK = {"role": "system", "content": "Do not use thinking. Respond directly with the JSON only."} | |
| def load_tree(domain_key: str) -> dict: | |
| """Load a PageIndex tree from disk.""" | |
| index_path = DOCUMENT_REGISTRY[domain_key]["index"] | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def format_nodes_shallow(nodes: list, depth: int = 0) -> str: | |
| """Format nodes at the current level only (no recursion into children).""" | |
| lines = [] | |
| for node in nodes: | |
| indent = " " * depth | |
| summary = node.get("summary") or node.get("prefix_summary", "") | |
| # Truncate summaries to keep prompt short | |
| summary_str = f" — {summary[:80]}" if summary else "" | |
| child_count = len(node.get("nodes", [])) | |
| children_str = f" ({child_count} sub-sections)" if child_count > 0 else "" | |
| lines.append( | |
| f"{indent}[{node.get('node_id', '?')}] " | |
| f"{node.get('title', '')}{summary_str}{children_str}" | |
| ) | |
| return "\n".join(lines) | |
| def format_branch(node: dict) -> str: | |
| """Format a single branch — the node and all its children (full depth).""" | |
| lines = [] | |
| def _fmt(n, depth=0): | |
| indent = " " * depth | |
| summary = n.get("summary") or n.get("prefix_summary", "") | |
| summary_str = f" — {summary[:80]}" if summary else "" | |
| lines.append(f"{indent}[{n.get('node_id', '?')}] {n.get('title', '')}{summary_str}") | |
| for child in n.get("nodes", []): | |
| _fmt(child, depth + 1) | |
| _fmt(node) | |
| return "\n".join(lines) | |
| def get_tree_overview(tree: dict) -> str: | |
| """Full tree overview (used by test_pipeline for inspection).""" | |
| def format_node(node, depth=0): | |
| lines = [] | |
| indent = " " * depth | |
| summary = node.get("summary") or node.get("prefix_summary", "") | |
| summary_str = f" — {summary[:100]}" if summary else "" | |
| lines.append(f"{indent}[{node.get('node_id', '?')}] {node.get('title', '')}{summary_str}") | |
| for child in node.get("nodes", []): | |
| lines.extend(format_node(child, depth + 1)) | |
| return lines | |
| structure = tree.get("structure", []) | |
| lines = [] | |
| for node in structure: | |
| lines.extend(format_node(node)) | |
| return "\n".join(lines) | |
| def find_node_by_id(tree: dict, node_id: str) -> dict | None: | |
| """Find a node in the tree by its node_id.""" | |
| def search(nodes): | |
| for node in nodes: | |
| if node.get("node_id") == node_id: | |
| return node | |
| found = search(node.get("nodes", [])) | |
| if found: | |
| return found | |
| return None | |
| return search(tree.get("structure", [])) | |
| def _parse_json_array(content: str) -> list: | |
| """Parse a JSON array from LLM output, handling markdown wrapping.""" | |
| if "```" in content: | |
| content = content.split("```")[1] | |
| if content.startswith("json"): | |
| content = content[4:] | |
| content = content.strip() | |
| return json.loads(content) | |
| def _collect_leaf_nodes(node: dict) -> list[dict]: | |
| """Get all leaf nodes (no children) under a given node.""" | |
| children = node.get("nodes", []) | |
| if not children: | |
| return [node] | |
| leaves = [] | |
| for child in children: | |
| leaves.extend(_collect_leaf_nodes(child)) | |
| return leaves | |
| def retrieve_from_domain( | |
| query: str, | |
| domain_key: str, | |
| max_sections: int = 3, | |
| profession: str | None = None, | |
| ) -> list[dict]: | |
| """Retrieve relevant sections using two-level tree traversal. | |
| Level 1: Pick which top-level branches are relevant (small prompt). | |
| Level 2: Within each selected branch, pick specific leaf nodes. | |
| For small trees (<20 nodes total), skip to single-pass retrieval. | |
| Profession (if set) is injected into the LLM prompts as soft context so | |
| the retriever favours sections binding the user's profession. | |
| """ | |
| tree = load_tree(domain_key) | |
| structure = tree.get("structure", []) | |
| # Flatten to count total nodes | |
| def count_nodes(nodes): | |
| total = len(nodes) | |
| for n in nodes: | |
| total += count_nodes(n.get("nodes", [])) | |
| return total | |
| total_nodes = count_nodes(structure) | |
| # Small trees: single-pass is fine | |
| if total_nodes <= 20: | |
| results, usage = _single_pass_retrieve(query, tree, domain_key, max_sections, profession) | |
| return results, usage | |
| # Large trees: two-level traversal | |
| results, usage = _two_level_retrieve(query, tree, domain_key, max_sections, profession) | |
| return results, usage | |
| def _profession_line(profession: str | None) -> str: | |
| """Standard soft-context line for retrieval prompts.""" | |
| if not profession: | |
| return "" | |
| return ( | |
| f"\nThe user has stated their profession: **{profession}**. " | |
| f"When the source material includes `binds:` scope tags, prefer " | |
| f"sections that bind {profession} or apply across all professions. " | |
| f"Sections binding only OTHER professions are less relevant unless " | |
| f"the question is explicitly comparative.\n" | |
| ) | |
| def _single_pass_retrieve(query, tree, domain_key, max_sections, profession=None): | |
| """For small trees, do a single LLM call over the full structure.""" | |
| overview = get_tree_overview(tree) | |
| prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. | |
| Document: {DOCUMENT_REGISTRY[domain_key]['description']} | |
| {_profession_line(profession)} | |
| Structure: | |
| {overview} | |
| Question: {query} | |
| Pick up to {max_sections} node_ids that best answer this question. Prefer specific regulatory text over general introductions. Choose complementary sections rather than overlapping ones. | |
| Return ONLY a JSON array. Example: ["0003", "0007"]""" | |
| try: | |
| response = litellm.completion( | |
| model=MODEL, | |
| messages=[_NO_THINK, {"role": "user", "content": prompt}], | |
| temperature=0, | |
| max_tokens=500, | |
| ) | |
| usage = _extract_usage(response) | |
| node_ids = _parse_json_array((response.choices[0].message.content or "").strip()) | |
| return _nodes_to_results(tree, domain_key, node_ids, max_sections), usage | |
| except Exception as e: | |
| print(f"Retrieval error ({domain_key}): {e}") | |
| return [], _empty_usage() | |
| def _two_level_retrieve(query, tree, domain_key, max_sections, profession=None): | |
| """Two-level traversal: pick branches, then pick leaves.""" | |
| structure = tree.get("structure", []) | |
| all_usage = [] | |
| # Get the top-level children (skip the root wrapper if it exists) | |
| top_nodes = structure | |
| if len(structure) == 1 and structure[0].get("nodes"): | |
| top_nodes = structure[0].get("nodes", []) | |
| # Level 1: Which top-level branches are relevant? | |
| branch_overview = format_nodes_shallow(top_nodes) | |
| prompt = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. | |
| Document: {DOCUMENT_REGISTRY[domain_key]['description']} | |
| {_profession_line(profession)} | |
| Top-level sections: | |
| {branch_overview} | |
| Question: {query} | |
| Which 1-2 top-level sections are most likely to contain the answer? Think about what the user needs to know, not just keyword matches. | |
| Return ONLY a JSON array of node_ids. Example: ["0002", "0018"]""" | |
| try: | |
| response = litellm.completion( | |
| model=MODEL, | |
| messages=[_NO_THINK, {"role": "user", "content": prompt}], | |
| temperature=0, | |
| max_tokens=500, | |
| ) | |
| all_usage.append(_extract_usage(response)) | |
| branch_ids = _parse_json_array((response.choices[0].message.content or "").strip()) | |
| except Exception as e: | |
| print(f"Level 1 retrieval error ({domain_key}): {e}") | |
| return [], _empty_usage() | |
| # Level 2: Within selected branches, find specific nodes | |
| all_results = [] | |
| for bid in branch_ids[:2]: | |
| branch_node = find_node_by_id(tree, bid) | |
| if not branch_node: | |
| continue | |
| children = branch_node.get("nodes", []) | |
| if not children: | |
| # Branch is a leaf — return it directly | |
| text = branch_node.get("text", "") | |
| all_results.append({ | |
| "domain": domain_key, | |
| "node_id": bid, | |
| "title": branch_node.get("title", ""), | |
| "text": text, | |
| "line_num": branch_node.get("line_num"), | |
| "source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key), | |
| }) | |
| continue | |
| # Show the branch's children to the LLM | |
| branch_detail = format_branch(branch_node) | |
| prompt2 = f"""You retrieve sections from NZ healthcare marketing compliance documents to answer a user's question. | |
| Section: {branch_node.get('title', '')} | |
| {_profession_line(profession)} | |
| Sub-sections: | |
| {branch_detail} | |
| Question: {query} | |
| Pick up to {max_sections} specific sub-sections that answer this question. Prefer nodes with regulatory text over general introductions. Choose complementary sections rather than overlapping ones. | |
| Return ONLY a JSON array of node_ids. Example: ["0021", "0022"]""" | |
| try: | |
| response = litellm.completion( | |
| model=MODEL, | |
| messages=[_NO_THINK, {"role": "user", "content": prompt2}], | |
| temperature=0, | |
| max_tokens=500, | |
| ) | |
| all_usage.append(_extract_usage(response)) | |
| leaf_ids = _parse_json_array((response.choices[0].message.content or "").strip()) | |
| all_results.extend( | |
| _nodes_to_results(tree, domain_key, leaf_ids, max_sections) | |
| ) | |
| except Exception as e: | |
| print(f"Level 2 retrieval error ({domain_key}/{bid}): {e}") | |
| # Fallback: return the branch itself | |
| text = branch_node.get("text", "") | |
| all_results.append({ | |
| "domain": domain_key, | |
| "node_id": bid, | |
| "title": branch_node.get("title", ""), | |
| "text": text, | |
| "line_num": branch_node.get("line_num"), | |
| "source_url": branch_node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key), | |
| }) | |
| return all_results[:max_sections], _sum_usage(*all_usage) if all_usage else _empty_usage() | |
| def _extract_source_url(text: str) -> str | None: | |
| """Extract a Source: URL from the beginning of node text.""" | |
| for line in text.split("\n")[:5]: | |
| line = line.strip() | |
| if line.startswith("Source: http"): | |
| return line[len("Source: "):].strip() | |
| return None | |
| def _domain_fallback_url(domain_key: str) -> str | None: | |
| """Get the fallback source URL for a domain from the registry.""" | |
| info = DOCUMENT_REGISTRY.get(domain_key, {}) | |
| return info.get("source_url") | |
| def _nodes_to_results(tree, domain_key, node_ids, max_sections): | |
| """Convert node_ids to result dicts.""" | |
| results = [] | |
| for nid in node_ids[:max_sections]: | |
| node = find_node_by_id(tree, nid) | |
| if node: | |
| text = node.get("text", "") | |
| url = node.get("source_url") or _extract_source_url(text) or _domain_fallback_url(domain_key) | |
| results.append({ | |
| "domain": domain_key, | |
| "node_id": nid, | |
| "title": node.get("title", ""), | |
| "text": text, | |
| "line_num": node.get("line_num"), | |
| "source_url": url, | |
| }) | |
| return results | |
| def retrieve( | |
| query: str, | |
| domain_keys: list[str], | |
| max_sections_per_domain: int = 3, | |
| profession: str | None = None, | |
| ) -> tuple[list[dict], dict]: | |
| """Retrieve relevant sections from multiple domains. | |
| Profession (if set) is forwarded as soft context to the per-domain retriever | |
| so the LLM-driven tree traversal favours sections binding the user's profession. | |
| """ | |
| all_results = [] | |
| all_usage = [] | |
| for domain_key in domain_keys: | |
| results, usage = retrieve_from_domain( | |
| query, domain_key, max_sections_per_domain, profession=profession | |
| ) | |
| all_results.extend(results) | |
| all_usage.append(usage) | |
| return all_results, _sum_usage(*all_usage) if all_usage else _empty_usage() | |