cryogenic22 commited on
Commit
c310b6b
·
verified ·
1 Parent(s): 218251e

Update agents/planning_agent.py

Browse files
Files changed (1) hide show
  1. agents/planning_agent.py +77 -99
agents/planning_agent.py CHANGED
@@ -1,13 +1,15 @@
 
 
 
 
 
1
  import os
2
- from typing import Dict, List, Optional, Any, Tuple
3
- from langchain_core.prompts import ChatPromptTemplate
4
- from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
5
- from langchain_core.runnables import RunnablePassthrough
6
- from langchain_anthropic import ChatAnthropic
7
- from pydantic import BaseModel, Field
8
  import json
 
 
 
9
 
10
- # Define task types and output schema
11
  class AnalysisPlan(BaseModel):
12
  """Planning agent output with analysis plan details"""
13
  problem_statement: str = Field(description="Refined problem statement based on the alert")
@@ -24,21 +26,20 @@ class PlanningAgent:
24
  """Agent responsible for planning the analysis workflow"""
25
 
26
  def __init__(self):
27
- """Initialize the planning agent with Claude API"""
28
- # Set up Claude API client
29
  api_key = os.getenv("ANTHROPIC_API_KEY")
30
  if not api_key:
31
  raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
32
 
33
- self.llm = ChatAnthropic(
34
- model="claude-3-haiku-20240307", # Changed to haiku to use less tokens
35
- anthropic_api_key=api_key,
36
- temperature=0.1
37
- )
 
38
 
39
- # Create planning prompt - fixed variable issue in the template
40
- self.planning_prompt = ChatPromptTemplate.from_messages([
41
- ("system", """You are an expert pharmaceutical analytics planning agent.
42
  Your task is to create a detailed analysis plan to investigate sales anomalies.
43
 
44
  For pharmaceutical sales analysis:
@@ -81,28 +82,62 @@ Your output should be a complete JSON-formatted analysis plan following this str
81
  ]
82
  }
83
 
84
- Be thorough in your planning but focus on creating a practical analysis workflow.
85
- Tasks should follow a logical sequence with proper dependencies.
86
- """),
87
- ("human", "{input}") # Changed from {alert} to {input}
88
- ])
89
 
90
- # Set up the planning chain
91
- self.planning_chain = (
92
- {"input": RunnablePassthrough()} # Use input as the key
93
- | self.planning_prompt
94
- | self.llm
95
- | StrOutputParser()
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def extract_json_from_text(self, text: str) -> Dict:
99
  """Extract JSON from text that might contain additional content"""
100
  try:
101
- # First, try to parse the entire text as JSON
102
  return json.loads(text)
103
  except json.JSONDecodeError:
104
- # If that fails, look for JSON block
105
- import re
106
  json_pattern = r'```json\s*([\s\S]*?)\s*```'
107
  match = re.search(json_pattern, text)
108
  if match:
@@ -111,7 +146,7 @@ Tasks should follow a logical sequence with proper dependencies.
111
  except json.JSONDecodeError:
112
  pass
113
 
114
- # Try a more aggressive approach to find JSON-like content
115
  json_pattern = r'({[\s\S]*})'
116
  match = re.search(json_pattern, text)
117
  if match:
@@ -120,70 +155,13 @@ Tasks should follow a logical sequence with proper dependencies.
120
  except json.JSONDecodeError:
121
  pass
122
 
 
123
  raise ValueError(f"Could not extract JSON from response: {text}")
124
-
125
- def create_analysis_plan(self, alert_description: str) -> Tuple[AnalysisPlan, Dict]:
126
- """Generate an analysis plan based on the alert description"""
127
- print("Planning Agent: Creating analysis plan...")
128
-
129
- try:
130
- # Execute the planning chain with the alert as input
131
- response = self.planning_chain.invoke(alert_description)
132
-
133
- # Extract and parse the response as JSON
134
- plan_dict = self.extract_json_from_text(response)
135
-
136
- # Convert to Pydantic model for validation and structure
137
- analysis_plan = AnalysisPlan.model_validate(plan_dict)
138
-
139
- return analysis_plan, plan_dict
140
- except Exception as e:
141
- print(f"Error creating analysis plan: {e}")
142
- raise
143
-
144
- def visualize_plan(self, plan: AnalysisPlan) -> Dict:
145
- """Generate visualization data for the analysis plan"""
146
- # Create nodes representing tasks
147
- nodes = []
148
- edges = []
149
-
150
- for task in plan.tasks:
151
- nodes.append({
152
- "id": f"task_{task['id']}",
153
- "label": task['name'],
154
- "type": "task",
155
- "agent": task['agent']
156
- })
157
-
158
- # Create edges based on dependencies
159
- for dep in task.get('dependencies', []):
160
- edges.append({
161
- "source": f"task_{dep}",
162
- "target": f"task_{task['id']}",
163
- "label": "depends on"
164
- })
165
-
166
- # Add data source nodes
167
- for i, src in enumerate(plan.required_data_sources):
168
- src_id = f"data_{i}"
169
- nodes.append({
170
- "id": src_id,
171
- "label": src['table'],
172
- "type": "data_source"
173
- })
174
-
175
- # Connect data sources to the data acquisition task
176
- data_task = next((t for t in plan.tasks if t['agent'] == 'data_agent'), None)
177
- if data_task:
178
- edges.append({
179
- "source": src_id,
180
- "target": f"task_{data_task['id']}",
181
- "label": "input"
182
- })
183
-
184
- return {
185
- "nodes": nodes,
186
- "edges": edges,
187
- "problem_statement": plan.problem_statement,
188
- "expected_insights": plan.expected_insights
189
- }
 
1
+ """
2
+ Simplified Planning Agent for Pharmaceutical Analytics
3
+ This version uses direct API calls instead of LangChain components
4
+ """
5
+
6
  import os
 
 
 
 
 
 
7
  import json
8
+ import re
9
+ from typing import Dict, List, Any, Tuple
10
+ from pydantic import BaseModel, Field
11
 
12
+ # Define analysis plan schema
13
  class AnalysisPlan(BaseModel):
14
  """Planning agent output with analysis plan details"""
15
  problem_statement: str = Field(description="Refined problem statement based on the alert")
 
26
  """Agent responsible for planning the analysis workflow"""
27
 
28
  def __init__(self):
29
+ """Initialize the planning agent"""
 
30
  api_key = os.getenv("ANTHROPIC_API_KEY")
31
  if not api_key:
32
  raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
33
 
34
+ self.api_key = api_key
35
+ print("Planning Agent initialized successfully")
36
+
37
+ def create_analysis_plan(self, alert_description: str) -> Tuple[AnalysisPlan, Dict]:
38
+ """Generate an analysis plan based on the alert description"""
39
+ print("Planning Agent: Creating analysis plan...")
40
 
41
+ # Create the system prompt and user message
42
+ system_prompt = """You are an expert pharmaceutical analytics planning agent.
 
43
  Your task is to create a detailed analysis plan to investigate sales anomalies.
44
 
45
  For pharmaceutical sales analysis:
 
82
  ]
83
  }
84
 
85
+ Be thorough but focus on creating a practical analysis workflow.
86
+ """
87
+
88
+ user_message = f"Create an analysis plan for the following alert: {alert_description}"
 
89
 
90
+ # Make direct API call to Claude
91
+ try:
92
+ import anthropic
93
+ client = anthropic.Anthropic(api_key=self.api_key)
94
+
95
+ # Use the correct API structure based on the Anthropic Python SDK version
96
+ try:
97
+ # For newer versions of the Anthropic SDK
98
+ response = client.messages.create(
99
+ model="claude-3-haiku-20240307",
100
+ max_tokens=2000,
101
+ temperature=0.2,
102
+ system=system_prompt,
103
+ messages=[
104
+ {"role": "user", "content": user_message}
105
+ ]
106
+ )
107
+ except TypeError:
108
+ # Fallback for older versions of the Anthropic SDK
109
+ response = client.messages.create(
110
+ model="claude-3-haiku-20240307",
111
+ max_tokens=2000,
112
+ temperature=0.2,
113
+ messages=[
114
+ {"role": "system", "content": system_prompt},
115
+ {"role": "user", "content": user_message}
116
+ ]
117
+ )
118
+
119
+ # Extract response content
120
+ response_text = response.content[0].text
121
+
122
+ # Extract JSON from the response
123
+ plan_dict = self.extract_json_from_text(response_text)
124
+
125
+ # Convert to Pydantic model for validation
126
+ analysis_plan = AnalysisPlan.model_validate(plan_dict)
127
+
128
+ return analysis_plan, plan_dict
129
+
130
+ except Exception as e:
131
+ print(f"Error creating analysis plan: {e}")
132
+ raise
133
 
134
  def extract_json_from_text(self, text: str) -> Dict:
135
  """Extract JSON from text that might contain additional content"""
136
  try:
137
+ # First try to parse the entire text as JSON
138
  return json.loads(text)
139
  except json.JSONDecodeError:
140
+ # Try to find JSON block with regex
 
141
  json_pattern = r'```json\s*([\s\S]*?)\s*```'
142
  match = re.search(json_pattern, text)
143
  if match:
 
146
  except json.JSONDecodeError:
147
  pass
148
 
149
+ # Try to find anything that looks like JSON
150
  json_pattern = r'({[\s\S]*})'
151
  match = re.search(json_pattern, text)
152
  if match:
 
155
  except json.JSONDecodeError:
156
  pass
157
 
158
+ # If all extraction attempts fail
159
  raise ValueError(f"Could not extract JSON from response: {text}")
160
+
161
+ # For testing
162
+ if __name__ == "__main__":
163
+ # Get API key from environment
164
+ agent = PlanningAgent()
165
+ alert = "Sales of DrugX down 15% in Northeast region over past 30 days compared to forecast."
166
+ plan, plan_dict = agent.create_analysis_plan(alert)
167
+ print(json.dumps(plan_dict, indent=2))