AgenticAI-RAG / src /planning /cot_planner.py
GreymanT's picture
Upload 80 files
8bf4d58 verified
"""Chain-of-Thought (CoT) planner implementation."""
import logging
from typing import List, Dict, Any, Optional, Callable
from src.core.config import get_settings
logger = logging.getLogger(__name__)
class CoTPlanner:
"""Chain-of-Thought planner for multi-step reasoning."""
def __init__(
self,
max_steps: int = 10,
enable_reflection: bool = True,
):
"""
Initialize CoT planner.
Args:
max_steps: Maximum number of reasoning steps
enable_reflection: Whether to enable reflection on steps
"""
self.max_steps = max_steps
self.enable_reflection = enable_reflection
def plan(
self,
query: str,
context: Optional[str] = None,
llm_call: Optional[Callable] = None,
) -> Dict[str, Any]:
"""
Generate a plan using Chain-of-Thought methodology.
Args:
query: User query
context: Optional context information
llm_call: Function to call LLM
Returns:
Plan dictionary with reasoning chain
"""
if not llm_call:
raise ValueError("llm_call function is required")
reasoning_steps = []
current_understanding = context or ""
# Initial reasoning prompt
prompt = self._build_initial_prompt(query, context)
for step_num in range(self.max_steps):
try:
# Get LLM response
response = llm_call(prompt)
# Parse response
step = self._parse_cot_response(response, step_num)
reasoning_steps.append(step)
# Check if we have a conclusion
if step.get("is_conclusion", False):
return {
"query": query,
"reasoning_steps": reasoning_steps,
"conclusion": step.get("content", ""),
"total_steps": step_num + 1,
}
# Build next step prompt
prompt = self._build_next_step_prompt(
query=query,
context=current_understanding,
previous_steps=reasoning_steps,
current_step=step,
)
# Optional: Reflection step
if self.enable_reflection and step_num > 0 and step_num % 3 == 0:
reflection = self._reflect_on_progress(reasoning_steps, llm_call)
if reflection:
reasoning_steps.append({
"step": step_num + 1,
"type": "reflection",
"content": reflection,
})
except Exception as e:
logger.error(f"Error in CoT step {step_num}: {e}")
reasoning_steps.append({
"step": step_num + 1,
"type": "error",
"content": f"Error: {str(e)}",
})
# Max steps reached
return {
"query": query,
"reasoning_steps": reasoning_steps,
"conclusion": None,
"total_steps": self.max_steps,
"status": "max_steps_reached",
}
def _build_initial_prompt(
self,
query: str,
context: Optional[str],
) -> str:
"""Build initial CoT prompt."""
prompt_parts = [
"You are a helpful assistant that uses Chain-of-Thought reasoning.",
"Break down complex problems into smaller steps and reason through them step by step.",
"",
"Format your reasoning as:",
"Step 1: <your reasoning>",
"Step 2: <your reasoning>",
"...",
"Conclusion: <final answer>",
"",
]
if context:
prompt_parts.extend([
f"Context: {context}",
"",
])
prompt_parts.extend([
f"Question: {query}",
"",
"Begin your reasoning:",
])
return "\n".join(prompt_parts)
def _build_next_step_prompt(
self,
query: str,
context: str,
previous_steps: List[Dict[str, Any]],
current_step: Dict[str, Any],
) -> str:
"""Build prompt for next reasoning step."""
prompt_parts = [
"Continue your Chain-of-Thought reasoning.",
"",
f"Question: {query}",
"",
]
if context:
prompt_parts.append(f"Context: {context}\n")
prompt_parts.append("Previous reasoning steps:")
for step in previous_steps[-3:]: # Last 3 steps
if step.get("type") != "reflection":
prompt_parts.append(f"Step {step.get('step', '?')}: {step.get('content', '')}")
prompt_parts.extend([
"",
"What is the next step in your reasoning?",
])
return "\n".join(prompt_parts)
def _parse_cot_response(
self,
response: str,
step_num: int,
) -> Dict[str, Any]:
"""Parse CoT response into structured step."""
response = response.strip()
# Check for conclusion
if "Conclusion:" in response or response.startswith("Conclusion"):
conclusion_part = response.split("Conclusion:")[-1].strip()
return {
"step": step_num + 1,
"type": "conclusion",
"content": conclusion_part,
"is_conclusion": True,
}
# Extract step number and content
step_content = response
step_type = "reasoning"
# Try to extract step number
if response.startswith(f"Step {step_num + 1}:"):
step_content = response.split(":", 1)[1].strip()
elif "Step" in response:
# Extract any step number
parts = response.split("Step", 1)
if len(parts) > 1:
step_part = parts[1].split(":", 1)
if len(step_part) > 1:
step_content = step_part[1].strip()
return {
"step": step_num + 1,
"type": step_type,
"content": step_content,
"is_conclusion": False,
}
def _reflect_on_progress(
self,
steps: List[Dict[str, Any]],
llm_call: Callable,
) -> Optional[str]:
"""Reflect on reasoning progress."""
if not steps:
return None
reflection_prompt = (
"Review the following reasoning steps and provide a brief reflection:\n\n"
)
for step in steps[-3:]:
reflection_prompt += f"Step {step.get('step', '?')}: {step.get('content', '')}\n"
reflection_prompt += "\nReflection:"
try:
reflection = llm_call(reflection_prompt)
return reflection.strip()
except Exception as e:
logger.error(f"Error in reflection: {e}")
return None
def generate_execution_plan(
self,
reasoning_steps: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
Generate an execution plan from reasoning steps.
Args:
reasoning_steps: List of reasoning steps
Returns:
List of execution actions
"""
execution_plan = []
for step in reasoning_steps:
if step.get("type") == "conclusion":
execution_plan.append({
"action": "return_answer",
"content": step.get("content", ""),
})
elif step.get("type") == "reasoning":
# Extract potential actions from reasoning
content = step.get("content", "")
if "search" in content.lower():
execution_plan.append({
"action": "search",
"reasoning": content,
})
elif "calculate" in content.lower() or "compute" in content.lower():
execution_plan.append({
"action": "calculate",
"reasoning": content,
})
return execution_plan