Preformu / layers /regulatory_decision_engine.py
Kevinshh's picture
Upload regulatory_decision_engine.py
3a19140 verified
"""
Regulatory Decision Engine - Layer 2: Scientific & Regulatory Decision Layer
This is the CORE of the three-layer architecture. It:
1. Validates data completeness against ICH requirements
2. Selects appropriate analysis methods (zero-order, first-order, pooled)
3. Calls StabilityCalculator for all numerical computations
4. Decides whether to proceed or refuse based on regulatory rules
5. Never uses LLM - pure Python/scipy rule-based logic
CRITICAL: This module must NEVER import or call any LLM-related code.
All decisions are deterministic and auditable.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Any, Tuple
from datetime import datetime
import json
import numpy as np
# Local imports
from schemas.analysis_intent import (
AnalysisIntent,
AnalysisType,
AnalysisPurpose,
ExtractedDataSummary,
)
from schemas.decision_result import (
RegulatoryDecisionResult,
RefusalResult,
RefusalSeverity,
DataQuality,
DataQualityReport,
KineticFitSummary,
PredictionSummary,
ArrheniusResult,
TrendTransferResult,
BatchRankingItem,
RegulatoryNotes,
CalculationTrace,
)
from utils.stability_calculator import StabilityCalculator, KineticFitResult, PredictionResult
class RegulatoryDecisionEngine:
"""
Scientific & Regulatory Decision Engine (Layer 2).
This class is the gatekeeper between user intent and analysis results.
It ensures all outputs are scientifically valid and regulatory compliant.
KEY PRINCIPLES:
1. NO LLM calls - all logic is deterministic
2. Refuse gracefully when data is insufficient
3. All calculations via StabilityCalculator (scipy-based)
4. Full audit trail for FDA/EMA inspection
"""
# ICH Q1E Constants
ICH_MIN_DATAPOINTS = 3
ICH_MAX_EXTRAPOLATION_FACTOR = 2.0
ICH_MIN_R2_SUBMISSION = 0.90
ICH_MIN_R2_RD = 0.80
def __init__(self):
self.calculator = StabilityCalculator()
self.version = "2.1.0" # Updated for Expert Review Fixes
def execute(
self,
intent: AnalysisIntent,
extracted_data: Dict[str, Any]
) -> RegulatoryDecisionResult:
"""
Execute the regulatory decision pipeline.
Args:
intent: Structured user intent from Layer 1
extracted_data: Raw extracted data from file parsing
Returns:
RegulatoryDecisionResult for Layer 3 to present
"""
result = RegulatoryDecisionResult(
can_proceed=False,
timestamp=datetime.now().isoformat()
)
# Step 1: Validate data quality
data_quality = self._assess_data_quality(extracted_data, intent)
result.data_quality = data_quality
if data_quality.overall_quality == DataQuality.INSUFFICIENT:
result.refusal = self._create_refusal(
severity=RefusalSeverity.HARD_REFUSAL,
reason="数据不足,无法执行可靠分析",
reference="ICH Q1E Section 2.1",
suggestions=["补充稳定性数据点", "确保每个条件至少3个时间点"],
missing=data_quality.issues
)
return result
# Step 2: Check global constraints (like max extrapolation)
# Note: We now do per-prediction validation, but global check is still useful early warning
extrapolation_check = self._check_extrapolation_limits_global(intent, extracted_data)
if not extrapolation_check["allowed"]:
result.regulatory_notes.limitations.append(
f"外推警告: {extrapolation_check['reason']}"
)
# Step 3: Execute analysis based on type
try:
if intent.analysis_type == AnalysisType.SHELF_LIFE_PREDICTION:
self._execute_shelf_life_prediction(intent, extracted_data, result)
elif intent.analysis_type == AnalysisType.BATCH_COMPARISON:
self._execute_batch_comparison(intent, extracted_data, result)
elif intent.analysis_type == AnalysisType.TREND_ASSESSMENT:
self._execute_trend_assessment(intent, extracted_data, result)
elif intent.analysis_type == AnalysisType.RISK_EVALUATION:
self._execute_risk_evaluation(intent, extracted_data, result)
# Post-check: If we have no valid predictions/rankings, we might need to refuse
self._validate_final_result_state(result)
except Exception as e:
result.refusal = self._create_refusal(
severity=RefusalSeverity.HARD_REFUSAL,
reason=f"计算过程发生错误: {str(e)}",
reference="N/A",
suggestions=["检查数据格式", "联系技术支持"]
)
result.can_proceed = False
# Step 4: Add regulatory notes
self._add_regulatory_notes(intent, result)
return result
def _validate_final_result_state(self, result: RegulatoryDecisionResult):
"""Final check to see if we actually produced anything valid."""
has_valid_pred = any(p.is_valid for p in result.predictions.values())
has_valid_rank = len(result.batch_ranking) > 0
has_fits = len(result.kinetic_fits) > 0
if not (has_valid_pred or has_valid_rank or has_fits):
result.can_proceed = False
result.refusal = self._create_refusal(
severity=RefusalSeverity.HARD_REFUSAL,
reason="没有产生有效的分析结果(所有结果均因数据质量或规则限制被过滤)",
reference="ICH Q1E",
suggestions=["补充更多数据点", "检查数据质量(R2)"]
)
else:
result.can_proceed = True
def _assess_data_quality(
self,
data: Dict[str, Any],
intent: AnalysisIntent
) -> DataQualityReport:
"""Assess data quality against ICH requirements."""
report = DataQualityReport(
overall_quality=DataQuality.SUFFICIENT,
n_batches=0,
n_conditions=0,
n_total_datapoints=0
)
batches = data.get("batches", [])
report.n_batches = len(batches)
if report.n_batches == 0:
report.overall_quality = DataQuality.INSUFFICIENT
report.issues.append("未检测到任何批次数据")
return report
total_conditions = 0
total_points = 0
for batch in batches:
conditions = batch.get("conditions", [])
for cond in conditions:
total_conditions += 1
cond_id = cond.get("condition_id", "unknown")
timepoints = cond.get("timepoints", [])
# Filter None or empty values
valid_tps = [t for t in timepoints if t is not None]
n_points = len(valid_tps)
total_points += n_points
if n_points >= self.ICH_MIN_DATAPOINTS:
report.conditions_with_sufficient_data.append(cond_id)
else:
report.conditions_with_insufficient_data.append(cond_id)
report.warnings.append(
f"条件 {cond_id}: 仅 {n_points} 个有效时间点 (最低要求: {self.ICH_MIN_DATAPOINTS})"
)
report.n_conditions = total_conditions
report.n_total_datapoints = total_points
# Determine overall quality
if not report.conditions_with_sufficient_data:
# Allow proceeding ONLY if batch comparison is requested (might compare T=0)
# BUT strict prediction requires time points.
if intent.analysis_type != AnalysisType.BATCH_COMPARISON:
report.overall_quality = DataQuality.INSUFFICIENT
report.issues.append("所有条件的数据点均不足以进行动力学分析(需>=3点)")
else:
report.overall_quality = DataQuality.MARGINAL
report.warnings.append("无趋势分析数据,仅能进行初始点对比")
elif len(report.conditions_with_insufficient_data) > 0:
report.overall_quality = DataQuality.MARGINAL
return report
def _check_extrapolation_limits_global(
self,
intent: AnalysisIntent,
data: Dict[str, Any]
) -> Dict[str, Any]:
"""Check if requested extrapolation is within ICH limits (Global Check)."""
max_observed_time = 0
for batch in data.get("batches", []):
for cond in batch.get("conditions", []):
timepoints = [t for t in cond.get("timepoints", []) if t is not None]
if timepoints:
max_observed_time = max(max_observed_time, max(timepoints))
if max_observed_time == 0:
return {"allowed": False, "reason": "无法确定最大观测时间点"}
max_target = max(intent.preferences.target_timepoints) if intent.preferences.target_timepoints else 24
max_allowed = max_observed_time * self.ICH_MAX_EXTRAPOLATION_FACTOR
if max_target > max_allowed:
return {
"allowed": False,
"reason": f"目标时间点 {max_target}M 超过ICH允许的最大外推范围 {max_allowed:.1f}M (2x实测)"
}
return {"allowed": True}
def _execute_shelf_life_prediction(
self,
intent: AnalysisIntent,
data: Dict[str, Any],
result: RegulatoryDecisionResult
):
"""Execute shelf-life prediction analysis."""
batches = data.get("batches", [])
target_batch = self._find_target_batch(batches)
if target_batch is None:
raise ValueError("未找到可分析的目标批次")
# Process each condition
for cond in target_batch.get("conditions", []):
cond_id = cond.get("condition_id", "unknown")
timepoints = cond.get("timepoints", [])
cqa_data = self._find_cqa_data(cond, intent.constraints.primary_cqa)
if cqa_data is None:
continue
values = cqa_data.get("values", [])
# 1. Check Data Sufficiency
valid_t, valid_y = self._clean_data(timepoints, values)
if len(valid_t) < self.ICH_MIN_DATAPOINTS:
continue
# 2. Perform Kinetic Fitting
fit_result = self.calculator.fit_zero_order(valid_t, valid_y)
# 3. Validation Rules (Expert Review Fix)
# R2 Check
is_fit_valid = True
if fit_result.R2 < self.ICH_MIN_R2_RD:
# Only allow if slope is effectively zero (stable product)
if abs(fit_result.k) > 1e-4: # Arbitrary small threshold
is_fit_valid = False
result.regulatory_notes.limitations.append(
f"条件 {cond_id}: R²={fit_result.R2:.4f} < 0.8,拟合质量差,结果不可信"
)
# Negative k warning (assuming impurity increase)
if fit_result.k < 0 and intent.constraints.primary_cqa != "含量":
result.regulatory_notes.warnings.append(
f"条件 {cond_id}: 检测到负降解速率 (k={fit_result.k:.4f}),请检查数据异常或质量守恒"
)
# Store result
result.kinetic_fits[cond_id] = KineticFitSummary(
condition_id=cond_id,
model_type="zero_order",
k=fit_result.k,
k_unit=fit_result.k_unit,
y0=fit_result.y0,
R2=fit_result.R2,
SE_k=fit_result.SE_k,
n_points=fit_result.n,
equation=fit_result.equation,
confidence_level=self._assess_fit_confidence(fit_result.R2, intent),
scipy_method="linregress",
calculation_timestamp=datetime.now().isoformat()
)
# Only generate predictions if fit is valid
if is_fit_valid:
self._generate_predictions_for_fit(intent, result, fit_result, valid_t)
# Attempt Arrhenius analysis if multiple temperatures available
self._attempt_arrhenius_analysis(data, result)
def _execute_batch_comparison(
self,
intent: AnalysisIntent,
data: Dict[str, Any],
result: RegulatoryDecisionResult
):
"""
Execute batch comparison/ranking analysis with IMPROVED LOGIC.
Ref: Expert Review - Don't rank T=0 only batches higher than stable trending batches.
"""
batches = data.get("batches", [])
rankings = []
for batch in batches:
batch_id = batch.get("batch_id", "unknown")
batch_name = batch.get("batch_name", batch_id)
score = 0
reasons = []
completeness = "single_point"
# Analyze primary CQA across conditions
best_k = float('inf')
initial_val = float('inf')
has_trend_data = False
# Get data for this batch
for cond in batch.get("conditions", []):
cond_id = cond.get("condition_id", "unknown")
timepoints = cond.get("timepoints", [])
cqa_data = self._find_cqa_data(cond, intent.constraints.primary_cqa)
if not cqa_data:
continue
valid_t, valid_y = self._clean_data(timepoints, cqa_data.get("values", []))
if valid_y:
initial_val = min(initial_val, valid_y[0])
if len(valid_t) >= self.ICH_MIN_DATAPOINTS:
has_trend_data = True
fit = self.calculator.fit_zero_order(valid_t, valid_y)
# Store fit for display
result.kinetic_fits[f"{batch_id}_{cond_id}"] = KineticFitSummary(
condition_id=f"{batch_id}_{cond_id}",
model_type="zero_order",
k=fit.k,
k_unit=fit.k_unit,
y0=fit.y0,
R2=fit.R2,
SE_k=fit.SE_k,
n_points=fit.n,
equation=fit.equation,
confidence_level=self._assess_fit_confidence(fit.R2, intent),
scipy_method="linregress",
calculation_timestamp=datetime.now().isoformat()
)
if 0 <= fit.k < best_k: # Prefer positive k (impurity) but smallest
best_k = fit.k
# Scoring Logic (New)
confidence = "low"
if has_trend_data:
completeness = "full_trend"
score += 50 # Base score for having data
confidence = "high"
# Rate k
if best_k < 0.05: score += 40
elif best_k < 0.1: score += 20
else: score += 10
reasons.append(f"具备稳定性趋势数据 (k={best_k:.4f})")
else:
score += 10 # Penalty for no trend data
reasons.append("仅有初始数据,缺乏长期趋势,风险较高")
confidence = "low"
# Rate Initial Value (Secondary factor)
if initial_val < intent.constraints.specification_limit * 0.5:
score += 10
reasons.append(f"初始杂质低 ({initial_val:.2f}%)")
rankings.append(BatchRankingItem(
rank=0,
batch_id=batch_id,
batch_name=batch_name,
score=score,
reason="; ".join(reasons),
k_best=best_k if best_k != float('inf') else None,
r2_best=None, # Simplified
data_completeness=completeness,
confidence=confidence
))
# Sort by score descending
rankings.sort(key=lambda x: x.score, reverse=True)
# Add rank
for i, r in enumerate(rankings):
r.rank = i + 1
result.batch_ranking = rankings
def _execute_trend_assessment(self, intent, data, result):
# Implementation similar to shelf life but without prediction generation
# For brevity, reusing logic
self._execute_shelf_life_prediction(intent, data, result)
result.predictions = {} # Clear predictions
def _execute_risk_evaluation(self, intent, data, result):
self._execute_shelf_life_prediction(intent, data, result)
def _generate_predictions_for_fit(
self,
intent: AnalysisIntent,
result: RegulatoryDecisionResult,
fit_result: KineticFitResult,
observed_timepoints: List[float]
):
"""Generate predictions with strict validation."""
max_observed = max(observed_timepoints)
predictions = self.calculator.predict_with_ci(
fit_result=fit_result,
target_times=intent.preferences.target_timepoints,
specification_limit=intent.constraints.specification_limit,
confidence=intent.preferences.required_confidence
)
for pred in predictions:
tp_key = f"{pred.timepoint}M"
# Validation Logic
is_valid = True
validity_reason = ""
# 1. Extrapolation Limit
if pred.timepoint > max_observed * self.ICH_MAX_EXTRAPOLATION_FACTOR:
is_valid = False
validity_reason = f"外推时间 ({pred.timepoint}M) 超过实测范围 ({max_observed}M) 的2倍"
result.predictions[tp_key] = PredictionSummary(
timepoint_months=pred.timepoint,
point_estimate=pred.point_estimate,
CI_lower=pred.CI_lower,
CI_upper=pred.CI_upper,
risk_level=self._map_risk_level(pred.risk_level),
specification_limit=pred.specification_limit,
margin_to_limit=pred.specification_limit - pred.CI_upper,
is_valid=is_valid,
validity_reason=validity_reason
)
result.calculation_trace.add(
step=f"prediction_{tp_key}",
inputs={"timepoint": pred.timepoint, "k": fit_result.k},
outputs={"point": pred.point_estimate, "valid": is_valid},
method="linear_prediction"
)
# --- Helper Methods ---
def _find_target_batch(self, batches: List[Dict]) -> Optional[Dict]:
for batch in batches:
if batch.get("batch_type") == "target": return batch
return batches[0] if batches else None
def _find_cqa_data(self, cond: Dict, cqa_name: str) -> Optional[Dict]:
for cqa in cond.get("cqa_data", []):
if cqa.get("cqa_name") == cqa_name: return cqa
return None
def _clean_data(self, t: List, y: List) -> Tuple[List[float], List[float]]:
"""Remove None values."""
clean_t = []
clean_y = []
for ti, yi in zip(t, y):
if ti is not None and yi is not None:
clean_t.append(float(ti))
clean_y.append(float(yi))
return clean_t, clean_y
def _assess_fit_confidence(self, r2: float, intent: AnalysisIntent) -> str:
"""Assess fit confidence level."""
# Strict logic for P1 Fix
if r2 < 0.8: return "low"
if r2 < 0.9: return "medium"
return "high"
def _map_risk_level(self, risk_str: str) -> str:
risk_lower = risk_str.lower()
if "compliant" in risk_lower or "合格" in risk_lower or "low" in risk_lower:
return "compliant"
elif "marginal" in risk_lower or "临界" in risk_lower:
return "marginal"
else:
return "non_compliant"
def _create_refusal(
self, severity: RefusalSeverity, reason: str, reference: str, suggestions: List[str] = None, missing: List[str] = None
) -> RefusalResult:
return RefusalResult(
severity=severity, reason=reason, regulatory_reference=reference, suggestions=suggestions or [], missing_data=missing or []
)
def _add_regulatory_notes(self, intent, result):
result.regulatory_notes.statistical_method_statement = (
"线性回归采用最小二乘法 (scipy.stats.linregress),"
"置信区间基于t分布计算 (scipy.stats.t.ppf)。"
"所有预测均经过ICH Q1E规则有效性验证。"
)
if result.predictions:
max_pred = max(p.timepoint_months for p in result.predictions.values() if p.is_valid) if result.predictions else 0
if max_pred > 0:
result.regulatory_notes.extrapolation_statement = f"本报告包含有效外推预测至 {max_pred} 个月。"
def _attempt_arrhenius_analysis(self, data, result):
# Placeholder for brevity - similar to previous implementation
pass