hmc-rag / scripts /build_indexes.py
webmuppet
Initial commit — health marketing compliance RAG
bad8b6c
"""
Build PageIndex trees for each domain compilation in the corpus.
Uses local MLX server (Qwen 2.5 14B) for summary generation — zero API cost.
Shows real-time progress for each LLM call.
Usage:
uv run python scripts/build_indexes.py # build all
uv run python scripts/build_indexes.py legislation # build one domain
uv run python scripts/build_indexes.py --list # show available domains
uv run python scripts/build_indexes.py --dry-run # show costs without building
"""
import sys
import os
import json
import time
import argparse
from dotenv import load_dotenv
# Load .env from project root
load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"), override=True)
# Add PageIndex repo to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "PageIndex"))
import litellm
from pageindex.page_index_md import (
extract_nodes_from_markdown,
extract_node_text_content,
build_tree_from_nodes,
)
from pageindex.utils import (
structure_to_list,
write_node_id,
format_structure,
count_tokens,
create_clean_structure_for_description,
)
# Model selection — supports local MLX and remote API models
# ECE_MODEL takes priority (presets: qwen, gemma, sonnet, opus)
# Falls back to ECE_MLX_MODEL for backwards compat
_MODEL_PRESETS = {
"qwen": "openai/mlx-community/Qwen2.5-14B-Instruct-4bit",
"gemma": "openai/mlx-community/gemma-4-26b-a4b-it-4bit",
"sonnet": "anthropic/claude-sonnet-4-6",
"opus": "anthropic/claude-opus-4-6",
}
_model_choice = os.environ.get("ECE_MODEL") or os.environ.get("ECE_MLX_MODEL") or "qwen"
INDEX_MODEL = _MODEL_PRESETS.get(_model_choice, _model_choice)
LLM_TIMEOUT = 120
# Only configure MLX server env vars for local models
MLX_BASE_URL = "http://localhost:8080/v1"
if INDEX_MODEL.startswith("openai/"):
os.environ["OPENAI_API_KEY"] = "mlx"
os.environ["OPENAI_API_BASE"] = MLX_BASE_URL
CORPUS_DIR = os.path.join(os.path.dirname(__file__), "..", "corpus")
INDEX_DIR = os.path.join(os.path.dirname(__file__), "..", "indexes")
os.makedirs(INDEX_DIR, exist_ok=True)
DOMAINS = [
"medicines-and-supplements",
"advertising-standards",
"consumer-protection",
"marketing-comms",
"practitioner-regulation",
"professional-codes",
]
def check_mlx_server():
"""Pre-flight check: mlx-lm server is running."""
import urllib.request
import urllib.error
try:
req = urllib.request.Request(f"{MLX_BASE_URL}/models", method="GET")
with urllib.request.urlopen(req, timeout=3) as resp:
if resp.status == 200:
return True
except (urllib.error.URLError, OSError):
pass
print(f"\n ERROR: mlx-lm server not running at {MLX_BASE_URL}")
print(f" Start it with:")
print(f" uv run python -m mlx_lm server --model {MLX_MODEL} --port 8080\n")
sys.exit(1)
SUMMARY_SYSTEM_PROMPT = """\
You write concise retrieval summaries for a NZ healthcare marketing compliance system.
Your summaries will be compared against user search queries to decide which sections are relevant. \
A good summary front-loads the specific topics, requirements, and terminology that someone would \
search for. A bad summary is generic ("this section covers requirements for...") and could match anything.
Rules:
- 1-3 sentences, under 80 words
- Start with the criterion code and topic, not "This section covers..."
- Name specific requirements: age groups, ratios, qualifications, equipment types, time periods
- Use the same terminology the regulation uses (e.g. "under-2s" not "infants", "kaiako" not "teachers")
- Mention any secondary topics the section also covers
- Do not include opinions or interpretation — just what the section contains"""
def llm_call(model, prompt, system_prompt=None, timeout=LLM_TIMEOUT):
"""Single LLM call via litellm."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = litellm.completion(
model=model,
messages=messages,
temperature=0,
timeout=timeout,
max_tokens=500,
)
return (response.choices[0].message.content or "").strip()
def generate_summary_sync(node, model):
"""Generate a node summary. Short nodes skip the LLM entirely."""
node_text = node.get("text", "")
num_tokens = count_tokens(node_text, model=model)
if num_tokens < 200:
return node_text
title = node.get("title", "")
prompt = (
f"Write a retrieval summary for this section of NZ ECE licensing criteria.\n\n"
f"Section title: {title}\n\n"
f"Full text:\n{node_text}\n\n"
f"Summary:"
)
try:
return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT)
except Exception as e:
return f"[summary error: {e}]"
def generate_doc_description_sync(structure, model):
"""Generate doc description."""
clean_structure = create_clean_structure_for_description(structure)
prompt = (
"You are given the structure of a NZ Early Childhood Education regulatory document. "
"Write a one-sentence description that distinguishes it from other ECE documents. "
"Be specific about what domains, criteria codes, or legislation it covers.\n\n"
f"Document Structure: {clean_structure}\n\n"
"Description:"
)
try:
return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT)
except Exception as e:
return f"[description error: {e}]"
def build_tree_with_progress(domain):
"""Build a PageIndex tree with per-step progress output."""
md_path = os.path.join(CORPUS_DIR, f"{domain}.md")
if not os.path.exists(md_path):
print(f" SKIP: {md_path} not found")
return None
size_kb = os.path.getsize(md_path) / 1024
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
line_count = content.count("\n") + 1
print(f" Source: {size_kb:.0f} KB, {line_count} lines")
# Step 1: Extract nodes (fast, no LLM)
t0 = time.time()
node_list, markdown_lines = extract_nodes_from_markdown(content)
nodes_with_content = extract_node_text_content(node_list, markdown_lines)
print(f" Extracting nodes... {len(nodes_with_content)} found ({time.time()-t0:.1f}s)")
# Step 2: Build tree structure (fast, no LLM)
tree_structure = build_tree_from_nodes(nodes_with_content)
write_node_id(tree_structure)
# Format with text included (needed for summary generation)
tree_structure = format_structure(
tree_structure,
order=["title", "node_id", "line_num", "summary", "prefix_summary", "text", "nodes"],
)
# Step 3: Generate summaries one by one
all_nodes = structure_to_list(tree_structure)
total = len(all_nodes)
llm_needed = sum(1 for n in all_nodes if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200)
print(f" Generating summaries ({total} nodes, {llm_needed} need LLM)...")
t_sum = time.time()
skipped = 0
for i, node in enumerate(all_nodes, 1):
t_node = time.time()
summary = generate_summary_sync(node, INDEX_MODEL)
if summary.startswith("["):
skipped += 1
print(f" {i}/{total} SKIP {node.get('title', '?')[:40]}{summary}")
else:
elapsed_node = time.time() - t_node
label = f"({elapsed_node:.1f}s)" if elapsed_node > 0.1 else "(skip)"
print(f" {i}/{total} {label} {node.get('title', '?')[:50]}")
if not node.get("nodes"):
node["summary"] = summary
else:
node["prefix_summary"] = summary
sum_elapsed = time.time() - t_sum
print(f" Summaries done: {total - skipped}/{total} in {sum_elapsed:.0f}s")
# Step 4: Generate doc description (1 LLM call)
print(f" Generating doc description... ", end="", flush=True)
t_desc = time.time()
doc_description = generate_doc_description_sync(tree_structure, INDEX_MODEL)
print(f"done ({time.time()-t_desc:.1f}s)")
tree = {
"doc_name": domain,
"doc_description": doc_description,
"line_count": line_count,
"structure": tree_structure,
}
return tree
def save_tree(domain, tree):
"""Save tree to JSON file."""
output_path = os.path.join(INDEX_DIR, f"{domain}.json")
with open(output_path, "w", encoding="utf-8") as f:
json.dump(tree, f, indent=2, ensure_ascii=False)
size_kb = os.path.getsize(output_path) / 1024
return output_path, size_kb
def dry_run_report(targets):
"""Show node counts and estimated LLM calls without calling the API."""
print("=" * 60)
print("Dry run — no API calls will be made")
print(f"Index model: {INDEX_MODEL}")
print("=" * 60)
total_nodes = 0
total_llm = 0
for domain in targets:
md_path = os.path.join(CORPUS_DIR, f"{domain}.md")
if not os.path.exists(md_path):
print(f"\n {domain}: corpus not found")
continue
size_kb = os.path.getsize(md_path) / 1024
with open(md_path, "r", encoding="utf-8") as f:
content = f.read()
line_count = content.count("\n") + 1
# Extract nodes (no LLM)
node_list, markdown_lines = extract_nodes_from_markdown(content)
nodes_with_content = extract_node_text_content(node_list, markdown_lines)
# Count by heading level
level_counts = {}
for n in nodes_with_content:
level = n.get("heading_level", 0)
level_counts[level] = level_counts.get(level, 0) + 1
# Count how many need LLM
llm_needed = sum(
1 for n in nodes_with_content
if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200
)
api_calls = llm_needed + 1 # +1 for doc description
total_nodes += len(nodes_with_content)
total_llm += api_calls
# Check existing index
idx_path = os.path.join(INDEX_DIR, f"{domain}.json")
idx_status = ""
if os.path.exists(idx_path):
idx_size = os.path.getsize(idx_path) / 1024
idx_status = f" (existing index: {idx_size:.0f} KB)"
print(f"\n {domain}")
print(f" Corpus: {size_kb:.0f} KB, {line_count} lines")
levels_str = ", ".join(f"H{k}={v}" for k, v in sorted(level_counts.items()))
print(f" Nodes: {len(nodes_with_content)} ({levels_str})")
print(f" LLM calls: {llm_needed} summaries + 1 description = {api_calls}{idx_status}")
print(f"\n{'─' * 60}")
print(f" Total: {total_nodes} nodes, {total_llm} API calls")
print(f"{'=' * 60}")
def main():
parser = argparse.ArgumentParser(description="Build PageIndex trees")
parser.add_argument("domains", nargs="*", help="Specific domains to build (default: all)")
parser.add_argument("--list", action="store_true", help="List available domains")
parser.add_argument("--dry-run", action="store_true",
help="Show node counts and estimated API calls without building")
args = parser.parse_args()
if args.list:
print("Available domains:")
for d in DOMAINS:
path = os.path.join(CORPUS_DIR, f"{d}.md")
exists = os.path.exists(path)
size = f"{os.path.getsize(path)/1024:.0f} KB" if exists else "not found"
print(f" {d:25s} {size}")
return
# Validate domain names
targets = args.domains if args.domains else DOMAINS
for d in targets:
if d not in DOMAINS:
print(f"Unknown domain: {d}")
print(f"Available: {', '.join(DOMAINS)}")
sys.exit(1)
if args.dry_run:
dry_run_report(targets)
return
print("=" * 60)
print("Building PageIndex trees")
print(f"Index model: {INDEX_MODEL}")
print(f"Domains: {', '.join(targets)}")
print("=" * 60)
# Pre-flight check — only needed for local MLX models
if INDEX_MODEL.startswith("openai/"):
print("\nChecking mlx-lm server... ", end="", flush=True)
check_mlx_server()
print("OK")
else:
print(f"\nUsing API model: {INDEX_MODEL}")
total_time = time.time()
for i, domain in enumerate(targets, 1):
print(f"\n[{i}/{len(targets)}] {domain}")
print(f" {'─' * 40}")
start_time = time.time()
tree = build_tree_with_progress(domain)
if tree is None:
continue
# Save
output_path, size_kb = save_tree(domain, tree)
elapsed = time.time() - start_time
print(f" Saved: {output_path} ({size_kb:.0f} KB)")
print(f" Total: {elapsed:.0f}s")
print(f"\n{'=' * 60}")
print(f"Done in {time.time()-total_time:.0f}s")
print(f"{'=' * 60}")
if __name__ == "__main__":
main()