File size: 8,043 Bytes
e1ced8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """Background page metadata generator — extracts per-page descriptions from the full PDF.
Uses parallel batch processing: the PDF is split into 5-page chunks and each
chunk is sent to Gemini concurrently for faster metadata extraction.
"""
from __future__ import annotations
import json
import logging
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from google import genai
from google.genai import types
from config import GOOGLE_API_KEY, METADATA_MODEL
from prompts.metadata import METADATA_SYSTEM_PROMPT
from tools.pdf_processor import extract_page_range_bytes
logger = logging.getLogger(__name__)
# Number of PDF pages per batch sent to Gemini in parallel.
BATCH_SIZE = 5
# ---------------------------------------------------------------------------
# JSON extraction helper
# ---------------------------------------------------------------------------
def _extract_json_array(response_text: str) -> list[dict]:
"""Extract the outermost balanced JSON array from a response string."""
start = response_text.find("[")
if start == -1:
raise ValueError("No JSON array found in metadata generation response")
depth = 0
end = None
for i in range(start, len(response_text)):
if response_text[i] == "[":
depth += 1
elif response_text[i] == "]":
depth -= 1
if depth == 0:
end = i
break
if end is None:
raise ValueError("No matching closing bracket found in metadata response")
result = json.loads(response_text[start : end + 1])
if not isinstance(result, list):
raise ValueError(f"Expected list, got {type(result)}")
return result
# ---------------------------------------------------------------------------
# Single-batch API call
# ---------------------------------------------------------------------------
def _generate_batch(
pdf_path: str,
page_start_0: int,
page_end_0: int,
page_start_1: int,
page_end_1: int,
) -> list[dict]:
"""Generate metadata for a contiguous range of pages.
Args:
pdf_path: Path to the full PDF on disk.
page_start_0: First page (0-indexed, inclusive) for PDF extraction.
page_end_0: Last page (0-indexed, inclusive) for PDF extraction.
page_start_1: First page (1-indexed) — used in the prompt text.
page_end_1: Last page (1-indexed) — used in the prompt text.
Returns:
List of metadata dicts for the pages in this batch.
"""
client = genai.Client(api_key=GOOGLE_API_KEY)
batch_pdf_bytes = extract_page_range_bytes(pdf_path, page_start_0, page_end_0)
pdf_part = types.Part.from_bytes(data=batch_pdf_bytes, mime_type="application/pdf")
num_batch_pages = page_end_1 - page_start_1 + 1
instruction_text = (
f"This PDF excerpt contains {num_batch_pages} page(s), "
f"corresponding to pages {page_start_1} through {page_end_1} of the full drawing set.\n"
f"Generate structured metadata for ALL {num_batch_pages} page(s). "
f"Use page numbers {page_start_1} through {page_end_1} (1-indexed). "
f"Return a JSON array with exactly {num_batch_pages} objects."
)
instruction_part = types.Part.from_text(text=instruction_text)
response = client.models.generate_content(
model=METADATA_MODEL,
contents=[types.Content(role="user", parts=[pdf_part, instruction_part])],
config=types.GenerateContentConfig(
system_instruction=METADATA_SYSTEM_PROMPT,
),
)
return _extract_json_array(response.text.strip())
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def generate_page_metadata(
pdf_path: str,
num_pages: int,
progress_callback=None,
) -> list[dict]:
"""Extract per-page structured metadata from a PDF using parallel batches.
The PDF is split into chunks of ``BATCH_SIZE`` pages. Each chunk is sent to
Gemini concurrently via a thread pool. Results are merged, any missing
pages are back-filled, and the list is returned sorted by page number.
Args:
pdf_path: Path to the full PDF.
num_pages: Total number of pages.
progress_callback: Optional ``(completed_batches, total_batches, page_range_str) -> None``
called after each batch finishes.
Returns a list of dicts (1-indexed page_num), one per page.
Raises on failure (caller is responsible for error handling).
"""
num_batches = math.ceil(num_pages / BATCH_SIZE)
logger.info(
"Starting parallel metadata generation: %d pages in %d batches of %d",
num_pages, num_batches, BATCH_SIZE,
)
all_results: list[dict] = []
errors: list[str] = []
completed_count = 0
with ThreadPoolExecutor(max_workers=num_batches) as executor:
futures = {}
for batch_idx in range(num_batches):
page_start_0 = batch_idx * BATCH_SIZE
page_end_0 = min(page_start_0 + BATCH_SIZE - 1, num_pages - 1)
page_start_1 = page_start_0 + 1
page_end_1 = page_end_0 + 1
future = executor.submit(
_generate_batch,
pdf_path,
page_start_0,
page_end_0,
page_start_1,
page_end_1,
)
futures[future] = (page_start_1, page_end_1)
for future in as_completed(futures):
batch_range = futures[future]
try:
batch_results = future.result()
all_results.extend(batch_results)
completed_count += 1
logger.info("Batch pages %d-%d complete: %d entries", batch_range[0], batch_range[1], len(batch_results))
if progress_callback is not None:
progress_callback(
completed_count,
num_batches,
f"Pages {batch_range[0]}-{batch_range[1]}",
)
except Exception as e:
completed_count += 1
errors.append(f"Batch pages {batch_range[0]}-{batch_range[1]} failed: {e}")
logger.exception("Batch pages %d-%d failed", batch_range[0], batch_range[1])
if progress_callback is not None:
progress_callback(
completed_count,
num_batches,
f"Pages {batch_range[0]}-{batch_range[1]} (failed)",
)
if errors and not all_results:
raise RuntimeError(
f"All metadata batches failed:\n" + "\n".join(errors)
)
if errors:
logger.warning("Some batches failed (results will have gaps): %s", errors)
# Metadata stays 1-indexed (as the model produced it) because it will be
# passed as context text to the planner model, which also uses 1-indexed.
# The planner's *output* is converted to 0-indexed in nodes/planner.py.
# Fill in any missing pages with minimal entries (1-indexed)
covered_pages = {item.get("page_num") for item in all_results}
for p in range(1, num_pages + 1):
if p not in covered_pages:
all_results.append({
"page_num": p,
"sheet_id": "unknown",
"sheet_title": "Unknown",
"discipline": "other",
"page_type": "other",
"description": "Metadata not extracted for this page.",
"key_elements": [],
"spatial_coverage": "",
})
# Sort by page number
all_results.sort(key=lambda x: x.get("page_num", 0))
return all_results
|