| |
| """ |
| Calibration analysis for MITI 4.2 Not Coded Classifier |
| |
| This script analyzes model calibration and finds optimal thresholds for different use cases. |
| It performs: |
| 1. Probability calibration assessment |
| 2. ROC curve analysis |
| 3. Precision-Recall curve analysis |
| 4. Optimal threshold finding for various metrics |
| 5. Per-annotator threshold analysis (if applicable) |
| |
| Usage: |
| python calibration_analysis.py |
| """ |
|
|
| import json |
| import numpy as np |
| import torch |
| from sklearn.calibration import calibration_curve |
| from sklearn.metrics import ( |
| roc_curve, auc, precision_recall_curve, average_precision_score, |
| accuracy_score, precision_recall_fscore_support, confusion_matrix |
| ) |
| import matplotlib.pyplot as plt |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from datasets import Dataset |
| from typing import Dict, List, Tuple |
| import os |
|
|
|
|
| def load_model_and_data( |
| model_name: str = "Lekhansh/qwen_nc_classifier", |
| data_path: str = "multilabel_classifier_dataset.json" |
| ): |
| """Load model, tokenizer, and test data |
| |
| Args: |
| model_name: HuggingFace model name or local path |
| data_path: Path to multilabel dataset JSON file |
| """ |
|
|
| print(f"Loading model from {model_name}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| device_map="auto" |
| ) |
| model.eval() |
|
|
| print(f"Loading data from {data_path}...") |
| with open(data_path, 'r') as f: |
| data = json.load(f) |
|
|
| return model, tokenizer, data |
|
|
|
|
| def reframe_as_binary(data: List[Dict], invert_labels: bool = True) -> List[Dict]: |
| """Reframe multilabel data as binary classification (same as training script)""" |
| binary_data = [] |
|
|
| for example in data: |
| task_prefix = "Task: Decide if the last therapist utterance should be coded or not.\\n" |
| annotated_by = example.get('annotated_by', '') |
| annotator_info = f"Annotated by: {annotated_by}\\n" if annotated_by else "" |
| full_text = task_prefix + annotator_info + example['input'] |
|
|
| label = example['not_coded'] |
| if invert_labels: |
| label = 1 - label |
|
|
| binary_example = { |
| 'text': full_text, |
| 'label': label, |
| 'annotated_by': annotated_by, |
| 'seq_id': example['seq_id'], |
| 'id': example['id'], |
| 'unique_id': example['unique_id'] |
| } |
| binary_data.append(binary_example) |
|
|
| return binary_data |
|
|
|
|
| def split_data(data: List[Dict], random_state=42): |
| """Split data same way as training (80/10/10)""" |
| from sklearn.model_selection import train_test_split |
|
|
| |
| train_data, temp_data = train_test_split( |
| data, train_size=0.8, random_state=random_state, |
| stratify=[d['label'] for d in data] |
| ) |
|
|
| |
| val_data, test_data = train_test_split( |
| temp_data, train_size=0.5, random_state=random_state, |
| stratify=[d['label'] for d in temp_data] |
| ) |
|
|
| return train_data, val_data, test_data |
|
|
|
|
| def get_predictions(model, tokenizer, data: List[Dict], max_length=3000, batch_size=16): |
| """Get model predictions and probabilities for all examples""" |
|
|
| print(f"Getting predictions for {len(data)} examples...") |
|
|
| all_probs = [] |
| all_labels = [] |
| all_annotators = [] |
|
|
| |
| for i in range(0, len(data), batch_size): |
| batch = data[i:i+batch_size] |
| texts = [d['text'] for d in batch] |
| labels = [d['label'] for d in batch] |
| annotators = [d.get('annotated_by', '') for d in batch] |
|
|
| |
| inputs = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| probs = torch.softmax(outputs.logits, dim=1) |
| probs_coded = probs[:, 1].float().cpu().numpy() |
|
|
| all_probs.extend(probs_coded) |
| all_labels.extend(labels) |
| all_annotators.extend(annotators) |
|
|
| return np.array(all_probs), np.array(all_labels), all_annotators |
|
|
|
|
| def plot_calibration_curve(y_true, y_prob, n_bins=10, save_path="calibration_curve.png"): |
| """Plot calibration curve to assess probability calibration""" |
|
|
| prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform') |
|
|
| plt.figure(figsize=(10, 8)) |
| plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model') |
| plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration', color='gray') |
|
|
| plt.xlabel('Mean predicted probability', fontsize=12) |
| plt.ylabel('Fraction of positives (True probability)', fontsize=12) |
| plt.title('Calibration Curve - MITI Not Coded Classifier', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"Calibration curve saved to {save_path}") |
| plt.close() |
|
|
|
|
| def plot_roc_curve(y_true, y_prob, save_path="roc_curve.png"): |
| """Plot ROC curve""" |
|
|
| fpr, tpr, thresholds = roc_curve(y_true, y_prob) |
| roc_auc = auc(fpr, tpr) |
|
|
| plt.figure(figsize=(10, 8)) |
| plt.plot(fpr, tpr, linewidth=2, label=f'ROC curve (AUC = {roc_auc:.4f})') |
| plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random classifier') |
|
|
| plt.xlabel('False Positive Rate', fontsize=12) |
| plt.ylabel('True Positive Rate (Recall)', fontsize=12) |
| plt.title('ROC Curve - MITI Not Coded Classifier', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"ROC curve saved to {save_path}") |
| plt.close() |
|
|
| return fpr, tpr, thresholds, roc_auc |
|
|
|
|
| def plot_precision_recall_curve(y_true, y_prob, save_path="precision_recall_curve.png"): |
| """Plot Precision-Recall curve""" |
|
|
| precision, recall, thresholds = precision_recall_curve(y_true, y_prob) |
| avg_precision = average_precision_score(y_true, y_prob) |
|
|
| plt.figure(figsize=(10, 8)) |
| plt.plot(recall, precision, linewidth=2, label=f'PR curve (AP = {avg_precision:.4f})') |
|
|
| |
| baseline = np.sum(y_true) / len(y_true) |
| plt.axhline(y=baseline, linestyle='--', color='gray', label=f'Baseline ({baseline:.3f})') |
|
|
| plt.xlabel('Recall', fontsize=12) |
| plt.ylabel('Precision', fontsize=12) |
| plt.title('Precision-Recall Curve - MITI Not Coded Classifier', fontsize=14) |
| plt.legend(fontsize=11) |
| plt.grid(alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| print(f"Precision-Recall curve saved to {save_path}") |
| plt.close() |
|
|
| return precision, recall, thresholds, avg_precision |
|
|
|
|
| def find_optimal_thresholds(y_true, y_prob): |
| """Find optimal thresholds for different optimization objectives""" |
|
|
| precision, recall, pr_thresholds = precision_recall_curve(y_true, y_prob) |
| fpr, tpr, roc_thresholds = roc_curve(y_true, y_prob) |
|
|
| results = {} |
|
|
| |
| f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-10) |
| best_f1_idx = np.argmax(f1_scores) |
| results['max_f1'] = { |
| 'threshold': pr_thresholds[best_f1_idx], |
| 'f1': f1_scores[best_f1_idx], |
| 'precision': precision[best_f1_idx], |
| 'recall': recall[best_f1_idx] |
| } |
|
|
| |
| j_scores = tpr - fpr |
| best_j_idx = np.argmax(j_scores) |
| results['max_youden'] = { |
| 'threshold': roc_thresholds[best_j_idx], |
| 'j_statistic': j_scores[best_j_idx], |
| 'tpr': tpr[best_j_idx], |
| 'fpr': fpr[best_j_idx] |
| } |
|
|
| |
| pr_diff = np.abs(precision[:-1] - recall[:-1]) |
| balanced_idx = np.argmin(pr_diff) |
| results['balanced_pr'] = { |
| 'threshold': pr_thresholds[balanced_idx], |
| 'precision': precision[balanced_idx], |
| 'recall': recall[balanced_idx], |
| 'f1': f1_scores[balanced_idx] |
| } |
|
|
| |
| recall_95_idx = np.where(recall[:-1] >= 0.95)[0] |
| if len(recall_95_idx) > 0: |
| recall_95_idx = recall_95_idx[-1] |
| results['high_recall_95'] = { |
| 'threshold': pr_thresholds[recall_95_idx], |
| 'recall': recall[recall_95_idx], |
| 'precision': precision[recall_95_idx], |
| 'f1': f1_scores[recall_95_idx] |
| } |
|
|
| |
| precision_95_idx = np.where(precision[:-1] >= 0.95)[0] |
| if len(precision_95_idx) > 0: |
| precision_95_idx = precision_95_idx[0] |
| results['high_precision_95'] = { |
| 'threshold': pr_thresholds[precision_95_idx], |
| 'precision': precision[precision_95_idx], |
| 'recall': recall[precision_95_idx], |
| 'f1': f1_scores[precision_95_idx] |
| } |
|
|
| |
| default_pred = (y_prob >= 0.5).astype(int) |
| default_acc = accuracy_score(y_true, default_pred) |
| default_prec, default_rec, default_f1, _ = precision_recall_fscore_support( |
| y_true, default_pred, average='binary' |
| ) |
| results['default_0.5'] = { |
| 'threshold': 0.5, |
| 'accuracy': default_acc, |
| 'precision': default_prec, |
| 'recall': default_rec, |
| 'f1': default_f1 |
| } |
|
|
| return results |
|
|
|
|
| def evaluate_at_threshold(y_true, y_prob, threshold): |
| """Get detailed metrics at a specific threshold""" |
|
|
| y_pred = (y_prob >= threshold).astype(int) |
|
|
| accuracy = accuracy_score(y_true, y_pred) |
| precision, recall, f1, _ = precision_recall_fscore_support( |
| y_true, y_pred, average='binary', pos_label=1 |
| ) |
|
|
| |
| precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support( |
| y_true, y_pred, average='macro' |
| ) |
|
|
| |
| precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support( |
| y_true, y_pred, average=None, labels=[0, 1] |
| ) |
|
|
| |
| cm = confusion_matrix(y_true, y_pred) |
|
|
| return { |
| 'accuracy': accuracy, |
| 'precision': precision, |
| 'recall': recall, |
| 'f1': f1, |
| 'precision_macro': precision_macro, |
| 'recall_macro': recall_macro, |
| 'f1_macro': f1_macro, |
| 'precision_not_coded': precision_per_class[0], |
| 'recall_not_coded': recall_per_class[0], |
| 'f1_not_coded': f1_per_class[0], |
| 'precision_coded': precision_per_class[1], |
| 'recall_coded': recall_per_class[1], |
| 'f1_coded': f1_per_class[1], |
| 'confusion_matrix': cm.tolist() |
| } |
|
|
|
|
| def analyze_per_annotator(y_true, y_prob, annotators): |
| """Analyze performance separately for each annotator""" |
|
|
| unique_annotators = set(annotators) |
| results = {} |
|
|
| for annotator in unique_annotators: |
| if not annotator: |
| continue |
|
|
| |
| indices = [i for i, a in enumerate(annotators) if a == annotator] |
|
|
| if len(indices) < 10: |
| continue |
|
|
| ann_true = y_true[indices] |
| ann_prob = y_prob[indices] |
|
|
| |
| thresholds = find_optimal_thresholds(ann_true, ann_prob) |
|
|
| results[annotator] = { |
| 'n_examples': len(indices), |
| 'n_coded': int(np.sum(ann_true)), |
| 'n_not_coded': int(len(ann_true) - np.sum(ann_true)), |
| 'optimal_thresholds': thresholds |
| } |
|
|
| return results |
|
|
|
|
| def main(): |
| """Main calibration analysis""" |
|
|
| print("="*80) |
| print("CALIBRATION ANALYSIS - MITI NOT CODED CLASSIFIER") |
| print("="*80) |
| print() |
|
|
| |
| output_dir = "calibration_analysis" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| model, tokenizer, data = load_model_and_data() |
|
|
| |
| binary_data = reframe_as_binary(data, invert_labels=True) |
|
|
| |
| train_data, val_data, test_data = split_data(binary_data) |
|
|
| print(f"\\nData splits:") |
| print(f" Train: {len(train_data)}") |
| print(f" Val: {len(val_data)}") |
| print(f" Test: {len(test_data)}") |
|
|
| |
| print("\\n" + "="*80) |
| print("VALIDATION SET ANALYSIS") |
| print("="*80) |
|
|
| val_probs, val_labels, val_annotators = get_predictions(model, tokenizer, val_data) |
|
|
| |
| print("\\nGenerating calibration curve...") |
| plot_calibration_curve(val_labels, val_probs, |
| save_path=f"{output_dir}/val_calibration_curve.png") |
|
|
| |
| print("Generating ROC curve...") |
| val_fpr, val_tpr, val_roc_thresh, val_auc = plot_roc_curve( |
| val_labels, val_probs, save_path=f"{output_dir}/val_roc_curve.png" |
| ) |
|
|
| |
| print("Generating Precision-Recall curve...") |
| val_prec, val_rec, val_pr_thresh, val_ap = plot_precision_recall_curve( |
| val_labels, val_probs, save_path=f"{output_dir}/val_pr_curve.png" |
| ) |
|
|
| |
| print("\\nFinding optimal thresholds...") |
| optimal_thresholds = find_optimal_thresholds(val_labels, val_probs) |
|
|
| print("\\n" + "-"*80) |
| print("OPTIMAL THRESHOLDS (Validation Set)") |
| print("-"*80) |
|
|
| for strategy, metrics in optimal_thresholds.items(): |
| print(f"\\n{strategy.upper().replace('_', ' ')}:") |
| for key, value in metrics.items(): |
| print(f" {key}: {value:.4f}") |
|
|
| |
| print("\\n" + "-"*80) |
| print("PER-ANNOTATOR ANALYSIS (Validation Set)") |
| print("-"*80) |
|
|
| annotator_results = analyze_per_annotator(val_labels, val_probs, val_annotators) |
|
|
| for annotator, results in annotator_results.items(): |
| print(f"\\nAnnotator: {annotator}") |
| print(f" Examples: {results['n_examples']}") |
| print(f" Coded: {results['n_coded']} ({results['n_coded']/results['n_examples']*100:.1f}%)") |
| print(f" Not Coded: {results['n_not_coded']} ({results['n_not_coded']/results['n_examples']*100:.1f}%)") |
| print(f" Optimal threshold (max F1): {results['optimal_thresholds']['max_f1']['threshold']:.4f}") |
|
|
| |
| print("\\n" + "="*80) |
| print("TEST SET EVALUATION WITH DIFFERENT THRESHOLDS") |
| print("="*80) |
|
|
| test_probs, test_labels, test_annotators = get_predictions(model, tokenizer, test_data) |
|
|
| test_results = {} |
| print("\\nEvaluating different threshold strategies on test set...") |
|
|
| for strategy, val_metrics in optimal_thresholds.items(): |
| threshold = val_metrics['threshold'] |
| test_metrics = evaluate_at_threshold(test_labels, test_probs, threshold) |
| test_results[strategy] = { |
| 'threshold': threshold, |
| **test_metrics |
| } |
|
|
| print(f"\\n{strategy.upper().replace('_', ' ')} (threshold={threshold:.4f}):") |
| print(f" Accuracy: {test_metrics['accuracy']:.4f}") |
| print(f" F1 Macro: {test_metrics['f1_macro']:.4f}") |
| print(f" F1 Coded: {test_metrics['f1_coded']:.4f}") |
| print(f" F1 Not Coded: {test_metrics['f1_not_coded']:.4f}") |
| print(f" Precision Coded: {test_metrics['precision_coded']:.4f}") |
| print(f" Recall Coded: {test_metrics['recall_coded']:.4f}") |
| print(f" Precision Not Coded: {test_metrics['precision_not_coded']:.4f}") |
| print(f" Recall Not Coded: {test_metrics['recall_not_coded']:.4f}") |
|
|
| |
| def convert_numpy_types(obj): |
| """Recursively convert numpy types to native Python types""" |
| if isinstance(obj, dict): |
| return {key: convert_numpy_types(value) for key, value in obj.items()} |
| elif isinstance(obj, list): |
| return [convert_numpy_types(item) for item in obj] |
| elif isinstance(obj, np.integer): |
| return int(obj) |
| elif isinstance(obj, np.floating): |
| return float(obj) |
| elif isinstance(obj, np.ndarray): |
| return obj.tolist() |
| else: |
| return obj |
|
|
| output_data = { |
| 'validation': { |
| 'optimal_thresholds': optimal_thresholds, |
| 'roc_auc': val_auc, |
| 'average_precision': val_ap, |
| 'per_annotator': annotator_results |
| }, |
| 'test': test_results |
| } |
|
|
| |
| output_data = convert_numpy_types(output_data) |
|
|
| output_file = f"{output_dir}/calibration_results.json" |
| with open(output_file, 'w') as f: |
| json.dump(output_data, f, indent=2) |
|
|
| print(f"\\n{'='*80}") |
| print("ANALYSIS COMPLETE") |
| print(f"{'='*80}") |
| print(f"\\nResults saved to: {output_dir}/") |
| print(f" - calibration_results.json") |
| print(f" - val_calibration_curve.png") |
| print(f" - val_roc_curve.png") |
| print(f" - val_pr_curve.png") |
|
|
| |
| print(f"\\n{'='*80}") |
| print("RECOMMENDATIONS") |
| print(f"{'='*80}") |
|
|
| print("\\nBased on the analysis:") |
| print(f"\\n1. Default threshold (0.5): Currently used by the model") |
| print(f" - F1 Macro: {test_results['default_0.5']['f1_macro']:.4f}") |
|
|
| print(f"\\n2. Max F1 threshold: Optimizes overall F1 score") |
| print(f" - Threshold: {optimal_thresholds['max_f1']['threshold']:.4f}") |
| print(f" - F1 Macro: {test_results['max_f1']['f1_macro']:.4f}") |
|
|
| print(f"\\n3. Balanced P/R threshold: Equalizes precision and recall") |
| print(f" - Threshold: {optimal_thresholds['balanced_pr']['threshold']:.4f}") |
| print(f" - F1 Macro: {test_results['balanced_pr']['f1_macro']:.4f}") |
|
|
| print("\\nConsider using different thresholds for different use cases:") |
| print(" - Training/Education: Use high recall threshold to catch all codeable utterances") |
| print(" - Research: Use max F1 or balanced threshold for optimal overall performance") |
| print(" - Quality Assurance: Use high precision threshold to minimize false positives") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|