Spaces:
Running
Running
| """ | |
| MediGuard AI RAG-Helper | |
| Confidence Assessor Agent - Evaluates prediction reliability | |
| """ | |
| from typing import Any | |
| from src.biomarker_validator import BiomarkerValidator | |
| from src.llm_config import llm_config | |
| from src.state import AgentOutput, GuildState | |
| class ConfidenceAssessorAgent: | |
| """Agent that assesses the reliability and limitations of the prediction""" | |
| def __init__(self): | |
| self.llm = llm_config.analyzer | |
| def assess(self, state: GuildState) -> GuildState: | |
| """ | |
| Assess prediction confidence and identify limitations. | |
| Args: | |
| state: Current guild state | |
| Returns: | |
| Updated state with confidence assessment | |
| """ | |
| print("\n" + "=" * 70) | |
| print("EXECUTING: Confidence Assessor Agent") | |
| print("=" * 70) | |
| model_prediction = state["model_prediction"] | |
| disease = model_prediction["disease"] | |
| ml_confidence = model_prediction["confidence"] | |
| probabilities = model_prediction.get("probabilities", {}) | |
| biomarkers = state["patient_biomarkers"] | |
| # Collect previous agent findings | |
| biomarker_analysis = state.get("biomarker_analysis") or {} | |
| disease_explanation = self._get_agent_findings(state, "Disease Explainer") | |
| linker_findings = self._get_agent_findings(state, "Biomarker-Disease Linker") | |
| print(f"\nAssessing confidence for {disease} prediction...") | |
| # Evaluate evidence strength | |
| evidence_strength = self._evaluate_evidence_strength(biomarker_analysis, disease_explanation, linker_findings) | |
| # Identify limitations | |
| limitations = self._identify_limitations(biomarkers, biomarker_analysis, probabilities) | |
| # Calculate aggregate reliability | |
| reliability = self._calculate_reliability(ml_confidence, evidence_strength, len(limitations)) | |
| # Generate assessment summary | |
| assessment_summary = self._generate_assessment( | |
| disease, ml_confidence, reliability, evidence_strength, limitations | |
| ) | |
| # Create agent output | |
| output = AgentOutput( | |
| agent_name="Confidence Assessor", | |
| findings={ | |
| "prediction_reliability": reliability, | |
| "ml_confidence": ml_confidence, | |
| "evidence_strength": evidence_strength, | |
| "limitations": limitations, | |
| "assessment_summary": assessment_summary, | |
| "recommendation": self._get_recommendation(reliability), | |
| "alternative_diagnoses": self._get_alternatives(probabilities), | |
| }, | |
| ) | |
| # Update state | |
| print("\nConfidence assessment complete") | |
| print(f" - Prediction reliability: {reliability}") | |
| print(f" - Evidence strength: {evidence_strength}") | |
| print(f" - Limitations identified: {len(limitations)}") | |
| return {"agent_outputs": [output]} | |
| def _get_agent_findings(self, state: GuildState, agent_name: str) -> dict: | |
| """Extract findings from a specific agent""" | |
| for output in state.get("agent_outputs", []): | |
| if output.agent_name == agent_name: | |
| return output.findings | |
| return {} | |
| def _evaluate_evidence_strength( | |
| self, biomarker_analysis: dict, disease_explanation: dict, linker_findings: dict | |
| ) -> str: | |
| """Evaluate the strength of supporting evidence""" | |
| score = 0 | |
| max_score = 5 | |
| # Check biomarker validation quality | |
| flags = biomarker_analysis.get("biomarker_flags", []) | |
| abnormal_count = len([f for f in flags if f.get("status") != "NORMAL"]) | |
| if abnormal_count >= 3: | |
| score += 1 | |
| if abnormal_count >= 5: | |
| score += 1 | |
| # Check disease explanation quality | |
| if disease_explanation.get("retrieval_quality", 0) >= 3: | |
| score += 1 | |
| # Check biomarker-disease linking | |
| key_drivers = linker_findings.get("key_drivers", []) | |
| if len(key_drivers) >= 2: | |
| score += 1 | |
| if len(key_drivers) >= 4: | |
| score += 1 | |
| # Map score to categorical rating | |
| if score >= 4: | |
| return "STRONG" | |
| elif score >= 2: | |
| return "MODERATE" | |
| else: | |
| return "WEAK" | |
| def _identify_limitations( | |
| self, biomarkers: dict[str, float], biomarker_analysis: dict, probabilities: dict[str, float] | |
| ) -> list[str]: | |
| """Identify limitations and uncertainties""" | |
| limitations = [] | |
| # Check for missing biomarkers | |
| expected_biomarkers = BiomarkerValidator().expected_biomarker_count() | |
| if len(biomarkers) < expected_biomarkers: | |
| missing = expected_biomarkers - len(biomarkers) | |
| limitations.append(f"Missing data: {missing} biomarker(s) not provided") | |
| # Check for close alternative predictions | |
| sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True) | |
| if len(sorted_probs) >= 2: | |
| top1, prob1 = sorted_probs[0] | |
| top2, prob2 = sorted_probs[1] | |
| if prob2 > 0.15: # Alternative is significant | |
| limitations.append(f"Differential diagnosis: {top2} also possible ({prob2:.1%} probability)") | |
| # Check for normal biomarkers despite prediction | |
| flags = biomarker_analysis.get("biomarker_flags", []) | |
| relevant = biomarker_analysis.get("relevant_biomarkers", []) | |
| normal_relevant = [f for f in flags if f.get("name") in relevant and f.get("status") == "NORMAL"] | |
| if len(normal_relevant) >= 2: | |
| limitations.append("Some disease-relevant biomarkers are within normal range") | |
| # Check for safety alerts (indicates complexity) | |
| alerts = biomarker_analysis.get("safety_alerts", []) | |
| if len(alerts) >= 2: | |
| limitations.append("Multiple critical values detected; professional evaluation essential") | |
| return limitations | |
| def _calculate_reliability(self, ml_confidence: float, evidence_strength: str, limitation_count: int) -> str: | |
| """Calculate overall prediction reliability""" | |
| score = 0 | |
| # ML confidence contribution | |
| if ml_confidence >= 0.8: | |
| score += 3 | |
| elif ml_confidence >= 0.6: | |
| score += 2 | |
| elif ml_confidence >= 0.4: | |
| score += 1 | |
| # Evidence strength contribution | |
| if evidence_strength == "STRONG": | |
| score += 3 | |
| elif evidence_strength == "MODERATE": | |
| score += 2 | |
| else: | |
| score += 1 | |
| # Limitation penalty | |
| score -= min(limitation_count, 3) | |
| # Map to categorical | |
| if score >= 5: | |
| return "HIGH" | |
| elif score >= 3: | |
| return "MODERATE" | |
| else: | |
| return "LOW" | |
| def _generate_assessment( | |
| self, disease: str, ml_confidence: float, reliability: str, evidence_strength: str, limitations: list[str] | |
| ) -> str: | |
| """Generate human-readable assessment summary""" | |
| prompt = f"""As a medical AI assessment system, provide a brief confidence statement about this prediction: | |
| Disease Predicted: {disease} | |
| ML Model Confidence: {ml_confidence:.1%} | |
| Overall Reliability: {reliability} | |
| Evidence Strength: {evidence_strength} | |
| Limitations: {len(limitations)} identified | |
| Write a 2-3 sentence assessment that: | |
| 1. States the overall reliability | |
| 2. Mentions key strengths or weaknesses | |
| 3. Emphasizes the need for professional medical consultation | |
| Be honest about uncertainty. Patient safety is paramount.""" | |
| try: | |
| response = self.llm.invoke(prompt) | |
| return response.content.strip() | |
| except Exception as e: | |
| print(f"Warning: Assessment generation failed: {e}") | |
| return f"The {disease} prediction has {reliability.lower()} reliability based on available data. Professional medical evaluation is strongly recommended for accurate diagnosis." | |
| def _get_recommendation(self, reliability: str) -> str: | |
| """Get action recommendation based on reliability""" | |
| if reliability == "HIGH": | |
| return "High confidence prediction. Schedule medical consultation to confirm diagnosis and discuss treatment options." | |
| elif reliability == "MODERATE": | |
| return "Moderate confidence prediction. Medical consultation recommended for professional evaluation and additional testing if needed." | |
| else: | |
| return "Low confidence prediction. Professional medical assessment essential. Additional tests may be required for accurate diagnosis." | |
| def _get_alternatives(self, probabilities: dict[str, float]) -> list[dict[str, Any]]: | |
| """Get alternative diagnoses to consider""" | |
| sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True) | |
| alternatives = [] | |
| for disease, prob in sorted_probs[1:4]: # Top 3 alternatives | |
| if prob > 0.05: # Only significant alternatives | |
| alternatives.append( | |
| {"disease": disease, "probability": prob, "note": "Consider discussing with healthcare provider"} | |
| ) | |
| return alternatives | |
| # Create agent instance for import | |
| confidence_assessor_agent = ConfidenceAssessorAgent() | |