File size: 7,662 Bytes
31fc7e1 7aa93af 31fc7e1 5acfa1a 31fc7e1 5acfa1a 31fc7e1 5acfa1a 31fc7e1 5acfa1a 31fc7e1 72a4f99 31fc7e1 5acfa1a 31fc7e1 72a4f99 31fc7e1 5acfa1a 31fc7e1 72a4f99 9713221 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.dataset.dataset import VideoDataset
from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
from src.models.model import load_model
def plot_training_curves(log_file, output_dir):
data = pd.read_csv(log_file)
plt.figure(figsize=(12, 5))
# Plot loss curves
plt.subplot(1, 2, 1)
plt.plot(data['epoch'], data['train_loss'], label='Train Loss')
plt.plot(data['epoch'], data['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
# Plot accuracy curves
plt.subplot(1, 2, 2)
plt.plot(data['epoch'], data['train_accuracy'], label='Train Accuracy')
plt.plot(data['epoch'], data['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_curves.png'))
plt.close()
def generate_evaluation_metrics(model, data_loader, device, output_dir, class_labels, data_info):
model.eval()
all_preds = []
all_labels = []
all_probs = []
all_files = []
with torch.no_grad():
for frames, labels, filenames in data_loader:
frames = frames.to(device)
labels = labels.to(device)
outputs = model(frames)
probs = torch.softmax(outputs, dim=1)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
all_files.extend(filenames)
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
# Generate error analysis file
error_file = os.path.join(output_dir, 'error_analysis.txt')
with open(error_file, 'w') as f:
f.write(f"Error Analysis for {data_info}\n")
f.write("=" * 80 + "\n\n")
# Overall accuracy
accuracy = (all_labels == all_preds).mean()
f.write(f"Overall Accuracy: {accuracy:.2%}\n\n")
# Per-class accuracy
f.write("Per-Class Accuracy:\n")
for i, class_name in enumerate(class_labels):
class_mask = all_labels == i
if class_mask.sum() > 0:
class_acc = (all_preds[class_mask] == i).mean()
f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n")
f.write("\n")
# Detailed error analysis
f.write("Misclassified Videos:\n")
f.write("-" * 80 + "\n")
f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n")
f.write("-" * 80 + "\n")
for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)):
if true_label != pred_label:
true_class = class_labels[true_label]
pred_class = class_labels[pred_label]
confidence = probs[pred_label]
f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n")
# Compute and plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Value')
plt.ylabel('Actual Value')
plt.title(f'Confusion Matrix\n{data_info}')
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
plt.close()
colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
# Precision-Recall Curve
plt.figure(figsize=(10, 8))
for i, class_label in enumerate(class_labels):
precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i])
average_precision = average_precision_score(all_labels == i, all_probs[:, i])
plt.plot(recall, precision, color=colors[i], lw=2,
label=f'{class_label} (AP = {average_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title(f'Precision-Recall Curve\n{data_info}')
plt.legend(loc="lower left")
plt.savefig(f'{output_dir}/precision_recall_curve.png')
plt.close()
# ROC Curve
plt.figure(figsize=(10, 8))
for i, class_label in enumerate(class_labels):
fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, color=colors[i], lw=2,
label=f'{class_label} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'Receiver Operating Characteristic (ROC) Curve\n{data_info}')
plt.legend(loc="lower right")
plt.savefig(f'{output_dir}/roc_curve.png')
plt.close()
return cm
def run_visualization(run_dir, data_path=None, test_csv=None):
"""
Run visualization for a specific training run
Args:
run_dir (str): Path to the run directory
data_path (str, optional): Override the data path from config
test_csv (str, optional): Override the test CSV path
"""
# Load configuration
config = get_config(run_dir)
class_labels = config['class_labels']
num_classes = config['num_classes']
# Update the config's data_path if provided
if data_path:
config['data_path'] = data_path
data_path = config['data_path']
# Paths
log_file = os.path.join(run_dir, 'training_log.csv')
model_path = get_latest_model_path(run_dir)
if test_csv is None:
test_csv = os.path.join(data_path, 'test.csv')
# Get the last directory of data_path and the file name
last_dir = os.path.basename(os.path.normpath(data_path))
file_name = os.path.basename(test_csv)
print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}")
# Create a directory for visualization outputs
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
os.makedirs(vis_dir, exist_ok=True)
# Create data_info string for chart headers
data_info = f'Data: {last_dir}, File: {file_name}'
# Plot training curves
plot_training_curves(log_file, vis_dir)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(num_classes, model_path, device, config['clip_model'])
model.eval()
# Create test dataset and dataloader
test_dataset = VideoDataset(test_csv, config)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
# Generate evaluation metrics
cm = generate_evaluation_metrics(model, test_loader, device, vis_dir, class_labels, data_info)
print(f"Visualization complete! Check the output directory: {vis_dir}")
return vis_dir, cm
if __name__ == "__main__":
# Find the most recent run directory
run_dir = get_latest_run_dir()
# add a data_path argument to visualize a specific dataset
run_visualization(run_dir)
|