qwen_nc_classifier / calibration_analysis.py
Lekhansh's picture
Upload MITI 4.2 Not Coded Classifier with calibration results
6f2fb04 verified
#!/usr/bin/env python3
"""
Calibration analysis for MITI 4.2 Not Coded Classifier
This script analyzes model calibration and finds optimal thresholds for different use cases.
It performs:
1. Probability calibration assessment
2. ROC curve analysis
3. Precision-Recall curve analysis
4. Optimal threshold finding for various metrics
5. Per-annotator threshold analysis (if applicable)
Usage:
python calibration_analysis.py
"""
import json
import numpy as np
import torch
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
roc_curve, auc, precision_recall_curve, average_precision_score,
accuracy_score, precision_recall_fscore_support, confusion_matrix
)
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import Dataset
from typing import Dict, List, Tuple
import os
def load_model_and_data(
model_name: str = "Lekhansh/qwen_nc_classifier",
data_path: str = "multilabel_classifier_dataset.json"
):
"""Load model, tokenizer, and test data
Args:
model_name: HuggingFace model name or local path
data_path: Path to multilabel dataset JSON file
"""
print(f"Loading model from {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
model.eval()
print(f"Loading data from {data_path}...")
with open(data_path, 'r') as f:
data = json.load(f)
return model, tokenizer, data
def reframe_as_binary(data: List[Dict], invert_labels: bool = True) -> List[Dict]:
"""Reframe multilabel data as binary classification (same as training script)"""
binary_data = []
for example in data:
task_prefix = "Task: Decide if the last therapist utterance should be coded or not.\\n"
annotated_by = example.get('annotated_by', '')
annotator_info = f"Annotated by: {annotated_by}\\n" if annotated_by else ""
full_text = task_prefix + annotator_info + example['input']
label = example['not_coded']
if invert_labels:
label = 1 - label
binary_example = {
'text': full_text,
'label': label,
'annotated_by': annotated_by,
'seq_id': example['seq_id'],
'id': example['id'],
'unique_id': example['unique_id']
}
binary_data.append(binary_example)
return binary_data
def split_data(data: List[Dict], random_state=42):
"""Split data same way as training (80/10/10)"""
from sklearn.model_selection import train_test_split
# First split: train vs (val + test)
train_data, temp_data = train_test_split(
data, train_size=0.8, random_state=random_state,
stratify=[d['label'] for d in data]
)
# Second split: val vs test
val_data, test_data = train_test_split(
temp_data, train_size=0.5, random_state=random_state,
stratify=[d['label'] for d in temp_data]
)
return train_data, val_data, test_data
def get_predictions(model, tokenizer, data: List[Dict], max_length=3000, batch_size=16):
"""Get model predictions and probabilities for all examples"""
print(f"Getting predictions for {len(data)} examples...")
all_probs = []
all_labels = []
all_annotators = []
# Process in batches
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
texts = [d['text'] for d in batch]
labels = [d['label'] for d in batch]
annotators = [d.get('annotated_by', '') for d in batch]
# Tokenize
inputs = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
probs_coded = probs[:, 1].float().cpu().numpy() # Probability of "coded" class
all_probs.extend(probs_coded)
all_labels.extend(labels)
all_annotators.extend(annotators)
return np.array(all_probs), np.array(all_labels), all_annotators
def plot_calibration_curve(y_true, y_prob, n_bins=10, save_path="calibration_curve.png"):
"""Plot calibration curve to assess probability calibration"""
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')
plt.figure(figsize=(10, 8))
plt.plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfect calibration', color='gray')
plt.xlabel('Mean predicted probability', fontsize=12)
plt.ylabel('Fraction of positives (True probability)', fontsize=12)
plt.title('Calibration Curve - MITI Not Coded Classifier', fontsize=14)
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Calibration curve saved to {save_path}")
plt.close()
def plot_roc_curve(y_true, y_prob, save_path="roc_curve.png"):
"""Plot ROC curve"""
fpr, tpr, thresholds = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 8))
plt.plot(fpr, tpr, linewidth=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random classifier')
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate (Recall)', fontsize=12)
plt.title('ROC Curve - MITI Not Coded Classifier', fontsize=14)
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"ROC curve saved to {save_path}")
plt.close()
return fpr, tpr, thresholds, roc_auc
def plot_precision_recall_curve(y_true, y_prob, save_path="precision_recall_curve.png"):
"""Plot Precision-Recall curve"""
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
avg_precision = average_precision_score(y_true, y_prob)
plt.figure(figsize=(10, 8))
plt.plot(recall, precision, linewidth=2, label=f'PR curve (AP = {avg_precision:.4f})')
# Baseline (proportion of positive class)
baseline = np.sum(y_true) / len(y_true)
plt.axhline(y=baseline, linestyle='--', color='gray', label=f'Baseline ({baseline:.3f})')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Precision-Recall Curve - MITI Not Coded Classifier', fontsize=14)
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Precision-Recall curve saved to {save_path}")
plt.close()
return precision, recall, thresholds, avg_precision
def find_optimal_thresholds(y_true, y_prob):
"""Find optimal thresholds for different optimization objectives"""
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_prob)
fpr, tpr, roc_thresholds = roc_curve(y_true, y_prob)
results = {}
# 1. Maximize F1 score
f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-10)
best_f1_idx = np.argmax(f1_scores)
results['max_f1'] = {
'threshold': pr_thresholds[best_f1_idx],
'f1': f1_scores[best_f1_idx],
'precision': precision[best_f1_idx],
'recall': recall[best_f1_idx]
}
# 2. Maximize Youden's J statistic (TPR - FPR)
j_scores = tpr - fpr
best_j_idx = np.argmax(j_scores)
results['max_youden'] = {
'threshold': roc_thresholds[best_j_idx],
'j_statistic': j_scores[best_j_idx],
'tpr': tpr[best_j_idx],
'fpr': fpr[best_j_idx]
}
# 3. Balance precision and recall (closest to precision = recall)
pr_diff = np.abs(precision[:-1] - recall[:-1])
balanced_idx = np.argmin(pr_diff)
results['balanced_pr'] = {
'threshold': pr_thresholds[balanced_idx],
'precision': precision[balanced_idx],
'recall': recall[balanced_idx],
'f1': f1_scores[balanced_idx]
}
# 4. High recall (95% recall threshold)
recall_95_idx = np.where(recall[:-1] >= 0.95)[0]
if len(recall_95_idx) > 0:
recall_95_idx = recall_95_idx[-1] # Highest threshold with 95% recall
results['high_recall_95'] = {
'threshold': pr_thresholds[recall_95_idx],
'recall': recall[recall_95_idx],
'precision': precision[recall_95_idx],
'f1': f1_scores[recall_95_idx]
}
# 5. High precision (95% precision threshold)
precision_95_idx = np.where(precision[:-1] >= 0.95)[0]
if len(precision_95_idx) > 0:
precision_95_idx = precision_95_idx[0] # Lowest threshold with 95% precision
results['high_precision_95'] = {
'threshold': pr_thresholds[precision_95_idx],
'precision': precision[precision_95_idx],
'recall': recall[precision_95_idx],
'f1': f1_scores[precision_95_idx]
}
# 6. Default 0.5 threshold for comparison
default_pred = (y_prob >= 0.5).astype(int)
default_acc = accuracy_score(y_true, default_pred)
default_prec, default_rec, default_f1, _ = precision_recall_fscore_support(
y_true, default_pred, average='binary'
)
results['default_0.5'] = {
'threshold': 0.5,
'accuracy': default_acc,
'precision': default_prec,
'recall': default_rec,
'f1': default_f1
}
return results
def evaluate_at_threshold(y_true, y_prob, threshold):
"""Get detailed metrics at a specific threshold"""
y_pred = (y_prob >= threshold).astype(int)
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='binary', pos_label=1
)
# Macro metrics
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
y_true, y_pred, average='macro'
)
# Per-class metrics
precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
y_true, y_pred, average=None, labels=[0, 1]
)
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'precision_macro': precision_macro,
'recall_macro': recall_macro,
'f1_macro': f1_macro,
'precision_not_coded': precision_per_class[0],
'recall_not_coded': recall_per_class[0],
'f1_not_coded': f1_per_class[0],
'precision_coded': precision_per_class[1],
'recall_coded': recall_per_class[1],
'f1_coded': f1_per_class[1],
'confusion_matrix': cm.tolist()
}
def analyze_per_annotator(y_true, y_prob, annotators):
"""Analyze performance separately for each annotator"""
unique_annotators = set(annotators)
results = {}
for annotator in unique_annotators:
if not annotator: # Skip empty annotator
continue
# Get indices for this annotator
indices = [i for i, a in enumerate(annotators) if a == annotator]
if len(indices) < 10: # Skip if too few examples
continue
ann_true = y_true[indices]
ann_prob = y_prob[indices]
# Find optimal threshold for this annotator
thresholds = find_optimal_thresholds(ann_true, ann_prob)
results[annotator] = {
'n_examples': len(indices),
'n_coded': int(np.sum(ann_true)),
'n_not_coded': int(len(ann_true) - np.sum(ann_true)),
'optimal_thresholds': thresholds
}
return results
def main():
"""Main calibration analysis"""
print("="*80)
print("CALIBRATION ANALYSIS - MITI NOT CODED CLASSIFIER")
print("="*80)
print()
# Create output directory
output_dir = "calibration_analysis"
os.makedirs(output_dir, exist_ok=True)
# Load model and data
model, tokenizer, data = load_model_and_data()
# Reframe as binary
binary_data = reframe_as_binary(data, invert_labels=True)
# Split data
train_data, val_data, test_data = split_data(binary_data)
print(f"\\nData splits:")
print(f" Train: {len(train_data)}")
print(f" Val: {len(val_data)}")
print(f" Test: {len(test_data)}")
# Analyze on validation set (for finding optimal thresholds)
print("\\n" + "="*80)
print("VALIDATION SET ANALYSIS")
print("="*80)
val_probs, val_labels, val_annotators = get_predictions(model, tokenizer, val_data)
# Plot calibration curve
print("\\nGenerating calibration curve...")
plot_calibration_curve(val_labels, val_probs,
save_path=f"{output_dir}/val_calibration_curve.png")
# Plot ROC curve
print("Generating ROC curve...")
val_fpr, val_tpr, val_roc_thresh, val_auc = plot_roc_curve(
val_labels, val_probs, save_path=f"{output_dir}/val_roc_curve.png"
)
# Plot PR curve
print("Generating Precision-Recall curve...")
val_prec, val_rec, val_pr_thresh, val_ap = plot_precision_recall_curve(
val_labels, val_probs, save_path=f"{output_dir}/val_pr_curve.png"
)
# Find optimal thresholds
print("\\nFinding optimal thresholds...")
optimal_thresholds = find_optimal_thresholds(val_labels, val_probs)
print("\\n" + "-"*80)
print("OPTIMAL THRESHOLDS (Validation Set)")
print("-"*80)
for strategy, metrics in optimal_thresholds.items():
print(f"\\n{strategy.upper().replace('_', ' ')}:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
# Analyze per annotator
print("\\n" + "-"*80)
print("PER-ANNOTATOR ANALYSIS (Validation Set)")
print("-"*80)
annotator_results = analyze_per_annotator(val_labels, val_probs, val_annotators)
for annotator, results in annotator_results.items():
print(f"\\nAnnotator: {annotator}")
print(f" Examples: {results['n_examples']}")
print(f" Coded: {results['n_coded']} ({results['n_coded']/results['n_examples']*100:.1f}%)")
print(f" Not Coded: {results['n_not_coded']} ({results['n_not_coded']/results['n_examples']*100:.1f}%)")
print(f" Optimal threshold (max F1): {results['optimal_thresholds']['max_f1']['threshold']:.4f}")
# Test on test set with various thresholds
print("\\n" + "="*80)
print("TEST SET EVALUATION WITH DIFFERENT THRESHOLDS")
print("="*80)
test_probs, test_labels, test_annotators = get_predictions(model, tokenizer, test_data)
test_results = {}
print("\\nEvaluating different threshold strategies on test set...")
for strategy, val_metrics in optimal_thresholds.items():
threshold = val_metrics['threshold']
test_metrics = evaluate_at_threshold(test_labels, test_probs, threshold)
test_results[strategy] = {
'threshold': threshold,
**test_metrics
}
print(f"\\n{strategy.upper().replace('_', ' ')} (threshold={threshold:.4f}):")
print(f" Accuracy: {test_metrics['accuracy']:.4f}")
print(f" F1 Macro: {test_metrics['f1_macro']:.4f}")
print(f" F1 Coded: {test_metrics['f1_coded']:.4f}")
print(f" F1 Not Coded: {test_metrics['f1_not_coded']:.4f}")
print(f" Precision Coded: {test_metrics['precision_coded']:.4f}")
print(f" Recall Coded: {test_metrics['recall_coded']:.4f}")
print(f" Precision Not Coded: {test_metrics['precision_not_coded']:.4f}")
print(f" Recall Not Coded: {test_metrics['recall_not_coded']:.4f}")
# Save all results - convert numpy types to native Python types for JSON serialization
def convert_numpy_types(obj):
"""Recursively convert numpy types to native Python types"""
if isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj
output_data = {
'validation': {
'optimal_thresholds': optimal_thresholds,
'roc_auc': val_auc,
'average_precision': val_ap,
'per_annotator': annotator_results
},
'test': test_results
}
# Convert all numpy types
output_data = convert_numpy_types(output_data)
output_file = f"{output_dir}/calibration_results.json"
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)
print(f"\\n{'='*80}")
print("ANALYSIS COMPLETE")
print(f"{'='*80}")
print(f"\\nResults saved to: {output_dir}/")
print(f" - calibration_results.json")
print(f" - val_calibration_curve.png")
print(f" - val_roc_curve.png")
print(f" - val_pr_curve.png")
# Recommendation
print(f"\\n{'='*80}")
print("RECOMMENDATIONS")
print(f"{'='*80}")
print("\\nBased on the analysis:")
print(f"\\n1. Default threshold (0.5): Currently used by the model")
print(f" - F1 Macro: {test_results['default_0.5']['f1_macro']:.4f}")
print(f"\\n2. Max F1 threshold: Optimizes overall F1 score")
print(f" - Threshold: {optimal_thresholds['max_f1']['threshold']:.4f}")
print(f" - F1 Macro: {test_results['max_f1']['f1_macro']:.4f}")
print(f"\\n3. Balanced P/R threshold: Equalizes precision and recall")
print(f" - Threshold: {optimal_thresholds['balanced_pr']['threshold']:.4f}")
print(f" - F1 Macro: {test_results['balanced_pr']['f1_macro']:.4f}")
print("\\nConsider using different thresholds for different use cases:")
print(" - Training/Education: Use high recall threshold to catch all codeable utterances")
print(" - Research: Use max F1 or balanced threshold for optimal overall performance")
print(" - Quality Assurance: Use high precision threshold to minimize false positives")
if __name__ == "__main__":
main()