cryogenic22 commited on
Commit
72ca2bc
·
verified ·
1 Parent(s): ca8485b

Create insight_agent.py

Browse files
Files changed (1) hide show
  1. agents/insight_agent.py +419 -0
agents/insight_agent.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Dict, List, Any, Tuple, Optional
6
+ from pydantic import BaseModel, Field
7
+ from langchain_anthropic import ChatAnthropic
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ import re
11
+ from datetime import datetime
12
+
13
+ class InsightRequest(BaseModel):
14
+ """Structure for an insight generation request"""
15
+ request_id: str
16
+ original_problem: str
17
+ analysis_results: Dict[str, Any]
18
+ validation_results: Dict[str, Any]
19
+ target_audience: str = "executive" # Options: executive, analyst, data scientist
20
+
21
+ class InsightCard(BaseModel):
22
+ """Structure for an insight card"""
23
+ card_id: str
24
+ title: str
25
+ description: str
26
+ key_findings: List[Dict[str, Any]]
27
+ charts: List[str] = None
28
+ metrics: Dict[str, Any] = None
29
+ action_items: List[Dict[str, Any]] = None
30
+ confidence: float
31
+ timestamp: datetime
32
+
33
+ class InsightsAgent:
34
+ """Agent responsible for generating insight cards and visualizations"""
35
+
36
+ def __init__(self):
37
+ """Initialize the insights agent"""
38
+ # Set up Claude API client
39
+ api_key = os.getenv("ANTHROPIC_API_KEY")
40
+ if not api_key:
41
+ raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
42
+
43
+ self.llm = ChatAnthropic(
44
+ model="claude-3-haiku-20240307",
45
+ anthropic_api_key=api_key,
46
+ temperature=0.2
47
+ )
48
+
49
+ # Create insight generation prompt
50
+ self.insight_prompt = ChatPromptTemplate.from_messages([
51
+ ("system", """You are an expert pharmaceutical analytics insights generator.
52
+ Your task is to create clear, actionable insights from analysis results.
53
+
54
+ For each insight request:
55
+ 1. Synthesize analysis findings into clear, concise insights
56
+ 2. Prioritize insights based on business impact
57
+ 3. Tailor communication style to the target audience
58
+ 4. Suggest concrete action items based on the findings
59
+ 5. Present balanced view including confidence levels and limitations
60
+
61
+ Output your insights in JSON format with the following structure:
62
+ ```json
63
+ {
64
+ "title": "DrugX Sales Decline Analysis",
65
+ "description": "Analysis of the 15% sales decline in the Northeast region",
66
+ "key_findings": [
67
+ {
68
+ "finding": "Competitor Launch Impact",
69
+ "details": "The launch of CompDrug2 by MedCorp 45 days ago has captured approximately 60% of our market share in the Northeast region.",
70
+ "evidence": "Strong correlation between sales decline and competitor sales growth, with 85% confidence.",
71
+ "impact": "Estimated $2.4M quarterly revenue impact"
72
+ },
73
+ {
74
+ "finding": "Supply Chain Issues",
75
+ "details": "Inventory shortages at 3 distribution centers in the Northeast have led to unfilled orders.",
76
+ "evidence": "25% of pharmacies experienced stockouts in the last 30 days.",
77
+ "impact": "Estimated $1.0M quarterly revenue impact"
78
+ },
79
+ {
80
+ "finding": "Seasonal Factors",
81
+ "details": "Normal seasonal variation accounts for a portion of the observed decline.",
82
+ "evidence": "Historical patterns show 5-7% seasonal decline in this period.",
83
+ "impact": "Estimated $0.6M quarterly revenue impact"
84
+ }
85
+ ],
86
+ "charts": [
87
+ "sales_trend_chart",
88
+ "competitor_comparison_chart",
89
+ "supply_chain_impact_chart"
90
+ ],
91
+ "metrics": {
92
+ "total_impact": "$4.0M quarterly",
93
+ "market_share_loss": "8.5 percentage points",
94
+ "affected_prescribers": "217 out of 934 (23%)",
95
+ "affected_territories": "3 out of 4 Northeast territories"
96
+ },
97
+ "action_items": [
98
+ {
99
+ "action": "Launch targeted co-pay program",
100
+ "owner": "Marketing",
101
+ "timeline": "Immediate (0-15 days)",
102
+ "expected_impact": "Recover 30-40% of lost prescriptions",
103
+ "priority": "High"
104
+ },
105
+ {
106
+ "action": "Resolve supply chain bottlenecks",
107
+ "owner": "Operations",
108
+ "timeline": "Short-term (15-45 days)",
109
+ "expected_impact": "Eliminate 90% of stockouts",
110
+ "priority": "High"
111
+ },
112
+ {
113
+ "action": "Develop competitive response strategy",
114
+ "owner": "Commercial Strategy",
115
+ "timeline": "Medium-term (30-90 days)",
116
+ "expected_impact": "Position for market share recovery",
117
+ "priority": "Medium"
118
+ }
119
+ ],
120
+ "confidence": 0.85
121
+ }
122
+ ```
123
+
124
+ Adapt your insights to the target audience:
125
+ - For executives: Focus on business impact, actions, and strategic implications
126
+ - For analysts: Include more detailed findings and evidence
127
+ - For data scientists: Add methodological details and statistical significance
128
+
129
+ Be concise but comprehensive, highlighting the most important insights first.
130
+ """),
131
+ ("human", """
132
+ Original Problem Statement: {original_problem}
133
+
134
+ Analysis Results:
135
+ {analysis_results}
136
+
137
+ Validation Results:
138
+ {validation_results}
139
+
140
+ Target Audience: {target_audience}
141
+
142
+ Please generate actionable insights based on these results.
143
+ """)
144
+ ])
145
+
146
+ # Set up the insight generation chain
147
+ self.insight_chain = (
148
+ self.insight_prompt
149
+ | self.llm
150
+ | StrOutputParser()
151
+ )
152
+
153
+ # Create visualization prompt
154
+ self.visualization_prompt = ChatPromptTemplate.from_messages([
155
+ ("system", """You are an expert data visualization designer specializing in pharmaceutical analytics.
156
+ Your task is to generate Python code to create clear, insightful visualizations based on analysis results.
157
+
158
+ For each visualization request:
159
+ 1. Create professional, publication-quality visualizations
160
+ 2. Choose appropriate chart types for the data and insights
161
+ 3. Use a consistent color scheme and styling
162
+ 4. Add clear labels, titles, and annotations
163
+ 5. Focus on communicating the key insights effectively
164
+
165
+ The visualizations should tell a compelling story about the data.
166
+ Make sure to include all the necessary code for styling and formatting.
167
+
168
+ Format your response with a code block:
169
+ ```python
170
+ # Visualization code
171
+ import pandas as pd
172
+ import numpy as np
173
+ import matplotlib.pyplot as plt
174
+ import seaborn as sns
175
+
176
+ def create_visualizations(data_sources):
177
+ # Your visualization code here
178
+ # Create multiple figures as needed
179
+
180
+ # Return a list of figure objects
181
+ return [fig1, fig2, fig3]
182
+ ```
183
+
184
+ The code should be complete and ready to execute with the provided data sources.
185
+ """),
186
+ ("human", """
187
+ Visualization Request: {description}
188
+
189
+ Key Insights:
190
+ {key_insights}
191
+
192
+ Available data sources:
193
+ {data_sources}
194
+
195
+ Target audience: {target_audience}
196
+
197
+ Please generate Python code to create visualizations for these insights.
198
+ """)
199
+ ])
200
+
201
+ # Set up the visualization chain
202
+ self.visualization_chain = (
203
+ self.visualization_prompt
204
+ | self.llm
205
+ | StrOutputParser()
206
+ )
207
+
208
+ def extract_json_from_response(self, response: str) -> Dict:
209
+ """Extract JSON from text that might contain additional content"""
210
+ try:
211
+ # First, try to parse the entire text as JSON
212
+ return json.loads(response)
213
+ except json.JSONDecodeError:
214
+ # If that fails, look for JSON block
215
+ import re
216
+ json_pattern = r'```json\s*([\s\S]*?)\s*```'
217
+ match = re.search(json_pattern, response, re.DOTALL)
218
+ if match:
219
+ try:
220
+ return json.loads(match.group(1))
221
+ except json.JSONDecodeError:
222
+ pass
223
+
224
+ # Try a more aggressive approach to find JSON-like content
225
+ json_pattern = r'({[\s\S]*})'
226
+ match = re.search(json_pattern, response)
227
+ if match:
228
+ try:
229
+ return json.loads(match.group(1))
230
+ except json.JSONDecodeError:
231
+ pass
232
+
233
+ raise ValueError(f"Could not extract JSON from response: {response}")
234
+
235
+ def extract_python_from_response(self, response: str) -> str:
236
+ """Extract Python code from LLM response"""
237
+ # Extract Python between ```python and ``` markers
238
+ python_match = re.search(r'```python\s*(.*?)\s*```', response, re.DOTALL)
239
+ if python_match:
240
+ return python_match.group(1).strip()
241
+
242
+ # If not found with python tag, try generic code block
243
+ python_match = re.search(r'```\s*(.*?)\s*```', response, re.DOTALL)
244
+ if python_match:
245
+ return python_match.group(1).strip()
246
+
247
+ # If all else fails, return empty string
248
+ return ""
249
+
250
+ def generate_insights(self, request: InsightRequest) -> InsightCard:
251
+ """Generate insights based on analysis and validation results"""
252
+ print(f"Insights Agent: Generating insights for problem: {request.original_problem}")
253
+
254
+ # Format analysis results for the prompt
255
+ analysis_results_str = json.dumps(request.analysis_results, indent=2)
256
+
257
+ # Format validation results for the prompt
258
+ validation_results_str = json.dumps(request.validation_results, indent=2)
259
+
260
+ # Format the request for the prompt
261
+ request_data = {
262
+ "original_problem": request.original_problem,
263
+ "analysis_results": analysis_results_str,
264
+ "validation_results": validation_results_str,
265
+ "target_audience": request.target_audience
266
+ }
267
+
268
+ # Generate insights
269
+ response = self.insight_chain.invoke(request_data)
270
+
271
+ # Extract and parse insights JSON
272
+ insights_dict = self.extract_json_from_response(response)
273
+
274
+ # Add missing fields
275
+ insights_dict["card_id"] = f"insight_{request.request_id}"
276
+ insights_dict["timestamp"] = datetime.now().isoformat()
277
+
278
+ # Ensure confidence exists
279
+ if "confidence" not in insights_dict:
280
+ # Use validation score if available, otherwise default to 0.7
281
+ insights_dict["confidence"] = request.validation_results.get("validation_score", 0.7)
282
+
283
+ return InsightCard(**insights_dict)
284
+
285
+ def generate_visualizations(self, insight_card: InsightCard, data_sources: Dict[str, Any]) -> List[str]:
286
+ """Generate visualizations based on insights"""
287
+ print(f"Insights Agent: Generating visualizations for insight card: {insight_card.title}")
288
+
289
+ # Extract key insights for visualization context
290
+ key_insights_str = json.dumps(insight_card.key_findings, indent=2)
291
+
292
+ # Format data sources description for the prompt
293
+ data_sources_desc = ""
294
+ for source_id, source in data_sources.items():
295
+ df = source.content
296
+ data_sources_desc += f"Data source '{source_id}' ({source.name}):\n"
297
+ data_sources_desc += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
298
+ data_sources_desc += f"- Columns: {', '.join(df.columns)}\n"
299
+ data_sources_desc += f"- Sample data:\n{df.head(3).to_string()}\n\n"
300
+
301
+ # Format the request for the prompt
302
+ request_data = {
303
+ "description": insight_card.title,
304
+ "key_insights": key_insights_str,
305
+ "data_sources": data_sources_desc,
306
+ "target_audience": "executive" # Default to executive-level visualizations
307
+ }
308
+
309
+ # Generate visualization code
310
+ response = self.visualization_chain.invoke(request_data)
311
+
312
+ # Extract Python code
313
+ python_code = self.extract_python_from_response(response)
314
+
315
+ # Execute visualization code (with safety checks)
316
+ visualization_files = []
317
+
318
+ if not python_code:
319
+ print("Warning: No visualization code generated.")
320
+ else:
321
+ try:
322
+ # Prepare data sources for the visualizations
323
+ viz_data_sources = {src_id: src.content for src_id, src in data_sources.items()}
324
+
325
+ # Create a local namespace with access to pandas, numpy, etc.
326
+ local_namespace = {
327
+ "pd": pd,
328
+ "np": np,
329
+ "plt": plt,
330
+ "sns": sns,
331
+ "data_sources": viz_data_sources
332
+ }
333
+
334
+ # Execute the code
335
+ exec(python_code, local_namespace)
336
+
337
+ # Look for a create_visualizations function and execute it
338
+ if "create_visualizations" in local_namespace:
339
+ figures = local_namespace["create_visualizations"](viz_data_sources)
340
+
341
+ # Save figures to files
342
+ for i, fig in enumerate(figures):
343
+ if hasattr(fig, 'savefig'):
344
+ fig_filename = f"viz_{insight_card.card_id}_{i}.png"
345
+ fig.savefig(fig_filename, dpi=300, bbox_inches='tight')
346
+ visualization_files.append(fig_filename)
347
+
348
+ except Exception as e:
349
+ print(f"Visualization execution error: {e}")
350
+
351
+ return visualization_files
352
+
353
+ # For testing
354
+ if __name__ == "__main__":
355
+ import matplotlib.pyplot as plt
356
+ import seaborn as sns
357
+
358
+ # Set API key for testing
359
+ os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
360
+
361
+ # Create mock insight request
362
+ class MockInsightRequest:
363
+ def __init__(self):
364
+ self.request_id = "test"
365
+ self.original_problem = "Sales of DrugX down 15% in Northeast region over past 30 days"
366
+ self.analysis_results = {
367
+ "insights": [
368
+ {"finding": "Competitor launch impact", "details": "New competing drug launched", "impact": "Estimated 60% of decline"},
369
+ {"finding": "Supply chain issues", "details": "Inventory shortages in key distribution centers", "impact": "Estimated 25% of decline"}
370
+ ],
371
+ "attribution": {
372
+ "competitor_launch": 0.60,
373
+ "supply_issues": 0.25,
374
+ "seasonal_factors": 0.15
375
+ },
376
+ "confidence": 0.85
377
+ }
378
+ self.validation_results = {
379
+ "validation_score": 0.82,
380
+ "critical_issues": [],
381
+ "recommendations": ["Consider analyzing prescriber-level data"]
382
+ }
383
+ self.target_audience = "executive"
384
+
385
+ # Create mock data sources
386
+ from dataclasses import dataclass
387
+
388
+ @dataclass
389
+ class MockDataSource:
390
+ content: pd.DataFrame
391
+ name: str
392
+
393
+ sales_df = pd.DataFrame({
394
+ 'date': pd.date_range(start='2023-01-01', periods=12, freq='M'),
395
+ 'region': ['Northeast'] * 12,
396
+ 'sales': [100, 110, 105, 115, 120, 115, 110, 105, 95, 85, 80, 70],
397
+ 'target': [100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155]
398
+ })
399
+
400
+ competitor_df = pd.DataFrame({
401
+ 'date': pd.date_range(start='2023-10-01', periods=3, freq='M'),
402
+ 'competitor': ['CompDrug2'] * 3,
403
+ 'launch_region': ['Northeast'] * 3,
404
+ 'estimated_sales': [0, 50, 70]
405
+ })
406
+
407
+ data_sources = {
408
+ "sales_data": MockDataSource(content=sales_df, name="Monthly sales data"),
409
+ "competitor_data": MockDataSource(content=competitor_df, name="Competitor launch data")
410
+ }
411
+
412
+ agent = InsightsAgent()
413
+ insight_card = agent.generate_insights(MockInsightRequest())
414
+ print(f"Insight card title: {insight_card.title}")
415
+ print(f"Key findings: {json.dumps(insight_card.key_findings, indent=2)}")
416
+ print(f"Action items: {json.dumps(insight_card.action_items, indent=2)}")
417
+
418
+ visualizations = agent.generate_visualizations(insight_card, data_sources)
419
+ print(f"Generated visualizations: {visualizations}")