sales_analytics / agents /qa_agent.py
cryogenic22's picture
Update agents/qa_agent.py
97bd872 verified
import os
import json
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Tuple, Optional
from pydantic import BaseModel, Field
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
import re
from datetime import datetime
class ValidationRequest(BaseModel):
"""Structure for a validation request"""
request_id: str
original_problem: str
analysis_results: Dict[str, Any]
data_sources: List[str]
class ValidationResult(BaseModel):
"""Structure for validation results"""
result_id: str
validation_score: float = Field(ge=0.0, le=1.0)
data_quality_score: float = Field(ge=0.0, le=1.0)
analysis_quality_score: float = Field(ge=0.0, le=1.0)
insight_quality_score: float = Field(ge=0.0, le=1.0)
validation_checks: List[Dict[str, Any]]
recommendations: List[str]
critical_issues: List[str]
timestamp: datetime
class QAAgent:
"""Agent responsible for quality assurance and validation"""
def __init__(self):
"""Initialize the QA agent"""
# Set up Claude API client
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
self.llm = ChatAnthropic(
model="claude-3-7-sonnet-20250219",
anthropic_api_key=api_key,
temperature=0.1
)
# Create validation prompt
self.validation_prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert pharmaceutical analytics validator.
Your task is to thoroughly validate analysis results to ensure quality and accuracy.
For each validation request:
1. Assess data quality (completeness, relevance, potential biases)
2. Evaluate analysis methodology (appropriateness, statistical rigor)
3. Verify that insights address the original problem statement
4. Check for potential alternative explanations
5. Identify any critical issues that could invalidate findings
Output your validation in JSON format with the following structure:
```json
{
"validation_score": 0.85, # Overall validation score (0.0-1.0)
"data_quality_score": 0.9, # Data quality score (0.0-1.0)
"analysis_quality_score": 0.8, # Analysis methodology score (0.0-1.0)
"insight_quality_score": 0.85, # Quality of insights score (0.0-1.0)
"validation_checks": [
{
"check": "Data completeness",
"result": "PASS",
"details": "All required data appears to be present",
"score": 1.0
},
{
"check": "Methodology appropriateness",
"result": "PARTIAL",
"details": "Time series approach valid but seasonality not fully addressed",
"score": 0.7
},
# More validation checks...
],
"recommendations": [
"Consider adjusting for seasonality in the time series analysis",
# More recommendations...
],
"critical_issues": [
# Any issues that could invalidate the findings
]
}
```
Be thorough in your assessment and provide specific details for each check.
"""),
("human", """
Original Problem Statement: {original_problem}
Analysis Results:
{analysis_results}
Available Data Sources:
{data_sources}
Please validate these analysis results thoroughly.
""")
])
# Set up the validation chain
self.validation_chain = (
self.validation_prompt
| self.llm
| StrOutputParser()
)
def extract_json_from_response(self, response: str) -> Dict:
"""Extract JSON from text that might contain additional content"""
try:
# First, try to parse the entire text as JSON
return json.loads(response)
except json.JSONDecodeError:
# If that fails, look for JSON block
import re
json_pattern = r'```json\s*([\s\S]*?)\s*```'
match = re.search(json_pattern, response, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
# Try a more aggressive approach to find JSON-like content
json_pattern = r'({[\s\S]*})'
match = re.search(json_pattern, response)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
pass
raise ValueError(f"Could not extract JSON from response: {response}")
def validate_analysis(self, request: ValidationRequest, data_sources: Dict[str, Any]) -> ValidationResult:
"""Validate analysis results"""
print(f"QA Agent: Validating analysis results for problem: {request.original_problem}")
# Format data sources description for the prompt
data_sources_desc = ""
for source_id, source in data_sources.items():
df = source.content
data_sources_desc += f"Data source '{source_id}' ({source.name}):\n"
data_sources_desc += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
data_sources_desc += f"- Columns: {', '.join(df.columns)}\n"
data_sources_desc += f"- Sample data:\n{df.head(3).to_string()}\n\n"
# Format analysis results for the prompt
analysis_results_desc = json.dumps(request.analysis_results, indent=2)
# Format the request for the prompt
request_data = {
"original_problem": request.original_problem,
"analysis_results": analysis_results_desc,
"data_sources": data_sources_desc
}
# Generate validation
response = self.validation_chain.invoke(request_data)
# Extract and parse validation JSON
validation_dict = self.extract_json_from_response(response)
# Create validation result with current timestamp
validation_dict["timestamp"] = datetime.now()
validation_dict["result_id"] = f"validation_{request.request_id}"
return ValidationResult(**validation_dict)
def get_validation_summary(self, validation: ValidationResult) -> Dict[str, Any]:
"""Generate a human-readable summary of validation results"""
# Determine overall validation status
if validation.validation_score >= 0.9:
status = "EXCELLENT"
elif validation.validation_score >= 0.75:
status = "GOOD"
elif validation.validation_score >= 0.6:
status = "ACCEPTABLE"
else:
status = "NEEDS IMPROVEMENT"
# Count check results
check_counts = {"PASS": 0, "PARTIAL": 0, "FAIL": 0}
for check in validation.validation_checks:
result = check.get("result", "")
if result in check_counts:
check_counts[result] += 1
# Create summary
summary = {
"status": status,
"overall_score": validation.validation_score,
"data_quality_score": validation.data_quality_score,
"analysis_quality_score": validation.analysis_quality_score,
"insight_quality_score": validation.insight_quality_score,
"check_counts": check_counts,
"critical_issues_count": len(validation.critical_issues),
"recommendations_count": len(validation.recommendations),
"timestamp": validation.timestamp.strftime("%Y-%m-%d %H:%M:%S")
}
return summary
# For testing
if __name__ == "__main__":
# Set API key for testing
os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here"
# Create mock validation request
class MockValidationRequest:
def __init__(self):
self.request_id = "test"
self.original_problem = "Sales of DrugX down 15% in Northeast region over past 30 days"
self.analysis_results = {
"insights": [
{"finding": "Competitor launch impact", "details": "New competing drug launched", "impact": "Estimated 60% of decline"},
{"finding": "Supply chain issues", "details": "Inventory shortages in key distribution centers", "impact": "Estimated 25% of decline"}
],
"attribution": {
"competitor_launch": 0.60,
"supply_issues": 0.25,
"seasonal_factors": 0.15
},
"confidence": 0.85
}
self.data_sources = ["sales_data", "competitor_data"]
# Create mock data sources
from dataclasses import dataclass
@dataclass
class MockDataSource:
content: pd.DataFrame
name: str
sales_df = pd.DataFrame({
'date': pd.date_range(start='2023-01-01', periods=12, freq='M'),
'region': ['Northeast'] * 12,
'sales': [100, 110, 105, 115, 120, 115, 110, 105, 95, 85, 80, 70],
'target': [100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155]
})
competitor_df = pd.DataFrame({
'date': pd.date_range(start='2023-10-01', periods=3, freq='M'),
'competitor': ['CompDrug2'] * 3,
'launch_region': ['Northeast'] * 3,
'estimated_sales': [0, 50, 70]
})
data_sources = {
"sales_data": MockDataSource(content=sales_df, name="Monthly sales data"),
"competitor_data": MockDataSource(content=competitor_df, name="Competitor launch data")
}
agent = QAAgent()
validation = agent.validate_analysis(MockValidationRequest(), data_sources)
print(f"Validation score: {validation.validation_score}")
print(f"Critical issues: {validation.critical_issues}")
print(f"Recommendations: {validation.recommendations}")
summary = agent.get_validation_summary(validation)
print(f"Summary: {json.dumps(summary, indent=2)}")