|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from sklearn.metrics import confusion_matrix, precision_recall_curve, average_precision_score, roc_curve, auc |
|
|
from torch.utils.data import DataLoader |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import os |
|
|
import sys |
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
|
|
|
from src.dataset.dataset import VideoDataset |
|
|
from src.utils.utils import get_latest_model_path, get_latest_run_dir, get_config |
|
|
from src.models.model import load_model |
|
|
|
|
|
def plot_training_curves(log_file, output_dir): |
|
|
data = pd.read_csv(log_file) |
|
|
|
|
|
plt.figure(figsize=(12, 5)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.plot(data['epoch'], data['train_loss'], label='Train Loss') |
|
|
plt.plot(data['epoch'], data['val_loss'], label='Validation Loss') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Training and Validation Loss') |
|
|
plt.legend() |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
plt.plot(data['epoch'], data['train_accuracy'], label='Train Accuracy') |
|
|
plt.plot(data['epoch'], data['val_accuracy'], label='Validation Accuracy') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('Accuracy') |
|
|
plt.title('Training and Validation Accuracy') |
|
|
plt.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, 'training_curves.png')) |
|
|
plt.close() |
|
|
|
|
|
def generate_evaluation_metrics(model, data_loader, device, output_dir, class_labels, data_info): |
|
|
model.eval() |
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
all_probs = [] |
|
|
all_files = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for frames, labels, filenames in data_loader: |
|
|
frames = frames.to(device) |
|
|
labels = labels.to(device) |
|
|
|
|
|
outputs = model(frames) |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
_, predicted = outputs.max(1) |
|
|
|
|
|
all_preds.extend(predicted.cpu().numpy()) |
|
|
all_labels.extend(labels.cpu().numpy()) |
|
|
all_probs.extend(probs.cpu().numpy()) |
|
|
all_files.extend(filenames) |
|
|
|
|
|
all_labels = np.array(all_labels) |
|
|
all_preds = np.array(all_preds) |
|
|
all_probs = np.array(all_probs) |
|
|
|
|
|
|
|
|
error_file = os.path.join(output_dir, 'error_analysis.txt') |
|
|
with open(error_file, 'w') as f: |
|
|
f.write(f"Error Analysis for {data_info}\n") |
|
|
f.write("=" * 80 + "\n\n") |
|
|
|
|
|
|
|
|
accuracy = (all_labels == all_preds).mean() |
|
|
f.write(f"Overall Accuracy: {accuracy:.2%}\n\n") |
|
|
|
|
|
|
|
|
f.write("Per-Class Accuracy:\n") |
|
|
for i, class_name in enumerate(class_labels): |
|
|
class_mask = all_labels == i |
|
|
if class_mask.sum() > 0: |
|
|
class_acc = (all_preds[class_mask] == i).mean() |
|
|
f.write(f"{class_name}: {class_acc:.2%} ({(class_mask).sum()} samples)\n") |
|
|
f.write("\n") |
|
|
|
|
|
|
|
|
f.write("Misclassified Videos:\n") |
|
|
f.write("-" * 80 + "\n") |
|
|
f.write(f"{'Filename':<40} {'True Class':<20} {'Predicted Class':<20} Confidence\n") |
|
|
f.write("-" * 80 + "\n") |
|
|
|
|
|
for i, (true_label, pred_label, probs, filename) in enumerate(zip(all_labels, all_preds, all_probs, all_files)): |
|
|
if true_label != pred_label: |
|
|
true_class = class_labels[true_label] |
|
|
pred_class = class_labels[pred_label] |
|
|
confidence = probs[pred_label] |
|
|
f.write(f"{filename:<40} {true_class:<20} {pred_class:<20} {confidence:.2%}\n") |
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds) |
|
|
plt.figure(figsize=(10, 8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') |
|
|
plt.xlabel('Predicted Value') |
|
|
plt.ylabel('Actual Value') |
|
|
plt.title(f'Confusion Matrix\n{data_info}') |
|
|
plt.savefig(os.path.join(output_dir, 'confusion_matrix.png')) |
|
|
plt.close() |
|
|
|
|
|
colors = ['blue', 'red', 'green', 'yellow', 'purple', 'orange', 'pink', 'cyan'] |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
for i, class_label in enumerate(class_labels): |
|
|
precision, recall, _ = precision_recall_curve(all_labels == i, all_probs[:, i]) |
|
|
average_precision = average_precision_score(all_labels == i, all_probs[:, i]) |
|
|
plt.plot(recall, precision, color=colors[i], lw=2, |
|
|
label=f'{class_label} (AP = {average_precision:.2f})') |
|
|
|
|
|
plt.xlabel('Recall') |
|
|
plt.ylabel('Precision') |
|
|
plt.title(f'Precision-Recall Curve\n{data_info}') |
|
|
plt.legend(loc="lower left") |
|
|
plt.savefig(f'{output_dir}/precision_recall_curve.png') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
for i, class_label in enumerate(class_labels): |
|
|
fpr, tpr, _ = roc_curve(all_labels == i, all_probs[:, i]) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
plt.plot(fpr, tpr, color=colors[i], lw=2, |
|
|
label=f'{class_label} (AUC = {roc_auc:.2f})') |
|
|
|
|
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') |
|
|
plt.xlim([0.0, 1.0]) |
|
|
plt.ylim([0.0, 1.05]) |
|
|
plt.xlabel('False Positive Rate') |
|
|
plt.ylabel('True Positive Rate') |
|
|
plt.title(f'Receiver Operating Characteristic (ROC) Curve\n{data_info}') |
|
|
plt.legend(loc="lower right") |
|
|
plt.savefig(f'{output_dir}/roc_curve.png') |
|
|
plt.close() |
|
|
|
|
|
return cm |
|
|
|
|
|
def run_visualization(run_dir, data_path=None, test_csv=None): |
|
|
""" |
|
|
Run visualization for a specific training run |
|
|
|
|
|
Args: |
|
|
run_dir (str): Path to the run directory |
|
|
data_path (str, optional): Override the data path from config |
|
|
test_csv (str, optional): Override the test CSV path |
|
|
""" |
|
|
|
|
|
config = get_config(run_dir) |
|
|
|
|
|
class_labels = config['class_labels'] |
|
|
num_classes = config['num_classes'] |
|
|
|
|
|
|
|
|
if data_path: |
|
|
config['data_path'] = data_path |
|
|
data_path = config['data_path'] |
|
|
|
|
|
|
|
|
log_file = os.path.join(run_dir, 'training_log.csv') |
|
|
model_path = get_latest_model_path(run_dir) |
|
|
|
|
|
if test_csv is None: |
|
|
test_csv = os.path.join(data_path, 'test.csv') |
|
|
|
|
|
|
|
|
last_dir = os.path.basename(os.path.normpath(data_path)) |
|
|
file_name = os.path.basename(test_csv) |
|
|
|
|
|
print(f"Running visualization for {data_path} with {test_csv} from CWD {os.getcwd()}") |
|
|
|
|
|
|
|
|
vis_dir = os.path.join(run_dir, f'visualization_{last_dir}_{file_name.split(".")[0]}') |
|
|
os.makedirs(vis_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
data_info = f'Data: {last_dir}, File: {file_name}' |
|
|
|
|
|
|
|
|
plot_training_curves(log_file, vis_dir) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = load_model(num_classes, model_path, device, config['clip_model']) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
test_dataset = VideoDataset(test_csv, config) |
|
|
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False) |
|
|
|
|
|
|
|
|
cm = generate_evaluation_metrics(model, test_loader, device, vis_dir, class_labels, data_info) |
|
|
|
|
|
print(f"Visualization complete! Check the output directory: {vis_dir}") |
|
|
return vis_dir, cm |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
run_dir = get_latest_run_dir() |
|
|
|
|
|
|
|
|
run_visualization(run_dir) |
|
|
|