Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import numpy as np | |
| from typing import Optional, Dict, Any, List | |
| from dataclasses import dataclass | |
| class PredictionAnalysis: | |
| """Analysis results for a set of predictions.""" | |
| total_samples: int | |
| correct: int | |
| incorrect: int | |
| accuracy: float | |
| by_class: Dict[str, Dict[str, int]] # {class: {correct, incorrect, total}} | |
| by_uncertainty: Dict[str, Dict[str, float]] # {low/medium/high: {accuracy, count}} | |
| high_confidence_errors: List[Dict[str, Any]] # Samples with high confidence but wrong | |
| low_confidence_correct: List[Dict[str, Any]] # Samples with low confidence but correct | |
| def analyze_predictions( | |
| predictions: np.ndarray, | |
| labels: np.ndarray, | |
| probabilities: Optional[np.ndarray] = None, | |
| uncertainties: Optional[np.ndarray] = None, | |
| ) -> PredictionAnalysis: | |
| """Detailed analysis of model predictions.""" | |
| correct_mask = predictions == labels | |
| # Basic stats | |
| total = len(predictions) | |
| correct = correct_mask.sum() | |
| incorrect = total - correct | |
| accuracy = correct / total if total > 0 else 0 | |
| # By class | |
| by_class = {} | |
| for c in [0, 1, 2]: | |
| class_mask = labels == c | |
| class_name = ["clean", "suspicious", "cheating"][c] | |
| by_class[class_name] = { | |
| "correct": int((correct_mask & class_mask).sum()), | |
| "incorrect": int((~correct_mask & class_mask).sum()), | |
| "total": int(class_mask.sum()), | |
| } | |
| # By uncertainty (if available) | |
| by_uncertainty = {} | |
| if uncertainties is not None: | |
| low_mask = uncertainties < 0.3 | |
| med_mask = (uncertainties >= 0.3) & (uncertainties < 0.7) | |
| high_mask = uncertainties >= 0.7 | |
| for name, mask in [("low", low_mask), ("medium", med_mask), ("high", high_mask)]: | |
| if mask.sum() > 0: | |
| by_uncertainty[name] = { | |
| "accuracy": float(correct_mask[mask].mean()), | |
| "count": int(mask.sum()), | |
| } | |
| # Error analysis | |
| high_conf_errors = [] | |
| low_conf_correct = [] | |
| if probabilities is not None: | |
| confidences = probabilities.max(axis=1) | |
| # High confidence errors | |
| high_conf_wrong = (~correct_mask) & (confidences > 0.9) | |
| for idx in np.where(high_conf_wrong)[0][:10]: # Top 10 | |
| high_conf_errors.append({ | |
| "idx": int(idx), | |
| "predicted": int(predictions[idx]), | |
| "actual": int(labels[idx]), | |
| "confidence": float(confidences[idx]), | |
| }) | |
| # Low confidence correct | |
| low_conf_right = correct_mask & (confidences < 0.5) | |
| for idx in np.where(low_conf_right)[0][:10]: | |
| low_conf_correct.append({ | |
| "idx": int(idx), | |
| "predicted": int(predictions[idx]), | |
| "actual": int(labels[idx]), | |
| "confidence": float(confidences[idx]), | |
| }) | |
| return PredictionAnalysis( | |
| total_samples=total, | |
| correct=int(correct), | |
| incorrect=int(incorrect), | |
| accuracy=accuracy, | |
| by_class=by_class, | |
| by_uncertainty=by_uncertainty, | |
| high_confidence_errors=high_conf_errors, | |
| low_confidence_correct=low_conf_correct, | |
| ) | |
| def compute_feature_importance( | |
| model_outputs: Dict[str, np.ndarray], | |
| method: str = "gradient", | |
| ) -> Dict[str, float]: | |
| """Compute feature importance from model outputs.""" | |
| # Placeholder - actual implementation would use gradients | |
| return {"placeholder": 1.0} | |
| def format_analysis_report(analysis: PredictionAnalysis) -> str: | |
| """Format analysis as readable report string.""" | |
| lines = [ | |
| "=" * 50, | |
| "PREDICTION ANALYSIS REPORT", | |
| "=" * 50, | |
| f"Total Samples: {analysis.total_samples}", | |
| f"Accuracy: {analysis.accuracy:.4f} ({analysis.correct}/{analysis.total_samples})", | |
| "", | |
| "By Class:", | |
| ] | |
| for cls, stats in analysis.by_class.items(): | |
| acc = stats["correct"] / stats["total"] if stats["total"] > 0 else 0 | |
| lines.append(f" {cls}: {acc:.4f} ({stats['correct']}/{stats['total']})") | |
| if analysis.by_uncertainty: | |
| lines.append("") | |
| lines.append("By Uncertainty:") | |
| for level, stats in analysis.by_uncertainty.items(): | |
| lines.append(f" {level}: acc={stats['accuracy']:.4f}, n={stats['count']}") | |
| if analysis.high_confidence_errors: | |
| lines.append("") | |
| lines.append(f"High Confidence Errors ({len(analysis.high_confidence_errors)}):") | |
| for err in analysis.high_confidence_errors[:3]: | |
| lines.append(f" idx={err['idx']}: pred={err['predicted']}, " | |
| f"actual={err['actual']}, conf={err['confidence']:.3f}") | |
| lines.append("=" * 50) | |
| return "\n".join(lines) | |