File size: 8,043 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""Background page metadata generator — extracts per-page descriptions from the full PDF.



Uses parallel batch processing: the PDF is split into 5-page chunks and each

chunk is sent to Gemini concurrently for faster metadata extraction.

"""
from __future__ import annotations

import json
import logging
import math
from concurrent.futures import ThreadPoolExecutor, as_completed

from google import genai
from google.genai import types

from config import GOOGLE_API_KEY, METADATA_MODEL
from prompts.metadata import METADATA_SYSTEM_PROMPT
from tools.pdf_processor import extract_page_range_bytes

logger = logging.getLogger(__name__)

# Number of PDF pages per batch sent to Gemini in parallel.
BATCH_SIZE = 5


# ---------------------------------------------------------------------------
# JSON extraction helper
# ---------------------------------------------------------------------------

def _extract_json_array(response_text: str) -> list[dict]:
    """Extract the outermost balanced JSON array from a response string."""
    start = response_text.find("[")
    if start == -1:
        raise ValueError("No JSON array found in metadata generation response")

    depth = 0
    end = None
    for i in range(start, len(response_text)):
        if response_text[i] == "[":
            depth += 1
        elif response_text[i] == "]":
            depth -= 1
            if depth == 0:
                end = i
                break

    if end is None:
        raise ValueError("No matching closing bracket found in metadata response")

    result = json.loads(response_text[start : end + 1])
    if not isinstance(result, list):
        raise ValueError(f"Expected list, got {type(result)}")
    return result


# ---------------------------------------------------------------------------
# Single-batch API call
# ---------------------------------------------------------------------------

def _generate_batch(

    pdf_path: str,

    page_start_0: int,

    page_end_0: int,

    page_start_1: int,

    page_end_1: int,

) -> list[dict]:
    """Generate metadata for a contiguous range of pages.



    Args:

        pdf_path: Path to the full PDF on disk.

        page_start_0: First page (0-indexed, inclusive) for PDF extraction.

        page_end_0: Last page (0-indexed, inclusive) for PDF extraction.

        page_start_1: First page (1-indexed) — used in the prompt text.

        page_end_1: Last page (1-indexed) — used in the prompt text.



    Returns:

        List of metadata dicts for the pages in this batch.

    """
    client = genai.Client(api_key=GOOGLE_API_KEY)

    batch_pdf_bytes = extract_page_range_bytes(pdf_path, page_start_0, page_end_0)
    pdf_part = types.Part.from_bytes(data=batch_pdf_bytes, mime_type="application/pdf")

    num_batch_pages = page_end_1 - page_start_1 + 1
    instruction_text = (
        f"This PDF excerpt contains {num_batch_pages} page(s), "
        f"corresponding to pages {page_start_1} through {page_end_1} of the full drawing set.\n"
        f"Generate structured metadata for ALL {num_batch_pages} page(s). "
        f"Use page numbers {page_start_1} through {page_end_1} (1-indexed). "
        f"Return a JSON array with exactly {num_batch_pages} objects."
    )
    instruction_part = types.Part.from_text(text=instruction_text)

    response = client.models.generate_content(
        model=METADATA_MODEL,
        contents=[types.Content(role="user", parts=[pdf_part, instruction_part])],
        config=types.GenerateContentConfig(
            system_instruction=METADATA_SYSTEM_PROMPT,
        ),
    )

    return _extract_json_array(response.text.strip())


# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------

def generate_page_metadata(

    pdf_path: str,

    num_pages: int,

    progress_callback=None,

) -> list[dict]:
    """Extract per-page structured metadata from a PDF using parallel batches.



    The PDF is split into chunks of ``BATCH_SIZE`` pages. Each chunk is sent to

    Gemini concurrently via a thread pool.  Results are merged, any missing

    pages are back-filled, and the list is returned sorted by page number.



    Args:

        pdf_path: Path to the full PDF.

        num_pages: Total number of pages.

        progress_callback: Optional ``(completed_batches, total_batches, page_range_str) -> None``

            called after each batch finishes.



    Returns a list of dicts (1-indexed page_num), one per page.

    Raises on failure (caller is responsible for error handling).

    """
    num_batches = math.ceil(num_pages / BATCH_SIZE)
    logger.info(
        "Starting parallel metadata generation: %d pages in %d batches of %d",
        num_pages, num_batches, BATCH_SIZE,
    )

    all_results: list[dict] = []
    errors: list[str] = []
    completed_count = 0

    with ThreadPoolExecutor(max_workers=num_batches) as executor:
        futures = {}
        for batch_idx in range(num_batches):
            page_start_0 = batch_idx * BATCH_SIZE
            page_end_0 = min(page_start_0 + BATCH_SIZE - 1, num_pages - 1)
            page_start_1 = page_start_0 + 1
            page_end_1 = page_end_0 + 1

            future = executor.submit(
                _generate_batch,
                pdf_path,
                page_start_0,
                page_end_0,
                page_start_1,
                page_end_1,
            )
            futures[future] = (page_start_1, page_end_1)

        for future in as_completed(futures):
            batch_range = futures[future]
            try:
                batch_results = future.result()
                all_results.extend(batch_results)
                completed_count += 1
                logger.info("Batch pages %d-%d complete: %d entries", batch_range[0], batch_range[1], len(batch_results))
                if progress_callback is not None:
                    progress_callback(
                        completed_count,
                        num_batches,
                        f"Pages {batch_range[0]}-{batch_range[1]}",
                    )
            except Exception as e:
                completed_count += 1
                errors.append(f"Batch pages {batch_range[0]}-{batch_range[1]} failed: {e}")
                logger.exception("Batch pages %d-%d failed", batch_range[0], batch_range[1])
                if progress_callback is not None:
                    progress_callback(
                        completed_count,
                        num_batches,
                        f"Pages {batch_range[0]}-{batch_range[1]} (failed)",
                    )

    if errors and not all_results:
        raise RuntimeError(
            f"All metadata batches failed:\n" + "\n".join(errors)
        )

    if errors:
        logger.warning("Some batches failed (results will have gaps): %s", errors)

    # Metadata stays 1-indexed (as the model produced it) because it will be
    # passed as context text to the planner model, which also uses 1-indexed.
    # The planner's *output* is converted to 0-indexed in nodes/planner.py.

    # Fill in any missing pages with minimal entries (1-indexed)
    covered_pages = {item.get("page_num") for item in all_results}
    for p in range(1, num_pages + 1):
        if p not in covered_pages:
            all_results.append({
                "page_num": p,
                "sheet_id": "unknown",
                "sheet_title": "Unknown",
                "discipline": "other",
                "page_type": "other",
                "description": "Metadata not extracted for this page.",
                "key_elements": [],
                "spatial_coverage": "",
            })

    # Sort by page number
    all_results.sort(key=lambda x: x.get("page_num", 0))

    return all_results