multi-ai-agentic-system / agents /critic_agent.py
CHKIM79's picture
Deploy Multi-AI Agentic System v1.0.0 - Production Ready
d4b6ffc
"""
Critic Agent - Validates and quality-checks results
"""
import asyncio
from typing import Dict, Any, List
from core.base_agent import BaseAgent, AgentMessage, TaskResult, TaskStatus
class CriticAgent(BaseAgent):
"""Agent specialized in validation and quality assurance"""
def __init__(self):
super().__init__("critic_agent", ["validation", "quality_check", "verification"])
self.validation_rules = self._initialize_validation_rules()
self.quality_metrics = ["accuracy", "completeness", "consistency", "relevance"]
def _initialize_validation_rules(self) -> Dict[str, Any]:
"""Initialize validation rules for different domains"""
return {
"flight_booking": {
"required_fields": ["airline", "price", "duration", "route"],
"price_range": {"min": 50, "max": 2000},
"duration_max": 24, # hours
"valid_airlines": ["Delta", "United", "American", "JetBlue", "Southwest"]
},
"general": {
"min_response_length": 10,
"max_response_length": 5000,
"required_sections": ["findings", "recommendation"]
}
}
async def process_task(self, message: AgentMessage) -> TaskResult:
"""Process validation and quality check tasks"""
start_time = asyncio.get_event_loop().time()
try:
if message.message_type == "validation":
result = await self._validate_results(message.data)
elif message.message_type == "quality_check":
result = await self._quality_check(message.data)
elif message.message_type == "verification":
result = await self._verify_consistency(message.data)
else:
raise ValueError(f"Unknown task type: {message.message_type}")
return TaskResult(
task_id=message.task_id,
agent_id=self.agent_id,
status=TaskStatus.COMPLETED,
result=result,
execution_time=asyncio.get_event_loop().time() - start_time
)
except Exception as e:
return TaskResult(
task_id=message.task_id,
agent_id=self.agent_id,
status=TaskStatus.FAILED,
result={},
error_message=str(e),
execution_time=asyncio.get_event_loop().time() - start_time
)
async def _validate_results(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Validate results from other agents"""
results_to_validate = data.get("results", [])
domain = data.get("domain", "general")
validation_report = {
"overall_status": "passed",
"issues_found": [],
"warnings": [],
"validated_items": 0,
"failed_items": 0
}
rules = self.validation_rules.get(domain, self.validation_rules["general"])
for i, result in enumerate(results_to_validate):
item_validation = await self._validate_single_item(result, rules, f"item_{i}")
validation_report["validated_items"] += 1
if not item_validation["passed"]:
validation_report["failed_items"] += 1
validation_report["overall_status"] = "failed"
validation_report["issues_found"].extend(item_validation["issues"])
validation_report["warnings"].extend(item_validation["warnings"])
return {
"findings": [
f"Validated {validation_report['validated_items']} items",
f"Overall status: {validation_report['overall_status']}",
f"Issues found: {len(validation_report['issues_found'])}"
],
"validation_report": validation_report,
"recommendation": self._generate_validation_recommendation(validation_report)
}
async def _validate_single_item(self, item: Dict[str, Any], rules: Dict[str, Any], item_id: str) -> Dict[str, Any]:
"""Validate a single item against rules"""
validation = {
"item_id": item_id,
"passed": True,
"issues": [],
"warnings": []
}
# Check required fields
if "required_fields" in rules:
for field in rules["required_fields"]:
if field not in item:
validation["passed"] = False
validation["issues"].append(f"Missing required field: {field}")
# Check price range for flight bookings
if "price_range" in rules and "price" in item:
price = item["price"]
if not (rules["price_range"]["min"] <= price <= rules["price_range"]["max"]):
validation["passed"] = False
validation["issues"].append(f"Price ${price} outside valid range")
# Check airline validity
if "valid_airlines" in rules and "airline" in item:
if item["airline"] not in rules["valid_airlines"]:
validation["warnings"].append(f"Airline {item['airline']} not in preferred list")
# Check duration
if "duration_max" in rules and "duration" in item:
duration_str = item["duration"]
duration_hours = self._parse_duration_to_hours(duration_str)
if duration_hours > rules["duration_max"]:
validation["warnings"].append(f"Duration {duration_str} exceeds typical maximum")
return validation
def _parse_duration_to_hours(self, duration_str: str) -> float:
"""Parse duration string to hours"""
try:
# Handle formats like "6h 15m"
parts = duration_str.replace("h", "").replace("m", "").split()
hours = float(parts[0]) if len(parts) > 0 else 0
minutes = float(parts[1]) if len(parts) > 1 else 0
return hours + (minutes / 60)
except:
return 0
async def _quality_check(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Perform comprehensive quality check"""
content = data.get("content", {})
quality_scores = {}
overall_issues = []
# Check accuracy (mock implementation)
accuracy_score = await self._check_accuracy(content)
quality_scores["accuracy"] = accuracy_score
# Check completeness
completeness_score = await self._check_completeness(content)
quality_scores["completeness"] = completeness_score
# Check consistency
consistency_score = await self._check_consistency(content)
quality_scores["consistency"] = consistency_score
# Check relevance
relevance_score = await self._check_relevance(content)
quality_scores["relevance"] = relevance_score
# Calculate overall score
overall_score = sum(quality_scores.values()) / len(quality_scores)
# Determine quality level
if overall_score >= 0.8:
quality_level = "excellent"
elif overall_score >= 0.6:
quality_level = "good"
elif overall_score >= 0.4:
quality_level = "fair"
else:
quality_level = "poor"
overall_issues.append("Content quality below acceptable threshold")
return {
"findings": [
f"Overall quality score: {overall_score:.2f}",
f"Quality level: {quality_level}",
f"Issues identified: {len(overall_issues)}"
],
"quality_scores": quality_scores,
"quality_level": quality_level,
"issues": overall_issues,
"recommendation": f"Content quality is {quality_level}. " +
("Consider improvements." if overall_score < 0.6 else "Quality acceptable.")
}
async def _check_accuracy(self, content: Dict[str, Any]) -> float:
"""Check accuracy of content (mock implementation)"""
# In real implementation, this would verify facts against reliable sources
score = 0.85 # Mock score
# Check for obvious inconsistencies
if "price" in content and isinstance(content["price"], (int, float)):
if content["price"] < 0:
score -= 0.3
return max(0, min(1, score))
async def _check_completeness(self, content: Dict[str, Any]) -> float:
"""Check completeness of content"""
expected_keys = ["findings", "recommendation"]
present_keys = sum(1 for key in expected_keys if key in content and content[key])
return present_keys / len(expected_keys)
async def _check_consistency(self, content: Dict[str, Any]) -> float:
"""Check internal consistency"""
# Mock consistency check
score = 0.9
# Check for contradictory information
if "findings" in content and "recommendation" in content:
findings_str = str(content["findings"]).lower()
recommendation_str = str(content["recommendation"]).lower()
# Simple contradiction detection
if "failed" in findings_str and "success" in recommendation_str:
score -= 0.2
return max(0, min(1, score))
async def _check_relevance(self, content: Dict[str, Any]) -> float:
"""Check relevance to user request"""
# Mock relevance check
return 0.8
async def _verify_consistency(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Verify consistency across multiple results"""
results = data.get("results", [])
consistency_issues = []
# Check for price consistency
prices = [r.get("price") for r in results if "price" in r]
if prices and (max(prices) - min(prices)) > 1000:
consistency_issues.append("Large price variations detected")
# Check for airline consistency
airlines = [r.get("airline") for r in results if "airline" in r]
if len(set(airlines)) > len(airlines) * 0.8: # Too many different airlines
consistency_issues.append("Inconsistent airline recommendations")
consistency_score = 1.0 - (len(consistency_issues) * 0.2)
return {
"findings": [
f"Consistency score: {consistency_score:.2f}",
f"Issues found: {len(consistency_issues)}"
],
"consistency_score": consistency_score,
"issues": consistency_issues,
"recommendation": "Results are consistent" if consistency_score > 0.7 else "Review for consistency issues"
}
def _generate_validation_recommendation(self, validation_report: Dict[str, Any]) -> str:
"""Generate recommendation based on validation results"""
if validation_report["overall_status"] == "passed":
if validation_report["warnings"]:
return f"Validation passed with {len(validation_report['warnings'])} warnings. Review recommended."
else:
return "All validations passed successfully."
else:
return f"Validation failed with {len(validation_report['issues_found'])} critical issues. Requires attention."
def get_agent_info(self) -> Dict[str, Any]:
"""Return critic agent information"""
return {
"agent_id": self.agent_id,
"type": "critic",
"capabilities": self.capabilities,
"specialization": "Quality assurance and validation",
"validation_domains": list(self.validation_rules.keys()),
"quality_metrics": self.quality_metrics
}