Spaces:
Sleeping
Sleeping
| """Ripeness classification accuracy tracking and reporting. | |
| Tracks predictions and actual outcomes to measure false positive/negative rates | |
| and enable data-driven threshold calibration. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| import pandas as pd | |
| from src.core.ripeness import RipenessStatus | |
| class RipenessPrediction: | |
| """Single ripeness classification prediction and outcome.""" | |
| case_id: str | |
| predicted_status: RipenessStatus | |
| prediction_date: datetime | |
| # Actual outcome (filled in after hearing) | |
| actual_outcome: Optional[str] = None | |
| was_adjourned: Optional[bool] = None | |
| outcome_date: Optional[datetime] = None | |
| class RipenessMetrics: | |
| """Tracks ripeness classification accuracy for feedback loop calibration.""" | |
| def __init__(self): | |
| """Initialize metrics tracker.""" | |
| self.predictions: dict[str, RipenessPrediction] = {} | |
| self.completed_predictions: list[RipenessPrediction] = [] | |
| def record_prediction( | |
| self, | |
| case_id: str, | |
| predicted_status: RipenessStatus, | |
| prediction_date: datetime, | |
| ) -> None: | |
| """Record a ripeness classification prediction. | |
| Args: | |
| case_id: Case identifier | |
| predicted_status: Predicted ripeness status | |
| prediction_date: When prediction was made | |
| """ | |
| self.predictions[case_id] = RipenessPrediction( | |
| case_id=case_id, | |
| predicted_status=predicted_status, | |
| prediction_date=prediction_date, | |
| ) | |
| def record_outcome( | |
| self, | |
| case_id: str, | |
| actual_outcome: str, | |
| was_adjourned: bool, | |
| outcome_date: datetime, | |
| ) -> None: | |
| """Record actual hearing outcome for a predicted case. | |
| Args: | |
| case_id: Case identifier | |
| actual_outcome: Actual hearing outcome (e.g., "ADJOURNED", "ARGUMENTS") | |
| was_adjourned: Whether hearing was adjourned | |
| outcome_date: When outcome occurred | |
| """ | |
| if case_id in self.predictions: | |
| pred = self.predictions[case_id] | |
| pred.actual_outcome = actual_outcome | |
| pred.was_adjourned = was_adjourned | |
| pred.outcome_date = outcome_date | |
| # Move to completed | |
| self.completed_predictions.append(pred) | |
| del self.predictions[case_id] | |
| def get_accuracy_metrics(self) -> dict[str, float]: | |
| """Compute classification accuracy metrics. | |
| Returns: | |
| Dictionary with accuracy metrics: | |
| - total_predictions: Total predictions made | |
| - completed_predictions: Predictions with outcomes | |
| - false_positive_rate: RIPE cases that adjourned | |
| - false_negative_rate: UNRIPE cases that progressed | |
| - unknown_rate: Cases classified as UNKNOWN | |
| - ripe_precision: P(progressed | predicted RIPE) | |
| - unripe_recall: P(predicted UNRIPE | adjourned) | |
| """ | |
| if not self.completed_predictions: | |
| return { | |
| "total_predictions": 0, | |
| "completed_predictions": 0, | |
| "false_positive_rate": 0.0, | |
| "false_negative_rate": 0.0, | |
| "unknown_rate": 0.0, | |
| "ripe_precision": 0.0, | |
| "unripe_recall": 0.0, | |
| } | |
| total = len(self.completed_predictions) | |
| # Count predictions by status | |
| ripe_predictions = [ | |
| p | |
| for p in self.completed_predictions | |
| if p.predicted_status == RipenessStatus.RIPE | |
| ] | |
| unripe_predictions = [ | |
| p for p in self.completed_predictions if p.predicted_status.is_unripe() | |
| ] | |
| unknown_predictions = [ | |
| p | |
| for p in self.completed_predictions | |
| if p.predicted_status == RipenessStatus.UNKNOWN | |
| ] | |
| # Count actual outcomes | |
| adjourned_cases = [p for p in self.completed_predictions if p.was_adjourned] | |
| [p for p in self.completed_predictions if not p.was_adjourned] | |
| # False positives: predicted RIPE but adjourned | |
| false_positives = [p for p in ripe_predictions if p.was_adjourned] | |
| false_positive_rate = ( | |
| len(false_positives) / len(ripe_predictions) if ripe_predictions else 0.0 | |
| ) | |
| # False negatives: predicted UNRIPE but progressed | |
| false_negatives = [p for p in unripe_predictions if not p.was_adjourned] | |
| false_negative_rate = ( | |
| len(false_negatives) / len(unripe_predictions) | |
| if unripe_predictions | |
| else 0.0 | |
| ) | |
| # Precision: of predicted RIPE, how many progressed? | |
| ripe_correct = [p for p in ripe_predictions if not p.was_adjourned] | |
| ripe_precision = ( | |
| len(ripe_correct) / len(ripe_predictions) if ripe_predictions else 0.0 | |
| ) | |
| # Recall: of actually adjourned cases, how many did we predict UNRIPE? | |
| unripe_correct = [p for p in unripe_predictions if p.was_adjourned] | |
| unripe_recall = ( | |
| len(unripe_correct) / len(adjourned_cases) if adjourned_cases else 0.0 | |
| ) | |
| return { | |
| "total_predictions": total + len(self.predictions), | |
| "completed_predictions": total, | |
| "false_positive_rate": false_positive_rate, | |
| "false_negative_rate": false_negative_rate, | |
| "unknown_rate": len(unknown_predictions) / total, | |
| "ripe_precision": ripe_precision, | |
| "unripe_recall": unripe_recall, | |
| } | |
| def get_confusion_matrix(self) -> dict[str, dict[str, int]]: | |
| """Generate confusion matrix of predictions vs outcomes. | |
| Returns: | |
| Nested dict: predicted_status -> actual_outcome -> count | |
| """ | |
| matrix: dict[str, dict[str, int]] = { | |
| "RIPE": {"progressed": 0, "adjourned": 0}, | |
| "UNRIPE": {"progressed": 0, "adjourned": 0}, | |
| "UNKNOWN": {"progressed": 0, "adjourned": 0}, | |
| } | |
| for pred in self.completed_predictions: | |
| if pred.predicted_status == RipenessStatus.RIPE: | |
| key = "RIPE" | |
| elif pred.predicted_status.is_unripe(): | |
| key = "UNRIPE" | |
| else: | |
| key = "UNKNOWN" | |
| outcome_key = "adjourned" if pred.was_adjourned else "progressed" | |
| matrix[key][outcome_key] += 1 | |
| return matrix | |
| def to_dataframe(self) -> pd.DataFrame: | |
| """Export predictions to DataFrame for analysis. | |
| Returns: | |
| DataFrame with columns: case_id, predicted_status, prediction_date, | |
| actual_outcome, was_adjourned, outcome_date | |
| """ | |
| records = [] | |
| for pred in self.completed_predictions: | |
| records.append( | |
| { | |
| "case_id": pred.case_id, | |
| "predicted_status": pred.predicted_status.value, | |
| "prediction_date": pred.prediction_date, | |
| "actual_outcome": pred.actual_outcome, | |
| "was_adjourned": pred.was_adjourned, | |
| "outcome_date": pred.outcome_date, | |
| "correct_prediction": ( | |
| ( | |
| pred.predicted_status == RipenessStatus.RIPE | |
| and not pred.was_adjourned | |
| ) | |
| or (pred.predicted_status.is_unripe() and pred.was_adjourned) | |
| ), | |
| } | |
| ) | |
| return pd.DataFrame(records) | |
| def save_report(self, output_path: Path) -> None: | |
| """Save accuracy report and predictions to files. | |
| Args: | |
| output_path: Path to output directory | |
| """ | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| # Save metrics summary | |
| metrics = self.get_accuracy_metrics() | |
| metrics_df = pd.DataFrame([metrics]) | |
| metrics_df.to_csv(output_path / "ripeness_accuracy.csv", index=False) | |
| # Save confusion matrix | |
| matrix = self.get_confusion_matrix() | |
| matrix_df = pd.DataFrame(matrix).T | |
| matrix_df.to_csv(output_path / "ripeness_confusion_matrix.csv") | |
| # Save detailed predictions | |
| if self.completed_predictions: | |
| predictions_df = self.to_dataframe() | |
| predictions_df.to_csv(output_path / "ripeness_predictions.csv", index=False) | |
| # Generate human-readable report | |
| report_lines = [ | |
| "Ripeness Classification Accuracy Report", | |
| "=" * 60, | |
| f"Total predictions: {metrics['total_predictions']}", | |
| f"Completed predictions: {metrics['completed_predictions']}", | |
| "", | |
| "Accuracy Metrics:", | |
| f" False positive rate (RIPE but adjourned): {metrics['false_positive_rate']:.1%}", | |
| f" False negative rate (UNRIPE but progressed): {metrics['false_negative_rate']:.1%}", | |
| f" UNKNOWN rate: {metrics['unknown_rate']:.1%}", | |
| f" RIPE precision (progressed | predicted RIPE): {metrics['ripe_precision']:.1%}", | |
| f" UNRIPE recall (predicted UNRIPE | adjourned): {metrics['unripe_recall']:.1%}", | |
| "", | |
| "Confusion Matrix:", | |
| f" RIPE -> Progressed: {matrix['RIPE']['progressed']}, Adjourned: {matrix['RIPE']['adjourned']}", | |
| f" UNRIPE -> Progressed: {matrix['UNRIPE']['progressed']}, Adjourned: {matrix['UNRIPE']['adjourned']}", | |
| f" UNKNOWN -> Progressed: {matrix['UNKNOWN']['progressed']}, Adjourned: {matrix['UNKNOWN']['adjourned']}", | |
| "", | |
| "Interpretation:", | |
| ] | |
| # Add interpretation | |
| if metrics["false_positive_rate"] > 0.20: | |
| report_lines.append( | |
| " - HIGH false positive rate: Consider increasing MIN_SERVICE_HEARINGS" | |
| ) | |
| if metrics["false_negative_rate"] > 0.15: | |
| report_lines.append( | |
| " - HIGH false negative rate: Consider decreasing MIN_STAGE_DAYS" | |
| ) | |
| if metrics["unknown_rate"] < 0.05: | |
| report_lines.append( | |
| " - LOW UNKNOWN rate: System may be overconfident, add uncertainty" | |
| ) | |
| if metrics["ripe_precision"] > 0.85: | |
| report_lines.append( | |
| " - GOOD RIPE precision: Most RIPE predictions are correct" | |
| ) | |
| if metrics["unripe_recall"] < 0.60: | |
| report_lines.append( | |
| " - LOW UNRIPE recall: Missing many bottlenecks, refine detection" | |
| ) | |
| report_text = "\n".join(report_lines) | |
| (output_path / "ripeness_report.txt").write_text(report_text) | |
| print(f"Ripeness accuracy report saved to {output_path}") | |