bawolf's picture
wip
c850c95
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.dataset.dataset import VideoDataset
from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config
from src.models.model import load_model
def plot_training_curves(log_file, output_dir):
data = pd.read_csv(log_file)
plt.figure(figsize=(12, 5))
# Plot loss curves
plt.subplot(1, 2, 1)
plt.plot(data['epoch'], data['train_loss'], label='Train Loss')
plt.plot(data['epoch'], data['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
# Plot accuracy curves
plt.subplot(1, 2, 2)
plt.plot(data['epoch'], data['train_accuracy'], label='Train Accuracy')
plt.plot(data['epoch'], data['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_curves.png'))
plt.close()
def generate_evaluation_metrics(model, data_loader, device, output_dir, class_labels, data_info):
model.eval()
all_preds = []
all_labels = []
all_probs = []
all_files = []
with torch.no_grad():
for frames, labels, filenames in data_loader:
frames = frames.to(device)
labels = labels.to(device)
outputs = model(frames)
probs = torch.softmax(outputs, dim=1)
_, predicted = outputs.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
all_files.extend(filenames)
all_labels = np.array(all_labels)
all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
# Generate error analysis file
error_file = os.path.join(output_dir, 'error_analysis.txt')
with open(error_file, 'w') as f:
f.write(f"Error Analysis for {data_info}\n")
f.write("=" * 80 + "\n\n")
# Overall accuracy
accuracy = (all_labels == all_preds).mean()
f.write(f"Overall Accuracy: {accuracy:.2%}\n\n")
# Per-class accuracy
f.write("Per-Class Accuracy:\n")
for i, class_name in enumerate(class_labels):
class_mask = all_labels == i
if class_mask.sum() > 0:
class_acc = (all_preds[class_mask] == i).mean()
f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n")
f.write("\n")
# Detailed error analysis
f.write("Misclassified Videos:\n")
f.write("-" * 80 + "\n")
f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n")
f.write("-" * 80 + "\n")
for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)):
if true_label != pred_label:
true_class = class_labels[true_label]
pred_class = class_labels[pred_label]
confidence = probs[pred_label]
f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n")
# Compute and plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Value')
plt.ylabel('Actual Value')
plt.title(f'Confusion Matrix\n{data_info}')
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'))
plt.close()
colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan']
# Precision-Recall Curve
plt.figure(figsize=(10, 8))
for i, class_label in enumerate(class_labels):
precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i])
average_precision = average_precision_score(all_labels == i, all_probs[:, i])
plt.plot(recall, precision, color=colors[i], lw=2,
label=f'{class_label} (AP = {average_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title(f'Precision-Recall Curve\n{data_info}')
plt.legend(loc="lower left")
plt.savefig(f'{output_dir}/precision_recall_curve.png')
plt.close()
# ROC Curve
plt.figure(figsize=(10, 8))
for i, class_label in enumerate(class_labels):
fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, color=colors[i], lw=2,
label=f'{class_label} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'Receiver Operating Characteristic (ROC) Curve\n{data_info}')
plt.legend(loc="lower right")
plt.savefig(f'{output_dir}/roc_curve.png')
plt.close()
return cm
def run_visualization(run_dir, data_path=None, test_csv=None):
"""
Run visualization for a specific training run
Args:
run_dir (str): Path to the run directory
data_path (str, optional): Override the data path from config
test_csv (str, optional): Override the test CSV path
"""
# Load configuration
config = get_config(run_dir)
class_labels = config['class_labels']
num_classes = config['num_classes']
# Update the config's data_path if provided
if data_path:
config['data_path'] = data_path
data_path = config['data_path']
# Paths
log_file = os.path.join(run_dir, 'training_log.csv')
model_path = get_latest_model_path(run_dir)
if test_csv is None:
test_csv = os.path.join(data_path, 'test.csv')
# Get the last directory of data_path and the file name
last_dir = os.path.basename(os.path.normpath(data_path))
file_name = os.path.basename(test_csv)
print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}")
# Create a directory for visualization outputs
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}')
os.makedirs(vis_dir, exist_ok=True)
# Create data_info string for chart headers
data_info = f'Data: {last_dir}, File: {file_name}'
# Plot training curves
plot_training_curves(log_file, vis_dir)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(num_classes, model_path, device, config['clip_model'])
model.eval()
# Create test dataset and dataloader
test_dataset = VideoDataset(test_csv, config)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
# Generate evaluation metrics
cm = generate_evaluation_metrics(model, test_loader, device, vis_dir, class_labels, data_info)
print(f"Visualization complete! Check the output directory: {vis_dir}")
return vis_dir, cm
if __name__ == "__main__":
# Find the most recent run directory
run_dir = get_latest_run_dir()
# add a data_path argument to visualize a specific dataset
run_visualization(run_dir)