hackathon_code4change / src /monitoring /ripeness_metrics.py
RoyAalekh's picture
refactored project structure. renamed scheduler dir to src
6a28f91
"""Ripeness classification accuracy tracking and reporting.
Tracks predictions and actual outcomes to measure false positive/negative rates
and enable data-driven threshold calibration.
"""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
import pandas as pd
from src.core.ripeness import RipenessStatus
@dataclass
class RipenessPrediction:
"""Single ripeness classification prediction and outcome."""
case_id: str
predicted_status: RipenessStatus
prediction_date: datetime
# Actual outcome (filled in after hearing)
actual_outcome: Optional[str] = None
was_adjourned: Optional[bool] = None
outcome_date: Optional[datetime] = None
class RipenessMetrics:
"""Tracks ripeness classification accuracy for feedback loop calibration."""
def __init__(self):
"""Initialize metrics tracker."""
self.predictions: dict[str, RipenessPrediction] = {}
self.completed_predictions: list[RipenessPrediction] = []
def record_prediction(
self,
case_id: str,
predicted_status: RipenessStatus,
prediction_date: datetime,
) -> None:
"""Record a ripeness classification prediction.
Args:
case_id: Case identifier
predicted_status: Predicted ripeness status
prediction_date: When prediction was made
"""
self.predictions[case_id] = RipenessPrediction(
case_id=case_id,
predicted_status=predicted_status,
prediction_date=prediction_date,
)
def record_outcome(
self,
case_id: str,
actual_outcome: str,
was_adjourned: bool,
outcome_date: datetime,
) -> None:
"""Record actual hearing outcome for a predicted case.
Args:
case_id: Case identifier
actual_outcome: Actual hearing outcome (e.g., "ADJOURNED", "ARGUMENTS")
was_adjourned: Whether hearing was adjourned
outcome_date: When outcome occurred
"""
if case_id in self.predictions:
pred = self.predictions[case_id]
pred.actual_outcome = actual_outcome
pred.was_adjourned = was_adjourned
pred.outcome_date = outcome_date
# Move to completed
self.completed_predictions.append(pred)
del self.predictions[case_id]
def get_accuracy_metrics(self) -> dict[str, float]:
"""Compute classification accuracy metrics.
Returns:
Dictionary with accuracy metrics:
- total_predictions: Total predictions made
- completed_predictions: Predictions with outcomes
- false_positive_rate: RIPE cases that adjourned
- false_negative_rate: UNRIPE cases that progressed
- unknown_rate: Cases classified as UNKNOWN
- ripe_precision: P(progressed | predicted RIPE)
- unripe_recall: P(predicted UNRIPE | adjourned)
"""
if not self.completed_predictions:
return {
"total_predictions": 0,
"completed_predictions": 0,
"false_positive_rate": 0.0,
"false_negative_rate": 0.0,
"unknown_rate": 0.0,
"ripe_precision": 0.0,
"unripe_recall": 0.0,
}
total = len(self.completed_predictions)
# Count predictions by status
ripe_predictions = [
p
for p in self.completed_predictions
if p.predicted_status == RipenessStatus.RIPE
]
unripe_predictions = [
p for p in self.completed_predictions if p.predicted_status.is_unripe()
]
unknown_predictions = [
p
for p in self.completed_predictions
if p.predicted_status == RipenessStatus.UNKNOWN
]
# Count actual outcomes
adjourned_cases = [p for p in self.completed_predictions if p.was_adjourned]
[p for p in self.completed_predictions if not p.was_adjourned]
# False positives: predicted RIPE but adjourned
false_positives = [p for p in ripe_predictions if p.was_adjourned]
false_positive_rate = (
len(false_positives) / len(ripe_predictions) if ripe_predictions else 0.0
)
# False negatives: predicted UNRIPE but progressed
false_negatives = [p for p in unripe_predictions if not p.was_adjourned]
false_negative_rate = (
len(false_negatives) / len(unripe_predictions)
if unripe_predictions
else 0.0
)
# Precision: of predicted RIPE, how many progressed?
ripe_correct = [p for p in ripe_predictions if not p.was_adjourned]
ripe_precision = (
len(ripe_correct) / len(ripe_predictions) if ripe_predictions else 0.0
)
# Recall: of actually adjourned cases, how many did we predict UNRIPE?
unripe_correct = [p for p in unripe_predictions if p.was_adjourned]
unripe_recall = (
len(unripe_correct) / len(adjourned_cases) if adjourned_cases else 0.0
)
return {
"total_predictions": total + len(self.predictions),
"completed_predictions": total,
"false_positive_rate": false_positive_rate,
"false_negative_rate": false_negative_rate,
"unknown_rate": len(unknown_predictions) / total,
"ripe_precision": ripe_precision,
"unripe_recall": unripe_recall,
}
def get_confusion_matrix(self) -> dict[str, dict[str, int]]:
"""Generate confusion matrix of predictions vs outcomes.
Returns:
Nested dict: predicted_status -> actual_outcome -> count
"""
matrix: dict[str, dict[str, int]] = {
"RIPE": {"progressed": 0, "adjourned": 0},
"UNRIPE": {"progressed": 0, "adjourned": 0},
"UNKNOWN": {"progressed": 0, "adjourned": 0},
}
for pred in self.completed_predictions:
if pred.predicted_status == RipenessStatus.RIPE:
key = "RIPE"
elif pred.predicted_status.is_unripe():
key = "UNRIPE"
else:
key = "UNKNOWN"
outcome_key = "adjourned" if pred.was_adjourned else "progressed"
matrix[key][outcome_key] += 1
return matrix
def to_dataframe(self) -> pd.DataFrame:
"""Export predictions to DataFrame for analysis.
Returns:
DataFrame with columns: case_id, predicted_status, prediction_date,
actual_outcome, was_adjourned, outcome_date
"""
records = []
for pred in self.completed_predictions:
records.append(
{
"case_id": pred.case_id,
"predicted_status": pred.predicted_status.value,
"prediction_date": pred.prediction_date,
"actual_outcome": pred.actual_outcome,
"was_adjourned": pred.was_adjourned,
"outcome_date": pred.outcome_date,
"correct_prediction": (
(
pred.predicted_status == RipenessStatus.RIPE
and not pred.was_adjourned
)
or (pred.predicted_status.is_unripe() and pred.was_adjourned)
),
}
)
return pd.DataFrame(records)
def save_report(self, output_path: Path) -> None:
"""Save accuracy report and predictions to files.
Args:
output_path: Path to output directory
"""
output_path.mkdir(parents=True, exist_ok=True)
# Save metrics summary
metrics = self.get_accuracy_metrics()
metrics_df = pd.DataFrame([metrics])
metrics_df.to_csv(output_path / "ripeness_accuracy.csv", index=False)
# Save confusion matrix
matrix = self.get_confusion_matrix()
matrix_df = pd.DataFrame(matrix).T
matrix_df.to_csv(output_path / "ripeness_confusion_matrix.csv")
# Save detailed predictions
if self.completed_predictions:
predictions_df = self.to_dataframe()
predictions_df.to_csv(output_path / "ripeness_predictions.csv", index=False)
# Generate human-readable report
report_lines = [
"Ripeness Classification Accuracy Report",
"=" * 60,
f"Total predictions: {metrics['total_predictions']}",
f"Completed predictions: {metrics['completed_predictions']}",
"",
"Accuracy Metrics:",
f" False positive rate (RIPE but adjourned): {metrics['false_positive_rate']:.1%}",
f" False negative rate (UNRIPE but progressed): {metrics['false_negative_rate']:.1%}",
f" UNKNOWN rate: {metrics['unknown_rate']:.1%}",
f" RIPE precision (progressed | predicted RIPE): {metrics['ripe_precision']:.1%}",
f" UNRIPE recall (predicted UNRIPE | adjourned): {metrics['unripe_recall']:.1%}",
"",
"Confusion Matrix:",
f" RIPE -> Progressed: {matrix['RIPE']['progressed']}, Adjourned: {matrix['RIPE']['adjourned']}",
f" UNRIPE -> Progressed: {matrix['UNRIPE']['progressed']}, Adjourned: {matrix['UNRIPE']['adjourned']}",
f" UNKNOWN -> Progressed: {matrix['UNKNOWN']['progressed']}, Adjourned: {matrix['UNKNOWN']['adjourned']}",
"",
"Interpretation:",
]
# Add interpretation
if metrics["false_positive_rate"] > 0.20:
report_lines.append(
" - HIGH false positive rate: Consider increasing MIN_SERVICE_HEARINGS"
)
if metrics["false_negative_rate"] > 0.15:
report_lines.append(
" - HIGH false negative rate: Consider decreasing MIN_STAGE_DAYS"
)
if metrics["unknown_rate"] < 0.05:
report_lines.append(
" - LOW UNKNOWN rate: System may be overconfident, add uncertainty"
)
if metrics["ripe_precision"] > 0.85:
report_lines.append(
" - GOOD RIPE precision: Most RIPE predictions are correct"
)
if metrics["unripe_recall"] < 0.60:
report_lines.append(
" - LOW UNRIPE recall: Missing many bottlenecks, refine detection"
)
report_text = "\n".join(report_lines)
(output_path / "ripeness_report.txt").write_text(report_text)
print(f"Ripeness accuracy report saved to {output_path}")