Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| from typing import Dict, List, Any, Tuple, Optional | |
| from pydantic import BaseModel, Field | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| import re | |
| from datetime import datetime | |
| class ValidationRequest(BaseModel): | |
| """Structure for a validation request""" | |
| request_id: str | |
| original_problem: str | |
| analysis_results: Dict[str, Any] | |
| data_sources: List[str] | |
| class ValidationResult(BaseModel): | |
| """Structure for validation results""" | |
| result_id: str | |
| validation_score: float = Field(ge=0.0, le=1.0) | |
| data_quality_score: float = Field(ge=0.0, le=1.0) | |
| analysis_quality_score: float = Field(ge=0.0, le=1.0) | |
| insight_quality_score: float = Field(ge=0.0, le=1.0) | |
| validation_checks: List[Dict[str, Any]] | |
| recommendations: List[str] | |
| critical_issues: List[str] | |
| timestamp: datetime | |
| class QAAgent: | |
| """Agent responsible for quality assurance and validation""" | |
| def __init__(self): | |
| """Initialize the QA agent""" | |
| # Set up Claude API client | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise ValueError("ANTHROPIC_API_KEY not found in environment variables") | |
| self.llm = ChatAnthropic( | |
| model="claude-3-7-sonnet-20250219", | |
| anthropic_api_key=api_key, | |
| temperature=0.1 | |
| ) | |
| # Create validation prompt | |
| self.validation_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert pharmaceutical analytics validator. | |
| Your task is to thoroughly validate analysis results to ensure quality and accuracy. | |
| For each validation request: | |
| 1. Assess data quality (completeness, relevance, potential biases) | |
| 2. Evaluate analysis methodology (appropriateness, statistical rigor) | |
| 3. Verify that insights address the original problem statement | |
| 4. Check for potential alternative explanations | |
| 5. Identify any critical issues that could invalidate findings | |
| Output your validation in JSON format with the following structure: | |
| ```json | |
| { | |
| "validation_score": 0.85, # Overall validation score (0.0-1.0) | |
| "data_quality_score": 0.9, # Data quality score (0.0-1.0) | |
| "analysis_quality_score": 0.8, # Analysis methodology score (0.0-1.0) | |
| "insight_quality_score": 0.85, # Quality of insights score (0.0-1.0) | |
| "validation_checks": [ | |
| { | |
| "check": "Data completeness", | |
| "result": "PASS", | |
| "details": "All required data appears to be present", | |
| "score": 1.0 | |
| }, | |
| { | |
| "check": "Methodology appropriateness", | |
| "result": "PARTIAL", | |
| "details": "Time series approach valid but seasonality not fully addressed", | |
| "score": 0.7 | |
| }, | |
| # More validation checks... | |
| ], | |
| "recommendations": [ | |
| "Consider adjusting for seasonality in the time series analysis", | |
| # More recommendations... | |
| ], | |
| "critical_issues": [ | |
| # Any issues that could invalidate the findings | |
| ] | |
| } | |
| ``` | |
| Be thorough in your assessment and provide specific details for each check. | |
| """), | |
| ("human", """ | |
| Original Problem Statement: {original_problem} | |
| Analysis Results: | |
| {analysis_results} | |
| Available Data Sources: | |
| {data_sources} | |
| Please validate these analysis results thoroughly. | |
| """) | |
| ]) | |
| # Set up the validation chain | |
| self.validation_chain = ( | |
| self.validation_prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| def extract_json_from_response(self, response: str) -> Dict: | |
| """Extract JSON from text that might contain additional content""" | |
| try: | |
| # First, try to parse the entire text as JSON | |
| return json.loads(response) | |
| except json.JSONDecodeError: | |
| # If that fails, look for JSON block | |
| import re | |
| json_pattern = r'```json\s*([\s\S]*?)\s*```' | |
| match = re.search(json_pattern, response, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| # Try a more aggressive approach to find JSON-like content | |
| json_pattern = r'({[\s\S]*})' | |
| match = re.search(json_pattern, response) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| raise ValueError(f"Could not extract JSON from response: {response}") | |
| def validate_analysis(self, request: ValidationRequest, data_sources: Dict[str, Any]) -> ValidationResult: | |
| """Validate analysis results""" | |
| print(f"QA Agent: Validating analysis results for problem: {request.original_problem}") | |
| # Format data sources description for the prompt | |
| data_sources_desc = "" | |
| for source_id, source in data_sources.items(): | |
| df = source.content | |
| data_sources_desc += f"Data source '{source_id}' ({source.name}):\n" | |
| data_sources_desc += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n" | |
| data_sources_desc += f"- Columns: {', '.join(df.columns)}\n" | |
| data_sources_desc += f"- Sample data:\n{df.head(3).to_string()}\n\n" | |
| # Format analysis results for the prompt | |
| analysis_results_desc = json.dumps(request.analysis_results, indent=2) | |
| # Format the request for the prompt | |
| request_data = { | |
| "original_problem": request.original_problem, | |
| "analysis_results": analysis_results_desc, | |
| "data_sources": data_sources_desc | |
| } | |
| # Generate validation | |
| response = self.validation_chain.invoke(request_data) | |
| # Extract and parse validation JSON | |
| validation_dict = self.extract_json_from_response(response) | |
| # Create validation result with current timestamp | |
| validation_dict["timestamp"] = datetime.now() | |
| validation_dict["result_id"] = f"validation_{request.request_id}" | |
| return ValidationResult(**validation_dict) | |
| def get_validation_summary(self, validation: ValidationResult) -> Dict[str, Any]: | |
| """Generate a human-readable summary of validation results""" | |
| # Determine overall validation status | |
| if validation.validation_score >= 0.9: | |
| status = "EXCELLENT" | |
| elif validation.validation_score >= 0.75: | |
| status = "GOOD" | |
| elif validation.validation_score >= 0.6: | |
| status = "ACCEPTABLE" | |
| else: | |
| status = "NEEDS IMPROVEMENT" | |
| # Count check results | |
| check_counts = {"PASS": 0, "PARTIAL": 0, "FAIL": 0} | |
| for check in validation.validation_checks: | |
| result = check.get("result", "") | |
| if result in check_counts: | |
| check_counts[result] += 1 | |
| # Create summary | |
| summary = { | |
| "status": status, | |
| "overall_score": validation.validation_score, | |
| "data_quality_score": validation.data_quality_score, | |
| "analysis_quality_score": validation.analysis_quality_score, | |
| "insight_quality_score": validation.insight_quality_score, | |
| "check_counts": check_counts, | |
| "critical_issues_count": len(validation.critical_issues), | |
| "recommendations_count": len(validation.recommendations), | |
| "timestamp": validation.timestamp.strftime("%Y-%m-%d %H:%M:%S") | |
| } | |
| return summary | |
| # For testing | |
| if __name__ == "__main__": | |
| # Set API key for testing | |
| os.environ["ANTHROPIC_API_KEY"] = "your_api_key_here" | |
| # Create mock validation request | |
| class MockValidationRequest: | |
| def __init__(self): | |
| self.request_id = "test" | |
| self.original_problem = "Sales of DrugX down 15% in Northeast region over past 30 days" | |
| self.analysis_results = { | |
| "insights": [ | |
| {"finding": "Competitor launch impact", "details": "New competing drug launched", "impact": "Estimated 60% of decline"}, | |
| {"finding": "Supply chain issues", "details": "Inventory shortages in key distribution centers", "impact": "Estimated 25% of decline"} | |
| ], | |
| "attribution": { | |
| "competitor_launch": 0.60, | |
| "supply_issues": 0.25, | |
| "seasonal_factors": 0.15 | |
| }, | |
| "confidence": 0.85 | |
| } | |
| self.data_sources = ["sales_data", "competitor_data"] | |
| # Create mock data sources | |
| from dataclasses import dataclass | |
| class MockDataSource: | |
| content: pd.DataFrame | |
| name: str | |
| sales_df = pd.DataFrame({ | |
| 'date': pd.date_range(start='2023-01-01', periods=12, freq='M'), | |
| 'region': ['Northeast'] * 12, | |
| 'sales': [100, 110, 105, 115, 120, 115, 110, 105, 95, 85, 80, 70], | |
| 'target': [100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 155] | |
| }) | |
| competitor_df = pd.DataFrame({ | |
| 'date': pd.date_range(start='2023-10-01', periods=3, freq='M'), | |
| 'competitor': ['CompDrug2'] * 3, | |
| 'launch_region': ['Northeast'] * 3, | |
| 'estimated_sales': [0, 50, 70] | |
| }) | |
| data_sources = { | |
| "sales_data": MockDataSource(content=sales_df, name="Monthly sales data"), | |
| "competitor_data": MockDataSource(content=competitor_df, name="Competitor launch data") | |
| } | |
| agent = QAAgent() | |
| validation = agent.validate_analysis(MockValidationRequest(), data_sources) | |
| print(f"Validation score: {validation.validation_score}") | |
| print(f"Critical issues: {validation.critical_issues}") | |
| print(f"Recommendations: {validation.recommendations}") | |
| summary = agent.get_validation_summary(validation) | |
| print(f"Summary: {json.dumps(summary, indent=2)}") |