cryogenic22 commited on
Commit
9524b20
·
verified ·
1 Parent(s): bbe7455

Create agents/planning_agent.py

Browse files
Files changed (1) hide show
  1. agents/planning_agent.py +215 -0
agents/planning_agent.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Any, Tuple
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from langchain_anthropic import ChatAnthropic
7
+ from pydantic import BaseModel, Field
8
+ import json
9
+
10
+ # Define task types and output schema
11
+ class AnalysisPlan(BaseModel):
12
+ """Planning agent output with analysis plan details"""
13
+ problem_statement: str = Field(description="Refined problem statement based on the alert")
14
+ required_data_sources: List[Dict[str, str]] = Field(
15
+ description="List of data sources needed with table name and purpose")
16
+ analysis_approaches: List[Dict[str, str]] = Field(
17
+ description="List of analytical approaches to be used with type and purpose")
18
+ tasks: List[Dict[str, Any]] = Field(
19
+ description="Ordered list of tasks to execute with dependencies")
20
+ expected_insights: List[str] = Field(
21
+ description="List of expected insights that would answer the problem")
22
+
23
+ class PlanningAgent:
24
+ """Agent responsible for planning the analysis workflow"""
25
+
26
+ def __init__(self):
27
+ """Initialize the planning agent with Claude API"""
28
+ # Set up Claude API client
29
+ api_key = os.getenv("ANTHROPIC_API_KEY")
30
+ if not api_key:
31
+ raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
32
+
33
+ self.llm = ChatAnthropic(
34
+ model="claude-3-haiku-20240307",
35
+ anthropic_api_key=api_key,
36
+ temperature=0.1
37
+ )
38
+
39
+ # Create planning prompt
40
+ self.planning_prompt = ChatPromptTemplate.from_messages([
41
+ ("system", """You are an expert pharmaceutical analytics planning agent.
42
+ Your task is to create a detailed analysis plan to investigate sales anomalies.
43
+
44
+ For pharmaceutical sales analysis:
45
+ - Consider product performance, competitor activities, prescriber behavior
46
+ - Include geographic, temporal, and demographic dimensions in your analysis
47
+ - Consider both internal factors (supply, marketing) and external factors (market events, seasonality)
48
+
49
+ Your output should be a complete JSON-formatted analysis plan following this structure:
50
+ {
51
+ "problem_statement": "Clear definition of the problem to solve",
52
+ "required_data_sources": [
53
+ {"table": "sales", "purpose": "Core sales metrics analysis"},
54
+ ...
55
+ ],
56
+ "analysis_approaches": [
57
+ {"type": "time_series_decomposition", "purpose": "Separate trend from seasonality"},
58
+ ...
59
+ ],
60
+ "tasks": [
61
+ {
62
+ "id": 1,
63
+ "name": "Data acquisition",
64
+ "description": "Pull relevant data from sources",
65
+ "agent": "data_agent",
66
+ "dependencies": [],
67
+ "expected_output": "Cleaned datasets for analysis"
68
+ },
69
+ ...
70
+ ],
71
+ "expected_insights": [
72
+ "Primary factors contributing to sales decline",
73
+ ...
74
+ ]
75
+ }
76
+
77
+ Be thorough in your planning but focus on creating a practical analysis workflow.
78
+ Tasks should follow a logical sequence with proper dependencies.
79
+ """),
80
+ ("human", "{input}")
81
+ ])
82
+
83
+ # Set up the planning chain
84
+ self.planning_chain = (
85
+ {"input": RunnablePassthrough()}
86
+ | self.planning_prompt
87
+ | self.llm
88
+ | StrOutputParser()
89
+ )
90
+
91
+ def extract_json_from_text(self, text: str) -> Dict:
92
+ """Extract JSON from text that might contain additional content"""
93
+ try:
94
+ # First, try to parse the entire text as JSON
95
+ return json.loads(text)
96
+ except json.JSONDecodeError:
97
+ # If that fails, look for JSON block
98
+ import re
99
+ json_pattern = r'```json\s*([\s\S]*?)\s*```'
100
+ match = re.search(json_pattern, text)
101
+ if match:
102
+ try:
103
+ return json.loads(match.group(1))
104
+ except json.JSONDecodeError:
105
+ pass
106
+
107
+ # Try a more aggressive approach to find JSON-like content
108
+ json_pattern = r'({[\s\S]*})'
109
+ match = re.search(json_pattern, text)
110
+ if match:
111
+ try:
112
+ return json.loads(match.group(1))
113
+ except json.JSONDecodeError:
114
+ pass
115
+
116
+ raise ValueError(f"Could not extract JSON from response: {text}")
117
+
118
+ def create_analysis_plan(self, alert_description: str) -> Tuple[AnalysisPlan, Dict]:
119
+ """Generate an analysis plan based on the alert description"""
120
+ print("Planning Agent: Creating analysis plan...")
121
+
122
+ # Format the input for the planning prompt
123
+ input_text = f"""
124
+ Alert: {alert_description}
125
+
126
+ Create a detailed analysis plan to investigate this issue. Include:
127
+ 1. A clear problem statement
128
+ 2. Required data sources from our pharma database
129
+ 3. Analytical approaches to identify root causes
130
+ 4. A sequence of tasks with dependencies
131
+ 5. Expected insights that would solve the problem
132
+
133
+ Available data tables:
134
+ - sales: Daily sales data (sale_date, product_id, region_id, territory_id, prescriber_id, pharmacy_id, units_sold, revenue, cost, margin)
135
+ - products: Product information (product_id, product_name, therapeutic_area, molecule, launch_date, status, list_price)
136
+ - regions: Geographic regions (region_id, region_name, country, division, population)
137
+ - territories: Sales territories (territory_id, territory_name, region_id, sales_rep_id)
138
+ - prescribers: Physician information (prescriber_id, name, specialty, practice_type, territory_id, decile)
139
+ - pharmacies: Pharmacy information (pharmacy_id, name, address, territory_id, pharmacy_type, monthly_rx_volume)
140
+ - competitor_products: Competitor information (competitor_product_id, product_name, manufacturer, therapeutic_area, molecule, launch_date, list_price, competing_with_product_id)
141
+ - marketing_campaigns: Marketing activities (campaign_id, campaign_name, start_date, end_date, product_id, campaign_type, target_audience, channels, budget, spend)
142
+ - market_events: Industry events (event_id, event_date, event_type, description, affected_products, affected_regions, impact_score)
143
+ - sales_targets: Performance targets (target_id, product_id, region_id, period, target_units, target_revenue)
144
+ - distribution_centers: Supply chain (dc_id, dc_name, region_id, inventory_capacity)
145
+ - inventory: Stock levels (inventory_id, product_id, dc_id, date, units_available, units_allocated, units_in_transit, days_of_supply)
146
+ - external_factors: External influences (factor_id, date, region_id, factor_type, factor_value, description)
147
+ """
148
+
149
+ # Execute the planning chain
150
+ response = self.planning_chain.invoke(input_text)
151
+
152
+ # Extract and parse the response as JSON
153
+ plan_dict = self.extract_json_from_text(response)
154
+
155
+ # Convert to Pydantic model for validation and structure
156
+ analysis_plan = AnalysisPlan.model_validate(plan_dict)
157
+
158
+ return analysis_plan, plan_dict
159
+
160
+ def visualize_plan(self, plan: AnalysisPlan) -> Dict:
161
+ """Generate visualization data for the analysis plan"""
162
+ # Create nodes representing tasks
163
+ nodes = []
164
+ edges = []
165
+
166
+ for task in plan.tasks:
167
+ nodes.append({
168
+ "id": f"task_{task['id']}",
169
+ "label": task['name'],
170
+ "type": "task",
171
+ "agent": task['agent']
172
+ })
173
+
174
+ # Create edges based on dependencies
175
+ for dep in task.get('dependencies', []):
176
+ edges.append({
177
+ "source": f"task_{dep}",
178
+ "target": f"task_{task['id']}",
179
+ "label": "depends on"
180
+ })
181
+
182
+ # Add data source nodes
183
+ for i, src in enumerate(plan.required_data_sources):
184
+ src_id = f"data_{i}"
185
+ nodes.append({
186
+ "id": src_id,
187
+ "label": src['table'],
188
+ "type": "data_source"
189
+ })
190
+
191
+ # Connect data sources to the data acquisition task
192
+ data_task = next((t for t in plan.tasks if t['agent'] == 'data_agent'), None)
193
+ if data_task:
194
+ edges.append({
195
+ "source": src_id,
196
+ "target": f"task_{data_task['id']}",
197
+ "label": "input"
198
+ })
199
+
200
+ return {
201
+ "nodes": nodes,
202
+ "edges": edges,
203
+ "problem_statement": plan.problem_statement,
204
+ "expected_insights": plan.expected_insights
205
+ }
206
+
207
+ # For testing
208
+ if __name__ == "__main__":
209
+ # Set API key for testing
210
+ os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
211
+
212
+ agent = PlanningAgent()
213
+ alert = "Sales of DrugX down 15% in Northeast region over past 30 days compared to forecast."
214
+ plan, _ = agent.create_analysis_plan(alert)
215
+ print(json.dumps(plan.model_dump(), indent=2))