Updated_code_complaince / nodes /metadata_generator.py
Ryan2219's picture
Upload 70 files
e1ced8e verified
"""Background page metadata generator — extracts per-page descriptions from the full PDF.
Uses parallel batch processing: the PDF is split into 5-page chunks and each
chunk is sent to Gemini concurrently for faster metadata extraction.
"""
from __future__ import annotations
import json
import logging
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from google.genai import types
from config import GOOGLE_API_KEY, METADATA_MODEL
from prompts.metadata import METADATA_SYSTEM_PROMPT
from tools.pdf_processor import extract_page_range_bytes
logger = logging.getLogger(__name__)
# Number of PDF pages per batch sent to Gemini in parallel.
BATCH_SIZE = 5
# ---------------------------------------------------------------------------
# JSON extraction helper
# ---------------------------------------------------------------------------
def _extract_json_array(response_text: str) -> list[dict]:
"""Extract the outermost balanced JSON array from a response string."""
start = response_text.find("[")
if start == -1:
raise ValueError("No JSON array found in metadata generation response")
depth = 0
end = None
for i in range(start, len(response_text)):
if response_text[i] == "[":
depth += 1
elif response_text[i] == "]":
depth -= 1
if depth == 0:
end = i
break
if end is None:
raise ValueError("No matching closing bracket found in metadata response")
result = json.loads(response_text[start : end + 1])
if not isinstance(result, list):
raise ValueError(f"Expected list, got {type(result)}")
return result
# ---------------------------------------------------------------------------
# Single-batch API call
# ---------------------------------------------------------------------------
def _generate_batch(
pdf_path: str,
page_start_0: int,
page_end_0: int,
page_start_1: int,
page_end_1: int,
) -> list[dict]:
"""Generate metadata for a contiguous range of pages.
Args:
pdf_path: Path to the full PDF on disk.
page_start_0: First page (0-indexed, inclusive) for PDF extraction.
page_end_0: Last page (0-indexed, inclusive) for PDF extraction.
page_start_1: First page (1-indexed) — used in the prompt text.
page_end_1: Last page (1-indexed) — used in the prompt text.
Returns:
List of metadata dicts for the pages in this batch.
"""
client = genai.Client(api_key=GOOGLE_API_KEY)
batch_pdf_bytes = extract_page_range_bytes(pdf_path, page_start_0, page_end_0)
pdf_part = types.Part.from_bytes(data=batch_pdf_bytes, mime_type="application/pdf")
num_batch_pages = page_end_1 - page_start_1 + 1
instruction_text = (
f"This PDF excerpt contains {num_batch_pages} page(s), "
f"corresponding to pages {page_start_1} through {page_end_1} of the full drawing set.\n"
f"Generate structured metadata for ALL {num_batch_pages} page(s). "
f"Use page numbers {page_start_1} through {page_end_1} (1-indexed). "
f"Return a JSON array with exactly {num_batch_pages} objects."
)
instruction_part = types.Part.from_text(text=instruction_text)
response = client.models.generate_content(
model=METADATA_MODEL,
contents=[types.Content(role="user", parts=[pdf_part, instruction_part])],
config=types.GenerateContentConfig(
system_instruction=METADATA_SYSTEM_PROMPT,
),
)
return _extract_json_array(response.text.strip())
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def generate_page_metadata(
pdf_path: str,
num_pages: int,
progress_callback=None,
) -> list[dict]:
"""Extract per-page structured metadata from a PDF using parallel batches.
The PDF is split into chunks of ``BATCH_SIZE`` pages. Each chunk is sent to
Gemini concurrently via a thread pool. Results are merged, any missing
pages are back-filled, and the list is returned sorted by page number.
Args:
pdf_path: Path to the full PDF.
num_pages: Total number of pages.
progress_callback: Optional ``(completed_batches, total_batches, page_range_str) -> None``
called after each batch finishes.
Returns a list of dicts (1-indexed page_num), one per page.
Raises on failure (caller is responsible for error handling).
"""
num_batches = math.ceil(num_pages / BATCH_SIZE)
logger.info(
"Starting parallel metadata generation: %d pages in %d batches of %d",
num_pages, num_batches, BATCH_SIZE,
)
all_results: list[dict] = []
errors: list[str] = []
completed_count = 0
with ThreadPoolExecutor(max_workers=num_batches) as executor:
futures = {}
for batch_idx in range(num_batches):
page_start_0 = batch_idx * BATCH_SIZE
page_end_0 = min(page_start_0 + BATCH_SIZE - 1, num_pages - 1)
page_start_1 = page_start_0 + 1
page_end_1 = page_end_0 + 1
future = executor.submit(
_generate_batch,
pdf_path,
page_start_0,
page_end_0,
page_start_1,
page_end_1,
)
futures[future] = (page_start_1, page_end_1)
for future in as_completed(futures):
batch_range = futures[future]
try:
batch_results = future.result()
all_results.extend(batch_results)
completed_count += 1
logger.info("Batch pages %d-%d complete: %d entries", batch_range[0], batch_range[1], len(batch_results))
if progress_callback is not None:
progress_callback(
completed_count,
num_batches,
f"Pages {batch_range[0]}-{batch_range[1]}",
)
except Exception as e:
completed_count += 1
errors.append(f"Batch pages {batch_range[0]}-{batch_range[1]} failed: {e}")
logger.exception("Batch pages %d-%d failed", batch_range[0], batch_range[1])
if progress_callback is not None:
progress_callback(
completed_count,
num_batches,
f"Pages {batch_range[0]}-{batch_range[1]} (failed)",
)
if errors and not all_results:
raise RuntimeError(
f"All metadata batches failed:\n" + "\n".join(errors)
)
if errors:
logger.warning("Some batches failed (results will have gaps): %s", errors)
# Metadata stays 1-indexed (as the model produced it) because it will be
# passed as context text to the planner model, which also uses 1-indexed.
# The planner's *output* is converted to 0-indexed in nodes/planner.py.
# Fill in any missing pages with minimal entries (1-indexed)
covered_pages = {item.get("page_num") for item in all_results}
for p in range(1, num_pages + 1):
if p not in covered_pages:
all_results.append({
"page_num": p,
"sheet_id": "unknown",
"sheet_title": "Unknown",
"discipline": "other",
"page_type": "other",
"description": "Metadata not extracted for this page.",
"key_elements": [],
"spatial_coverage": "",
})
# Sort by page number
all_results.sort(key=lambda x: x.get("page_num", 0))
return all_results