File size: 4,823 Bytes
e1ced8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""compliance_planner node — dual-plan generation (crops + code queries)."""
from __future__ import annotations

import json
import re
from datetime import datetime

from google import genai
from google.genai import types

from config import GOOGLE_API_KEY, PLANNER_MODEL
from prompts.compliance_planner import COMPLIANCE_PLANNER_SYSTEM_PROMPT
from state import AgentMessage, CodeQuery, ComplianceState, CropTask


def compliance_planner(state: ComplianceState) -> dict:
    """Analyze page metadata + user question and produce dual plans for

    image cropping AND code lookup."""
    question = state["question"]
    num_pages = state.get("num_pages", 0)
    page_metadata_json = state.get("page_metadata_json", "")
    investigation_round = state.get("investigation_round", 0)

    client = genai.Client(api_key=GOOGLE_API_KEY)

    question_text = (
        f"USER COMPLIANCE QUESTION: {question}\n\n"
        f"The PDF has {num_pages} pages (1-indexed, from page 1 to page {num_pages}).\n"
        f"This is investigation round {investigation_round + 1}.\n\n"
    )

    if page_metadata_json:
        question_text += f"PAGE METADATA:\n{page_metadata_json}"
    else:
        question_text += (
            "No page metadata available. Based on the question alone, "
            "plan what code lookups are needed. Crop tasks will use default pages."
        )

    response = client.models.generate_content(
        model=PLANNER_MODEL,
        contents=[types.Content(role="user", parts=[types.Part.from_text(text=question_text)])],
        config=types.GenerateContentConfig(
            system_instruction=COMPLIANCE_PLANNER_SYSTEM_PROMPT,
        ),
    )

    response_text = response.text.strip()

    # Parse JSON response
    json_match = re.search(r"\{.*\}", response_text, re.DOTALL)

    target_pages: list[int] = []
    legend_pages: list[int] = []
    crop_tasks: list[CropTask] = []
    code_queries: list[CodeQuery] = []

    if json_match:
        try:
            parsed = json.loads(json_match.group())
            valid_0indexed = set(range(num_pages))

            target_pages = [
                int(p) - 1 for p in parsed.get("target_pages", [])
                if int(p) - 1 in valid_0indexed
            ]
            legend_pages = [
                int(p) - 1 for p in parsed.get("legend_pages", [])
                if int(p) - 1 in valid_0indexed
            ]

            for t in parsed.get("crop_tasks", []):
                raw_page = int(t.get("page_num", 1))
                crop_tasks.append(
                    CropTask(
                        page_num=raw_page - 1,
                        crop_instruction=t.get("crop_instruction", ""),
                        annotate=bool(t.get("annotate", False)),
                        annotation_prompt=t.get("annotation_prompt", ""),
                        label=t.get("label", f"Page {raw_page} crop"),
                        priority=int(t.get("priority", 1)),
                    )
                )

            for q in parsed.get("code_queries", []):
                code_queries.append(
                    CodeQuery(
                        query=q.get("query", ""),
                        focus_area=q.get("focus_area", ""),
                        context=q.get("context", ""),
                        priority=int(q.get("priority", 0)),
                    )
                )

        except (json.JSONDecodeError, ValueError, KeyError):
            pass

    # Sort crop tasks by priority
    crop_tasks.sort(key=lambda t: t["priority"])

    # Fallback: if nothing identified, use first 5 pages
    if not target_pages and not crop_tasks:
        target_pages = list(range(min(num_pages, 5)))

    # Build discussion log message
    crop_summary = f"{len(crop_tasks)} crop tasks on pages {', '.join(str(p + 1) for p in target_pages[:5])}"
    code_summary = f"{len(code_queries)} code queries"
    if code_queries:
        code_summary += f" ({', '.join(q['focus_area'] for q in code_queries[:3])})"

    discussion_msg = AgentMessage(
        timestamp=datetime.now().strftime("%H:%M:%S"),
        agent="planner",
        action="plan",
        summary=f"Planned {crop_summary} and {code_summary}.",
        detail=response_text,
        evidence_refs=[],
    )

    return {
        "target_pages": target_pages,
        "legend_pages": legend_pages,
        "crop_tasks": crop_tasks,
        "code_queries": code_queries,
        "discussion_log": [discussion_msg],
        "status_message": [
            f"Selected {len(target_pages)} pages ({len(legend_pages)} legends), "
            f"planned {len(crop_tasks)} crop tasks, {len(code_queries)} code queries."
        ],
    }