Spaces:
Sleeping
Sleeping
| """ | |
| Build PageIndex trees for each domain compilation in the corpus. | |
| Uses local MLX server (Qwen 2.5 14B) for summary generation — zero API cost. | |
| Shows real-time progress for each LLM call. | |
| Usage: | |
| uv run python scripts/build_indexes.py # build all | |
| uv run python scripts/build_indexes.py legislation # build one domain | |
| uv run python scripts/build_indexes.py --list # show available domains | |
| uv run python scripts/build_indexes.py --dry-run # show costs without building | |
| """ | |
| import sys | |
| import os | |
| import json | |
| import time | |
| import argparse | |
| from dotenv import load_dotenv | |
| # Load .env from project root | |
| load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"), override=True) | |
| # Add PageIndex repo to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "PageIndex")) | |
| import litellm | |
| from pageindex.page_index_md import ( | |
| extract_nodes_from_markdown, | |
| extract_node_text_content, | |
| build_tree_from_nodes, | |
| ) | |
| from pageindex.utils import ( | |
| structure_to_list, | |
| write_node_id, | |
| format_structure, | |
| count_tokens, | |
| create_clean_structure_for_description, | |
| ) | |
| # Model selection — supports local MLX and remote API models | |
| # ECE_MODEL takes priority (presets: qwen, gemma, sonnet, opus) | |
| # Falls back to ECE_MLX_MODEL for backwards compat | |
| _MODEL_PRESETS = { | |
| "qwen": "openai/mlx-community/Qwen2.5-14B-Instruct-4bit", | |
| "gemma": "openai/mlx-community/gemma-4-26b-a4b-it-4bit", | |
| "sonnet": "anthropic/claude-sonnet-4-6", | |
| "opus": "anthropic/claude-opus-4-6", | |
| } | |
| _model_choice = os.environ.get("ECE_MODEL") or os.environ.get("ECE_MLX_MODEL") or "qwen" | |
| INDEX_MODEL = _MODEL_PRESETS.get(_model_choice, _model_choice) | |
| LLM_TIMEOUT = 120 | |
| # Only configure MLX server env vars for local models | |
| MLX_BASE_URL = "http://localhost:8080/v1" | |
| if INDEX_MODEL.startswith("openai/"): | |
| os.environ["OPENAI_API_KEY"] = "mlx" | |
| os.environ["OPENAI_API_BASE"] = MLX_BASE_URL | |
| CORPUS_DIR = os.path.join(os.path.dirname(__file__), "..", "corpus") | |
| INDEX_DIR = os.path.join(os.path.dirname(__file__), "..", "indexes") | |
| os.makedirs(INDEX_DIR, exist_ok=True) | |
| DOMAINS = [ | |
| "medicines-and-supplements", | |
| "advertising-standards", | |
| "consumer-protection", | |
| "marketing-comms", | |
| "practitioner-regulation", | |
| "professional-codes", | |
| ] | |
| def check_mlx_server(): | |
| """Pre-flight check: mlx-lm server is running.""" | |
| import urllib.request | |
| import urllib.error | |
| try: | |
| req = urllib.request.Request(f"{MLX_BASE_URL}/models", method="GET") | |
| with urllib.request.urlopen(req, timeout=3) as resp: | |
| if resp.status == 200: | |
| return True | |
| except (urllib.error.URLError, OSError): | |
| pass | |
| print(f"\n ERROR: mlx-lm server not running at {MLX_BASE_URL}") | |
| print(f" Start it with:") | |
| print(f" uv run python -m mlx_lm server --model {MLX_MODEL} --port 8080\n") | |
| sys.exit(1) | |
| SUMMARY_SYSTEM_PROMPT = """\ | |
| You write concise retrieval summaries for a NZ healthcare marketing compliance system. | |
| Your summaries will be compared against user search queries to decide which sections are relevant. \ | |
| A good summary front-loads the specific topics, requirements, and terminology that someone would \ | |
| search for. A bad summary is generic ("this section covers requirements for...") and could match anything. | |
| Rules: | |
| - 1-3 sentences, under 80 words | |
| - Start with the criterion code and topic, not "This section covers..." | |
| - Name specific requirements: age groups, ratios, qualifications, equipment types, time periods | |
| - Use the same terminology the regulation uses (e.g. "under-2s" not "infants", "kaiako" not "teachers") | |
| - Mention any secondary topics the section also covers | |
| - Do not include opinions or interpretation — just what the section contains""" | |
| def llm_call(model, prompt, system_prompt=None, timeout=LLM_TIMEOUT): | |
| """Single LLM call via litellm.""" | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| response = litellm.completion( | |
| model=model, | |
| messages=messages, | |
| temperature=0, | |
| timeout=timeout, | |
| max_tokens=500, | |
| ) | |
| return (response.choices[0].message.content or "").strip() | |
| def generate_summary_sync(node, model): | |
| """Generate a node summary. Short nodes skip the LLM entirely.""" | |
| node_text = node.get("text", "") | |
| num_tokens = count_tokens(node_text, model=model) | |
| if num_tokens < 200: | |
| return node_text | |
| title = node.get("title", "") | |
| prompt = ( | |
| f"Write a retrieval summary for this section of NZ ECE licensing criteria.\n\n" | |
| f"Section title: {title}\n\n" | |
| f"Full text:\n{node_text}\n\n" | |
| f"Summary:" | |
| ) | |
| try: | |
| return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT) | |
| except Exception as e: | |
| return f"[summary error: {e}]" | |
| def generate_doc_description_sync(structure, model): | |
| """Generate doc description.""" | |
| clean_structure = create_clean_structure_for_description(structure) | |
| prompt = ( | |
| "You are given the structure of a NZ Early Childhood Education regulatory document. " | |
| "Write a one-sentence description that distinguishes it from other ECE documents. " | |
| "Be specific about what domains, criteria codes, or legislation it covers.\n\n" | |
| f"Document Structure: {clean_structure}\n\n" | |
| "Description:" | |
| ) | |
| try: | |
| return llm_call(model, prompt, system_prompt=SUMMARY_SYSTEM_PROMPT) | |
| except Exception as e: | |
| return f"[description error: {e}]" | |
| def build_tree_with_progress(domain): | |
| """Build a PageIndex tree with per-step progress output.""" | |
| md_path = os.path.join(CORPUS_DIR, f"{domain}.md") | |
| if not os.path.exists(md_path): | |
| print(f" SKIP: {md_path} not found") | |
| return None | |
| size_kb = os.path.getsize(md_path) / 1024 | |
| with open(md_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| line_count = content.count("\n") + 1 | |
| print(f" Source: {size_kb:.0f} KB, {line_count} lines") | |
| # Step 1: Extract nodes (fast, no LLM) | |
| t0 = time.time() | |
| node_list, markdown_lines = extract_nodes_from_markdown(content) | |
| nodes_with_content = extract_node_text_content(node_list, markdown_lines) | |
| print(f" Extracting nodes... {len(nodes_with_content)} found ({time.time()-t0:.1f}s)") | |
| # Step 2: Build tree structure (fast, no LLM) | |
| tree_structure = build_tree_from_nodes(nodes_with_content) | |
| write_node_id(tree_structure) | |
| # Format with text included (needed for summary generation) | |
| tree_structure = format_structure( | |
| tree_structure, | |
| order=["title", "node_id", "line_num", "summary", "prefix_summary", "text", "nodes"], | |
| ) | |
| # Step 3: Generate summaries one by one | |
| all_nodes = structure_to_list(tree_structure) | |
| total = len(all_nodes) | |
| llm_needed = sum(1 for n in all_nodes if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200) | |
| print(f" Generating summaries ({total} nodes, {llm_needed} need LLM)...") | |
| t_sum = time.time() | |
| skipped = 0 | |
| for i, node in enumerate(all_nodes, 1): | |
| t_node = time.time() | |
| summary = generate_summary_sync(node, INDEX_MODEL) | |
| if summary.startswith("["): | |
| skipped += 1 | |
| print(f" {i}/{total} SKIP {node.get('title', '?')[:40]} — {summary}") | |
| else: | |
| elapsed_node = time.time() - t_node | |
| label = f"({elapsed_node:.1f}s)" if elapsed_node > 0.1 else "(skip)" | |
| print(f" {i}/{total} {label} {node.get('title', '?')[:50]}") | |
| if not node.get("nodes"): | |
| node["summary"] = summary | |
| else: | |
| node["prefix_summary"] = summary | |
| sum_elapsed = time.time() - t_sum | |
| print(f" Summaries done: {total - skipped}/{total} in {sum_elapsed:.0f}s") | |
| # Step 4: Generate doc description (1 LLM call) | |
| print(f" Generating doc description... ", end="", flush=True) | |
| t_desc = time.time() | |
| doc_description = generate_doc_description_sync(tree_structure, INDEX_MODEL) | |
| print(f"done ({time.time()-t_desc:.1f}s)") | |
| tree = { | |
| "doc_name": domain, | |
| "doc_description": doc_description, | |
| "line_count": line_count, | |
| "structure": tree_structure, | |
| } | |
| return tree | |
| def save_tree(domain, tree): | |
| """Save tree to JSON file.""" | |
| output_path = os.path.join(INDEX_DIR, f"{domain}.json") | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(tree, f, indent=2, ensure_ascii=False) | |
| size_kb = os.path.getsize(output_path) / 1024 | |
| return output_path, size_kb | |
| def dry_run_report(targets): | |
| """Show node counts and estimated LLM calls without calling the API.""" | |
| print("=" * 60) | |
| print("Dry run — no API calls will be made") | |
| print(f"Index model: {INDEX_MODEL}") | |
| print("=" * 60) | |
| total_nodes = 0 | |
| total_llm = 0 | |
| for domain in targets: | |
| md_path = os.path.join(CORPUS_DIR, f"{domain}.md") | |
| if not os.path.exists(md_path): | |
| print(f"\n {domain}: corpus not found") | |
| continue | |
| size_kb = os.path.getsize(md_path) / 1024 | |
| with open(md_path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| line_count = content.count("\n") + 1 | |
| # Extract nodes (no LLM) | |
| node_list, markdown_lines = extract_nodes_from_markdown(content) | |
| nodes_with_content = extract_node_text_content(node_list, markdown_lines) | |
| # Count by heading level | |
| level_counts = {} | |
| for n in nodes_with_content: | |
| level = n.get("heading_level", 0) | |
| level_counts[level] = level_counts.get(level, 0) + 1 | |
| # Count how many need LLM | |
| llm_needed = sum( | |
| 1 for n in nodes_with_content | |
| if count_tokens(n.get("text", ""), model=INDEX_MODEL) >= 200 | |
| ) | |
| api_calls = llm_needed + 1 # +1 for doc description | |
| total_nodes += len(nodes_with_content) | |
| total_llm += api_calls | |
| # Check existing index | |
| idx_path = os.path.join(INDEX_DIR, f"{domain}.json") | |
| idx_status = "" | |
| if os.path.exists(idx_path): | |
| idx_size = os.path.getsize(idx_path) / 1024 | |
| idx_status = f" (existing index: {idx_size:.0f} KB)" | |
| print(f"\n {domain}") | |
| print(f" Corpus: {size_kb:.0f} KB, {line_count} lines") | |
| levels_str = ", ".join(f"H{k}={v}" for k, v in sorted(level_counts.items())) | |
| print(f" Nodes: {len(nodes_with_content)} ({levels_str})") | |
| print(f" LLM calls: {llm_needed} summaries + 1 description = {api_calls}{idx_status}") | |
| print(f"\n{'─' * 60}") | |
| print(f" Total: {total_nodes} nodes, {total_llm} API calls") | |
| print(f"{'=' * 60}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Build PageIndex trees") | |
| parser.add_argument("domains", nargs="*", help="Specific domains to build (default: all)") | |
| parser.add_argument("--list", action="store_true", help="List available domains") | |
| parser.add_argument("--dry-run", action="store_true", | |
| help="Show node counts and estimated API calls without building") | |
| args = parser.parse_args() | |
| if args.list: | |
| print("Available domains:") | |
| for d in DOMAINS: | |
| path = os.path.join(CORPUS_DIR, f"{d}.md") | |
| exists = os.path.exists(path) | |
| size = f"{os.path.getsize(path)/1024:.0f} KB" if exists else "not found" | |
| print(f" {d:25s} {size}") | |
| return | |
| # Validate domain names | |
| targets = args.domains if args.domains else DOMAINS | |
| for d in targets: | |
| if d not in DOMAINS: | |
| print(f"Unknown domain: {d}") | |
| print(f"Available: {', '.join(DOMAINS)}") | |
| sys.exit(1) | |
| if args.dry_run: | |
| dry_run_report(targets) | |
| return | |
| print("=" * 60) | |
| print("Building PageIndex trees") | |
| print(f"Index model: {INDEX_MODEL}") | |
| print(f"Domains: {', '.join(targets)}") | |
| print("=" * 60) | |
| # Pre-flight check — only needed for local MLX models | |
| if INDEX_MODEL.startswith("openai/"): | |
| print("\nChecking mlx-lm server... ", end="", flush=True) | |
| check_mlx_server() | |
| print("OK") | |
| else: | |
| print(f"\nUsing API model: {INDEX_MODEL}") | |
| total_time = time.time() | |
| for i, domain in enumerate(targets, 1): | |
| print(f"\n[{i}/{len(targets)}] {domain}") | |
| print(f" {'─' * 40}") | |
| start_time = time.time() | |
| tree = build_tree_with_progress(domain) | |
| if tree is None: | |
| continue | |
| # Save | |
| output_path, size_kb = save_tree(domain, tree) | |
| elapsed = time.time() - start_time | |
| print(f" Saved: {output_path} ({size_kb:.0f} KB)") | |
| print(f" Total: {elapsed:.0f}s") | |
| print(f"\n{'=' * 60}") | |
| print(f"Done in {time.time()-total_time:.0f}s") | |
| print(f"{'=' * 60}") | |
| if __name__ == "__main__": | |
| main() | |