""" Build PageIndex trees for each domain compilation in the corpus. Uses local MLX server (Qwen 2.5 14B) for summary generation — zero API cost. Shows real-time progress for each LLM call. Usage: uv run python scripts/build_indexes.py # build all uv run python scripts/build_indexes.py legislation # build one domain uv run python scripts/build_indexes.py --list # show available domains uv run python scripts/build_indexes.py --dry-run # show costs without building """ import sys import os import json import time import argparse from dotenv import load_dotenv # Load .env from project root load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"), override=True) # Add PageIndex repo to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "PageIndex")) import litellm from pageindex.page_index_md import ( extract_nodes_from_markdown, extract_node_text_content, build_tree_from_nodes, ) from pageindex.utils import ( structure_to_list, write_node_id, format_structure, count_tokens, create_clean_structure_for_description, ) # Model selection — supports local MLX and remote API models # ECE_MODEL takes priority (presets: qwen, gemma, sonnet, opus) # Falls back to ECE_MLX_MODEL for backwards compat _MODEL_PRESETS = { "qwen": "openai/mlx-community/Qwen2.5-14B-Instruct-4bit", "gemma": "openai/mlx-community/gemma-4-26b-a4b-it-4bit", "sonnet": "anthropic/claude-sonnet-4-6", "opus": "anthropic/claude-opus-4-6", } _model_choice = os.environ.get("ECE_MODEL") or os.environ.get("ECE_MLX_MODEL") or "qwen" INDEX_MODEL = _MODEL_PRESETS.get(_model_choice, _model_choice) LLM_TIMEOUT = 120 # Only configure MLX server env vars for local models MLX_BASE_URL = "http://localhost:8080/v1" if INDEX_MODEL.startswith("openai/"): os.environ["OPENAI_API_KEY"] = "mlx" os.environ["OPENAI_API_BASE"] = MLX_BASE_URL CORPUS_DIR = os.path.join(os.path.dirname(__file__), "..", "corpus") INDEX_DIR = os.path.join(os.path.dirname(__file__), "..", "indexes") os.makedirs(INDEX_DIR, exist_ok=True) DOMAINS = [ "medicines-and-supplements", "advertising-standards", "consumer-protection", "marketing-comms", "practitioner-regulation", "professional-codes", ] def check_mlx_server(): """Pre-flight check: mlx-lm server is running.""" import urllib.request import urllib.error try: req = urllib.request.Request(f"{MLX_BASE_URL}/models", method="GET") with urllib.request.urlopen(req, timeout=3) as resp: if resp.status == 200: return True except (urllib.error.URLError, OSError): pass print(f"\n ERROR: mlx-lm server not running at {MLX_BASE_URL}") print(f" Start it with:") print(f" uv run python -m mlx_lm server --model {MLX_MODEL} --port 8080\n") sys.exit(1) SUMMARY_SYSTEM_PROMPT = """\ You write concise retrieval summaries for a NZ healthcare marketing compliance system. Your summaries will be compared against user search queries to decide which sections are relevant. \ A good summary front-loads the specific topics, requirements, and terminology that someone would \ search for. A bad summary is generic ("this section covers requirements for...") and could match anything. Rules: - 1-3 sentences, under 80 words - Start with the criterion code and topic, not "This section covers..." - Name specific requirements: age groups, ratios, qualifications, equipment types, time periods - Use the same terminology the regulation uses (e.g. "under-2s" not "infants", "kaiako" not "teachers") - Mention any secondary topics the section also covers - Do not include opinions or interpretation — just what the section contains""" def llm_call(model, prompt, system_prompt=None, timeout=LLM_TIMEOUT): """Single LLM call via litellm.""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) response = litellm.completion( model=model, messages=messages, temperature=0, timeout=timeout, max_tokens=500, ) return (response.choices[0].message.content or "").strip() def generate_summary_sync(node, model): """Generate a node summary. Short nodes skip the LLM entirely.""" node_text = node.get("text", "") num_tokens = count_tokens(node_text, model=model) if num_tokens < 200: return node_text title = node.get("title", "") prompt = ( f"Write a retrieval summary for this section of NZ ECE licensing criteria.\n\n" f"Section title: {title}\n\n" f"Full text:\n{node_text}\n\n" f"Summary:" ) try: return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT) except Exception as e: return f"[summary error: {e}]" def generate_doc_description_sync(structure, model): """Generate doc description.""" clean_structure = create_clean_structure_for_description(structure) prompt = ( "You are given the structure of a NZ Early Childhood Education regulatory document. " "Write a one-sentence description that distinguishes it from other ECE documents. " "Be specific about what domains, criteria codes, or legislation it covers.\n\n" f"Document Structure: {clean_structure}\n\n" "Description:" ) try: return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT) except Exception as e: return f"[description error: {e}]" def build_tree_with_progress(domain): """Build a PageIndex tree with per-step progress output.""" md_path = os.path.join(CORPUS_DIR, f"{domain}.md") if not os.path.exists(md_path): print(f" SKIP: {md_path} not found") return None size_kb = os.path.getsize(md_path) / 1024 with open(md_path, "r", encoding="utf-8") as f: content = f.read() line_count = content.count("\n") + 1 print(f" Source: {size_kb:.0f} KB, {line_count} lines") # Step 1: Extract nodes (fast, no LLM) t0 = time.time() node_list, markdown_lines = extract_nodes_from_markdown(content) nodes_with_content = extract_node_text_content(node_list, markdown_lines) print(f" Extracting nodes... {len(nodes_with_content)} found ({time.time()-t0:.1f}s)") # Step 2: Build tree structure (fast, no LLM) tree_structure = build_tree_from_nodes(nodes_with_content) write_node_id(tree_structure) # Format with text included (needed for summary generation) tree_structure = format_structure( tree_structure, order=["title", "node_id", "line_num", "summary", "prefix_summary", "text", "nodes"], ) # Step 3: Generate summaries one by one all_nodes = structure_to_list(tree_structure) total = len(all_nodes) llm_needed = sum(1 for n in all_nodes if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200) print(f" Generating summaries ({total} nodes, {llm_needed} need LLM)...") t_sum = time.time() skipped = 0 for i, node in enumerate(all_nodes, 1): t_node = time.time() summary = generate_summary_sync(node, INDEX_MODEL) if summary.startswith("["): skipped += 1 print(f" {i}/{total} SKIP {node.get('title', '?')[:40]} — {summary}") else: elapsed_node = time.time() - t_node label = f"({elapsed_node:.1f}s)" if elapsed_node > 0.1 else "(skip)" print(f" {i}/{total} {label} {node.get('title', '?')[:50]}") if not node.get("nodes"): node["summary"] = summary else: node["prefix_summary"] = summary sum_elapsed = time.time() - t_sum print(f" Summaries done: {total - skipped}/{total} in {sum_elapsed:.0f}s") # Step 4: Generate doc description (1 LLM call) print(f" Generating doc description... ", end="", flush=True) t_desc = time.time() doc_description = generate_doc_description_sync(tree_structure, INDEX_MODEL) print(f"done ({time.time()-t_desc:.1f}s)") tree = { "doc_name": domain, "doc_description": doc_description, "line_count": line_count, "structure": tree_structure, } return tree def save_tree(domain, tree): """Save tree to JSON file.""" output_path = os.path.join(INDEX_DIR, f"{domain}.json") with open(output_path, "w", encoding="utf-8") as f: json.dump(tree, f, indent=2, ensure_ascii=False) size_kb = os.path.getsize(output_path) / 1024 return output_path, size_kb def dry_run_report(targets): """Show node counts and estimated LLM calls without calling the API.""" print("=" * 60) print("Dry run — no API calls will be made") print(f"Index model: {INDEX_MODEL}") print("=" * 60) total_nodes = 0 total_llm = 0 for domain in targets: md_path = os.path.join(CORPUS_DIR, f"{domain}.md") if not os.path.exists(md_path): print(f"\n {domain}: corpus not found") continue size_kb = os.path.getsize(md_path) / 1024 with open(md_path, "r", encoding="utf-8") as f: content = f.read() line_count = content.count("\n") + 1 # Extract nodes (no LLM) node_list, markdown_lines = extract_nodes_from_markdown(content) nodes_with_content = extract_node_text_content(node_list, markdown_lines) # Count by heading level level_counts = {} for n in nodes_with_content: level = n.get("heading_level", 0) level_counts[level] = level_counts.get(level, 0) + 1 # Count how many need LLM llm_needed = sum( 1 for n in nodes_with_content if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200 ) api_calls = llm_needed + 1 # +1 for doc description total_nodes += len(nodes_with_content) total_llm += api_calls # Check existing index idx_path = os.path.join(INDEX_DIR, f"{domain}.json") idx_status = "" if os.path.exists(idx_path): idx_size = os.path.getsize(idx_path) / 1024 idx_status = f" (existing index: {idx_size:.0f} KB)" print(f"\n {domain}") print(f" Corpus: {size_kb:.0f} KB, {line_count} lines") levels_str = ", ".join(f"H{k}={v}" for k, v in sorted(level_counts.items())) print(f" Nodes: {len(nodes_with_content)} ({levels_str})") print(f" LLM calls: {llm_needed} summaries + 1 description = {api_calls}{idx_status}") print(f"\n{'─' * 60}") print(f" Total: {total_nodes} nodes, {total_llm} API calls") print(f"{'=' * 60}") def main(): parser = argparse.ArgumentParser(description="Build PageIndex trees") parser.add_argument("domains", nargs="*", help="Specific domains to build (default: all)") parser.add_argument("--list", action="store_true", help="List available domains") parser.add_argument("--dry-run", action="store_true", help="Show node counts and estimated API calls without building") args = parser.parse_args() if args.list: print("Available domains:") for d in DOMAINS: path = os.path.join(CORPUS_DIR, f"{d}.md") exists = os.path.exists(path) size = f"{os.path.getsize(path)/1024:.0f} KB" if exists else "not found" print(f" {d:25s} {size}") return # Validate domain names targets = args.domains if args.domains else DOMAINS for d in targets: if d not in DOMAINS: print(f"Unknown domain: {d}") print(f"Available: {', '.join(DOMAINS)}") sys.exit(1) if args.dry_run: dry_run_report(targets) return print("=" * 60) print("Building PageIndex trees") print(f"Index model: {INDEX_MODEL}") print(f"Domains: {', '.join(targets)}") print("=" * 60) # Pre-flight check — only needed for local MLX models if INDEX_MODEL.startswith("openai/"): print("\nChecking mlx-lm server... ", end="", flush=True) check_mlx_server() print("OK") else: print(f"\nUsing API model: {INDEX_MODEL}") total_time = time.time() for i, domain in enumerate(targets, 1): print(f"\n[{i}/{len(targets)}] {domain}") print(f" {'─' * 40}") start_time = time.time() tree = build_tree_with_progress(domain) if tree is None: continue # Save output_path, size_kb = save_tree(domain, tree) elapsed = time.time() - start_time print(f" Saved: {output_path} ({size_kb:.0f} KB)") print(f" Total: {elapsed:.0f}s") print(f"\n{'=' * 60}") print(f"Done in {time.time()-total_time:.0f}s") print(f"{'=' * 60}") if __name__ == "__main__": main()