File size: 4,823 Bytes
e1ced8e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """compliance_planner node — dual-plan generation (crops + code queries)."""
from __future__ import annotations
import json
import re
from datetime import datetime
from google import genai
from google.genai import types
from config import GOOGLE_API_KEY, PLANNER_MODEL
from prompts.compliance_planner import COMPLIANCE_PLANNER_SYSTEM_PROMPT
from state import AgentMessage, CodeQuery, ComplianceState, CropTask
def compliance_planner(state: ComplianceState) -> dict:
"""Analyze page metadata + user question and produce dual plans for
image cropping AND code lookup."""
question = state["question"]
num_pages = state.get("num_pages", 0)
page_metadata_json = state.get("page_metadata_json", "")
investigation_round = state.get("investigation_round", 0)
client = genai.Client(api_key=GOOGLE_API_KEY)
question_text = (
f"USER COMPLIANCE QUESTION: {question}\n\n"
f"The PDF has {num_pages} pages (1-indexed, from page 1 to page {num_pages}).\n"
f"This is investigation round {investigation_round + 1}.\n\n"
)
if page_metadata_json:
question_text += f"PAGE METADATA:\n{page_metadata_json}"
else:
question_text += (
"No page metadata available. Based on the question alone, "
"plan what code lookups are needed. Crop tasks will use default pages."
)
response = client.models.generate_content(
model=PLANNER_MODEL,
contents=[types.Content(role="user", parts=[types.Part.from_text(text=question_text)])],
config=types.GenerateContentConfig(
system_instruction=COMPLIANCE_PLANNER_SYSTEM_PROMPT,
),
)
response_text = response.text.strip()
# Parse JSON response
json_match = re.search(r"\{.*\}", response_text, re.DOTALL)
target_pages: list[int] = []
legend_pages: list[int] = []
crop_tasks: list[CropTask] = []
code_queries: list[CodeQuery] = []
if json_match:
try:
parsed = json.loads(json_match.group())
valid_0indexed = set(range(num_pages))
target_pages = [
int(p) - 1 for p in parsed.get("target_pages", [])
if int(p) - 1 in valid_0indexed
]
legend_pages = [
int(p) - 1 for p in parsed.get("legend_pages", [])
if int(p) - 1 in valid_0indexed
]
for t in parsed.get("crop_tasks", []):
raw_page = int(t.get("page_num", 1))
crop_tasks.append(
CropTask(
page_num=raw_page - 1,
crop_instruction=t.get("crop_instruction", ""),
annotate=bool(t.get("annotate", False)),
annotation_prompt=t.get("annotation_prompt", ""),
label=t.get("label", f"Page {raw_page} crop"),
priority=int(t.get("priority", 1)),
)
)
for q in parsed.get("code_queries", []):
code_queries.append(
CodeQuery(
query=q.get("query", ""),
focus_area=q.get("focus_area", ""),
context=q.get("context", ""),
priority=int(q.get("priority", 0)),
)
)
except (json.JSONDecodeError, ValueError, KeyError):
pass
# Sort crop tasks by priority
crop_tasks.sort(key=lambda t: t["priority"])
# Fallback: if nothing identified, use first 5 pages
if not target_pages and not crop_tasks:
target_pages = list(range(min(num_pages, 5)))
# Build discussion log message
crop_summary = f"{len(crop_tasks)} crop tasks on pages {', '.join(str(p + 1) for p in target_pages[:5])}"
code_summary = f"{len(code_queries)} code queries"
if code_queries:
code_summary += f" ({', '.join(q['focus_area'] for q in code_queries[:3])})"
discussion_msg = AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="planner",
action="plan",
summary=f"Planned {crop_summary} and {code_summary}.",
detail=response_text,
evidence_refs=[],
)
return {
"target_pages": target_pages,
"legend_pages": legend_pages,
"crop_tasks": crop_tasks,
"code_queries": code_queries,
"discussion_log": [discussion_msg],
"status_message": [
f"Selected {len(target_pages)} pages ({len(legend_pages)} legends), "
f"planned {len(crop_tasks)} crop tasks, {len(code_queries)} code queries."
],
}
|