MukulRay commited on
Commit
0fb281f
·
1 Parent(s): 22e9366

Phase 4: planner agent with temporal decomposition and session context

Browse files
Files changed (2) hide show
  1. src/agents/planner.py +107 -0
  2. test_phase4.py +57 -0
src/agents/planner.py CHANGED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dotenv import load_dotenv
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.messages import SystemMessage, HumanMessage
5
+
6
+ from src.state import ResearchState, SessionContext
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # LLM — shared across agents, lazy init
13
+ # ---------------------------------------------------------------------------
14
+ _llm: ChatGroq | None = None
15
+
16
+ def get_llm() -> ChatGroq:
17
+ global _llm
18
+ if _llm is None:
19
+ _llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0.2)
20
+ return _llm
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Planner system prompt
25
+ # ---------------------------------------------------------------------------
26
+ PLANNER_SYSTEM = """You are the Planner agent in RECON, a multi-agent ML research navigator.
27
+
28
+ Your job: decompose the user's research query into exactly 2-3 sub-questions.
29
+ Each sub-question must target a DIFFERENT temporal range:
30
+ - Foundational: seminal/classic work that established the field
31
+ - Recent advances: work from the last 2-3 years
32
+ - Open/contested: where the field actively disagrees or has open problems
33
+
34
+ Rules:
35
+ - Output ONLY a numbered list. No preamble, no explanation.
36
+ - Each sub-question must be self-contained and searchable on its own.
37
+ - If session context is provided, do NOT re-ask questions already answered.
38
+ - Keep each sub-question under 20 words.
39
+
40
+ Example output:
41
+ 1. What are the foundational methods for KV cache compression in transformers?
42
+ 2. What are the most effective KV cache compression techniques published in 2023-2024?
43
+ 3. What are the open challenges and contested approaches in KV cache compression?"""
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Planner node — called by LangGraph
48
+ # ---------------------------------------------------------------------------
49
+ def planner_node(state: ResearchState) -> ResearchState:
50
+ """
51
+ Reads: original_query, session_context
52
+ Writes: sub_questions
53
+ """
54
+ query = state["original_query"]
55
+ session_ctx: SessionContext = state.get("session_context") or SessionContext()
56
+
57
+ # Build context block for the LLM
58
+ context_block = ""
59
+ if session_ctx.prior_queries:
60
+ prior = "\n".join(f"- {q}" for q in session_ctx.prior_queries[-3:])
61
+ context_block = f"\n\nAlready answered in this session:\n{prior}\nDo not repeat these."
62
+
63
+ user_prompt = f"Research query: {query}{context_block}"
64
+
65
+ logger.info(f"Planner decomposing: {query[:60]}")
66
+
67
+ try:
68
+ response = get_llm().invoke([
69
+ SystemMessage(content=PLANNER_SYSTEM),
70
+ HumanMessage(content=user_prompt),
71
+ ])
72
+ raw = response.content.strip()
73
+ except Exception as e:
74
+ logger.error(f"Planner LLM call failed: {e}")
75
+ # Fallback: use the original query as a single sub-question
76
+ return {**state, "sub_questions": [query]}
77
+
78
+ # Parse numbered list
79
+ sub_questions = _parse_numbered_list(raw)
80
+
81
+ if not sub_questions:
82
+ logger.warning("Planner returned unparseable output, using raw query")
83
+ sub_questions = [query]
84
+
85
+ logger.info(f"Planner produced {len(sub_questions)} sub-questions")
86
+ for i, q in enumerate(sub_questions, 1):
87
+ logger.info(f" {i}. {q}")
88
+
89
+ return {**state, "sub_questions": sub_questions}
90
+
91
+
92
+ def _parse_numbered_list(text: str) -> list[str]:
93
+ """Parse '1. question\n2. question' into a list of strings."""
94
+ import re
95
+ lines = text.strip().split("\n")
96
+ questions = []
97
+ for line in lines:
98
+ line = line.strip()
99
+ if not line:
100
+ continue
101
+ # Match lines starting with a number and period/dot
102
+ match = re.match(r"^\d+[\.\)]\s*(.+)$", line)
103
+ if match:
104
+ q = match.group(1).strip()
105
+ if len(q) > 10: # skip very short lines
106
+ questions.append(q)
107
+ return questions[:3] # max 3 sub-questions
test_phase4.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, logging
2
+ sys.path.insert(0, ".")
3
+ logging.basicConfig(level=logging.WARNING)
4
+
5
+ from src.state import ResearchState, SessionContext
6
+ from src.agents.planner import planner_node
7
+ from src.memory import init_db
8
+
9
+ init_db()
10
+ print("=== Phase 4: Planner Agent ===\n")
11
+
12
+ # Test 1: Fresh query, no session context
13
+ state: ResearchState = {
14
+ "original_query": "What is the current state of speculative decoding in LLMs?",
15
+ "session_id": "test-session-001",
16
+ "session_context": None,
17
+ "sub_questions": [],
18
+ "retrieved_papers": [],
19
+ "citation_graph": {},
20
+ "web_results": [],
21
+ "critic_verdict": "",
22
+ "critic_notes": "",
23
+ "rewritten_questions": [],
24
+ "retry_count": 0,
25
+ "synthesized_position": "",
26
+ "claim_confidences": [],
27
+ "session_update": None,
28
+ "export_md": "",
29
+ "decay_config": "linear",
30
+ "calibration_bin": "",
31
+ "latency_ms": 0.0,
32
+ }
33
+
34
+ result = planner_node(state)
35
+ print("Query: What is the current state of speculative decoding in LLMs?")
36
+ print(f"Sub-questions generated: {len(result['sub_questions'])}")
37
+ for i, q in enumerate(result['sub_questions'], 1):
38
+ print(f" {i}. {q}")
39
+
40
+ # Test 2: Query with session context (should avoid repeating)
41
+ print("\n--- With session context ---")
42
+ ctx = SessionContext(
43
+ prior_queries=["What is the current state of speculative decoding in LLMs?"],
44
+ prior_positions=["Speculative decoding reduces latency by 2-3x..."],
45
+ flagged_contradictions=[]
46
+ )
47
+ state2 = {**state,
48
+ "original_query": "What are the limitations of speculative decoding?",
49
+ "session_context": ctx,
50
+ }
51
+ result2 = planner_node(state2)
52
+ print("Query: What are the limitations of speculative decoding?")
53
+ print(f"Sub-questions generated: {len(result2['sub_questions'])}")
54
+ for i, q in enumerate(result2['sub_questions'], 1):
55
+ print(f" {i}. {q}")
56
+
57
+ print("\n✅ Phase 4 complete")