BioRAG / src /bio_rag /risk_scorer.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class RiskProfile:
severity: float
type_score: float
omission: float
class RiskScorer:
"""Implement rule-based severity classifiers for medical statements."""
HIGH_SEVERITY_KEYWORDS = {
"dosage", "dose", "mg", "units/kg", "contraindicated",
"hypoglycemia", "ketoacidosis", "renal failure", "dialysis",
"surgery", "emergency", "fatal", "toxic", "lactic acidosis",
"gfr", "egfr", "creatinine", "nephropathy",
"insulin dose", "insulin dosage", "insulin regimen",
"insulin therapy", "insulin administration",
"drug interaction", "overdose",
"discontinue insulin", "stop insulin", "stopping insulin",
"discontinue therapy", "stop taking insulin",
"severe renal",
"glipizide", "glimepiride", "sulfonylurea", "pioglitazone",
"sitagliptin", "dapagliflozin", "empagliflozin", "liraglutide", "semaglutide",
"primary treatment", "first-line", "first line", "drug of choice",
"treatment of choice", "recommended treatment"
}
# These terms sound medical but are general concepts, not dangerous claims
SEVERITY_EXCEPTIONS = {
"insulin sensitivity", "insulin resistance", "insulin secretion",
"insulin signaling", "insulin receptor", "insulin levels",
"kidney function", "renal function",
"dose adjustment", "dose adjusted", "adjust the dose",
"careful monitoring", "close monitoring", "closely monitored",
"dose-dependent", "dosage adjustment",
}
MED_SEVERITY_KEYWORDS = {
"diet", "hba1c", "monitoring", "lifestyle", "exercise",
"target", "frequency", "guidance"
}
def calculate_profile(self, claim: str) -> RiskProfile:
claim_lower = claim.lower()
# 1. Severity
severity = 0.3 # Base Low
# Check exceptions first
has_exception = any(w in claim_lower for w in self.SEVERITY_EXCEPTIONS)
if not has_exception and any(w in claim_lower for w in self.HIGH_SEVERITY_KEYWORDS):
severity = 1.0
elif any(w in claim_lower for w in self.MED_SEVERITY_KEYWORDS):
severity = 0.7
elif has_exception:
severity = 0.5
else:
if "cure" in claim_lower:
severity = 0.8
# 2. Type
fabrication_signals = ["causes", "proven", "always", "never", "cures", "eliminates", "guarantees", "completely safe", "no risk", "fully effective", "definitely"]
type_score = 1.0 if any(w in claim_lower for w in fabrication_signals) else 0.5
# 3. Omission
omission_signals = ["not recommended", "avoid", "contraindicated", "warning", "caution", "do not", "should not"]
omission = 1.0 if any(w in claim_lower for w in omission_signals) else 0.5
return RiskProfile(
severity=severity,
type_score=type_score,
omission=omission
)
def compute_weighted_risk(self, nli_prob: float, profile: RiskProfile) -> float:
"""Risk-Weighted Score = NLI_Probability x (Severity x Type x Omission)"""
nli_adjusted = min(1.0, nli_prob * 2.0)
is_unverified = abs(nli_prob - 0.8501) < 0.0001
is_genuine_contradiction = nli_prob >= 0.7 and not is_unverified
if is_genuine_contradiction:
# Evidence actively contradicts this claim — assume worst case
effective_type = 1.0
effective_omission = 1.0
elif is_unverified and profile.severity >= 1.0:
# No evidence found + HIGH severity = assume worst case
effective_type = 1.0
effective_omission = 1.0
else:
effective_type = profile.type_score
effective_omission = profile.omission
return min(1.0, nli_adjusted * (profile.severity * effective_type * effective_omission))