"""Background page metadata generator — extracts per-page descriptions from the full PDF. Uses parallel batch processing: the PDF is split into 5-page chunks and each chunk is sent to Gemini concurrently for faster metadata extraction. """ from __future__ import annotations import json import logging import math from concurrent.futures import ThreadPoolExecutor, as_completed from google import genai from google.genai import types from config import GOOGLE_API_KEY, METADATA_MODEL from prompts.metadata import METADATA_SYSTEM_PROMPT from tools.pdf_processor import extract_page_range_bytes logger = logging.getLogger(__name__) # Number of PDF pages per batch sent to Gemini in parallel. BATCH_SIZE = 5 # --------------------------------------------------------------------------- # JSON extraction helper # --------------------------------------------------------------------------- def _extract_json_array(response_text: str) -> list[dict]: """Extract the outermost balanced JSON array from a response string.""" start = response_text.find("[") if start == -1: raise ValueError("No JSON array found in metadata generation response") depth = 0 end = None for i in range(start, len(response_text)): if response_text[i] == "[": depth += 1 elif response_text[i] == "]": depth -= 1 if depth == 0: end = i break if end is None: raise ValueError("No matching closing bracket found in metadata response") result = json.loads(response_text[start : end + 1]) if not isinstance(result, list): raise ValueError(f"Expected list, got {type(result)}") return result # --------------------------------------------------------------------------- # Single-batch API call # --------------------------------------------------------------------------- def _generate_batch( pdf_path: str, page_start_0: int, page_end_0: int, page_start_1: int, page_end_1: int, ) -> list[dict]: """Generate metadata for a contiguous range of pages. Args: pdf_path: Path to the full PDF on disk. page_start_0: First page (0-indexed, inclusive) for PDF extraction. page_end_0: Last page (0-indexed, inclusive) for PDF extraction. page_start_1: First page (1-indexed) — used in the prompt text. page_end_1: Last page (1-indexed) — used in the prompt text. Returns: List of metadata dicts for the pages in this batch. """ client = genai.Client(api_key=GOOGLE_API_KEY) batch_pdf_bytes = extract_page_range_bytes(pdf_path, page_start_0, page_end_0) pdf_part = types.Part.from_bytes(data=batch_pdf_bytes, mime_type="application/pdf") num_batch_pages = page_end_1 - page_start_1 + 1 instruction_text = ( f"This PDF excerpt contains {num_batch_pages} page(s), " f"corresponding to pages {page_start_1} through {page_end_1} of the full drawing set.\n" f"Generate structured metadata for ALL {num_batch_pages} page(s). " f"Use page numbers {page_start_1} through {page_end_1} (1-indexed). " f"Return a JSON array with exactly {num_batch_pages} objects." ) instruction_part = types.Part.from_text(text=instruction_text) response = client.models.generate_content( model=METADATA_MODEL, contents=[types.Content(role="user", parts=[pdf_part, instruction_part])], config=types.GenerateContentConfig( system_instruction=METADATA_SYSTEM_PROMPT, ), ) return _extract_json_array(response.text.strip()) # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- def generate_page_metadata( pdf_path: str, num_pages: int, progress_callback=None, ) -> list[dict]: """Extract per-page structured metadata from a PDF using parallel batches. The PDF is split into chunks of ``BATCH_SIZE`` pages. Each chunk is sent to Gemini concurrently via a thread pool. Results are merged, any missing pages are back-filled, and the list is returned sorted by page number. Args: pdf_path: Path to the full PDF. num_pages: Total number of pages. progress_callback: Optional ``(completed_batches, total_batches, page_range_str) -> None`` called after each batch finishes. Returns a list of dicts (1-indexed page_num), one per page. Raises on failure (caller is responsible for error handling). """ num_batches = math.ceil(num_pages / BATCH_SIZE) logger.info( "Starting parallel metadata generation: %d pages in %d batches of %d", num_pages, num_batches, BATCH_SIZE, ) all_results: list[dict] = [] errors: list[str] = [] completed_count = 0 with ThreadPoolExecutor(max_workers=num_batches) as executor: futures = {} for batch_idx in range(num_batches): page_start_0 = batch_idx * BATCH_SIZE page_end_0 = min(page_start_0 + BATCH_SIZE - 1, num_pages - 1) page_start_1 = page_start_0 + 1 page_end_1 = page_end_0 + 1 future = executor.submit( _generate_batch, pdf_path, page_start_0, page_end_0, page_start_1, page_end_1, ) futures[future] = (page_start_1, page_end_1) for future in as_completed(futures): batch_range = futures[future] try: batch_results = future.result() all_results.extend(batch_results) completed_count += 1 logger.info("Batch pages %d-%d complete: %d entries", batch_range[0], batch_range[1], len(batch_results)) if progress_callback is not None: progress_callback( completed_count, num_batches, f"Pages {batch_range[0]}-{batch_range[1]}", ) except Exception as e: completed_count += 1 errors.append(f"Batch pages {batch_range[0]}-{batch_range[1]} failed: {e}") logger.exception("Batch pages %d-%d failed", batch_range[0], batch_range[1]) if progress_callback is not None: progress_callback( completed_count, num_batches, f"Pages {batch_range[0]}-{batch_range[1]} (failed)", ) if errors and not all_results: raise RuntimeError( f"All metadata batches failed:\n" + "\n".join(errors) ) if errors: logger.warning("Some batches failed (results will have gaps): %s", errors) # Metadata stays 1-indexed (as the model produced it) because it will be # passed as context text to the planner model, which also uses 1-indexed. # The planner's *output* is converted to 0-indexed in nodes/planner.py. # Fill in any missing pages with minimal entries (1-indexed) covered_pages = {item.get("page_num") for item in all_results} for p in range(1, num_pages + 1): if p not in covered_pages: all_results.append({ "page_num": p, "sheet_id": "unknown", "sheet_title": "Unknown", "discipline": "other", "page_type": "other", "description": "Metadata not extracted for this page.", "key_elements": [], "spatial_coverage": "", }) # Sort by page number all_results.sort(key=lambda x: x.get("page_num", 0)) return all_results