| '''
|
| Utility functions for visualization
|
| '''
|
|
|
| from pathlib import Path
|
| import random
|
| import cv2
|
| import numpy as np
|
| from sklearn.metrics import confusion_matrix
|
| import matplotlib.pyplot as plt
|
| import seaborn as sns
|
|
|
| from src.data_utils import get_sample_frames
|
|
|
|
|
|
|
| def plot_sample_frames(root, sample_classes=["Ăn", "Nghỉ ngơi", "Chạy"], n_frames=5, save_path=None):
|
|
|
|
|
| n_sample_classes = len(sample_classes)
|
| fig, axes = plt.subplots(n_sample_classes, n_frames, figsize=(15,8))
|
|
|
| for row, cls in enumerate(sample_classes):
|
|
|
| sample_dir = Path(root) / cls
|
| if not (sample_dir.exists() and sample_dir.is_dir()):
|
| print(f"The directory \"{cls}\" is not available, skipping")
|
| continue
|
|
|
| all_path = [video_path for video_path in sample_dir.iterdir()]
|
| sample_path = random.choice(all_path)
|
|
|
|
|
| attempt = 0
|
| sample_frames = get_sample_frames(sample_path, num_frames=n_frames)
|
| if sample_frames is None:
|
| while sample_frames is None and attempt < 10:
|
| sample_path = random.choice(all_path)
|
| sample_frames = get_sample_frames(sample_path, num_frames=n_frames)
|
| attempt += 1
|
|
|
|
|
| for col in range(n_frames):
|
| ax = axes[row, col]
|
| ax.imshow(sample_frames[col])
|
| ax.axis("off")
|
|
|
|
|
| fig.text(
|
| 0.1,
|
| 1 - (row + 0.5) / n_sample_classes,
|
| cls,
|
| ha="left",
|
| va="center",
|
| fontsize=16
|
| )
|
|
|
| plt.tight_layout(rect=[0.05, 0, 1, 1])
|
| plt.suptitle("Sample Frames", fontsize=16)
|
| plt.subplots_adjust(top=0.88, left=0.2)
|
|
|
| if save_path:
|
| plt.savefig(save_path)
|
|
|
| plt.show()
|
|
|
|
|
| def plot_resolution_distribution(all_width, all_height, save_path=None):
|
|
|
| plt.figure(figsize=(8, 6))
|
|
|
|
|
| square_res_x = square_res_y = np.arange(250, step=1)
|
| plt.plot(square_res_x, square_res_y, "--r", label="Square Frame Ratio (1:1)")
|
| plt.scatter(all_width, all_height, alpha=0.5)
|
| plt.title("Resolution Distribution")
|
| plt.xlabel("Width (pixel)")
|
| plt.ylabel("Height (pixel)")
|
| plt.legend()
|
|
|
| if save_path:
|
| plt.savefig(save_path)
|
|
|
| plt.show()
|
|
|
|
|
| def plot_frame_count_distribution(all_frame_count, save_path=None):
|
|
|
| plt.figure(figsize=(8, 6))
|
| frame_count_dist = all_frame_count.value_counts()
|
|
|
|
|
| ax = sns.barplot(
|
| x=frame_count_dist.index,
|
| y=frame_count_dist.values,
|
| color="skyblue",
|
| edgecolor="black",
|
| linewidth=1.2
|
| )
|
| ax.set_axisbelow(True)
|
| ax.grid(axis="y", linestyle="--")
|
|
|
| plt.title("Frame Count Distribution")
|
| plt.xlabel("Number of Frames")
|
| plt.ylabel("Number of Videos")
|
|
|
| if save_path:
|
| plt.savefig(save_path)
|
|
|
| plt.show()
|
|
|
|
|
| def plot_class_balance(labels, save_path=None):
|
|
|
| plt.figure(figsize=(18, 6))
|
| class_count = labels.value_counts()
|
|
|
|
|
| ax = sns.barplot(
|
| x=class_count.index,
|
| y=class_count.values,
|
| hue=class_count.index,
|
| palette="flare"
|
| )
|
| ax.set_axisbelow(True)
|
| ax.grid(axis="y", linestyle="--")
|
| plt.xticks(rotation=90, fontsize=10)
|
|
|
| plt.title("Number of Videos per Class", fontsize=16)
|
| plt.xlabel("Class", fontsize=12)
|
| plt.ylabel("Number of Videos", fontsize=12)
|
|
|
| if save_path:
|
| plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
|
| plt.show()
|
|
|
|
|
| def plot_confusion_matrix(
|
| y_true, y_pred, labels, display_labels, top_k=10, figsize=(20, 26),
|
| normalize="true", save_path=None
|
| ):
|
| '''
|
| Plot confusion matrix and a table of top_k misclassified pairs
|
|
|
| Args:
|
| y_true (lst) : true labels
|
| y_pred (lst) : predictions from model
|
| labels (lst) : list of integer labels
|
| display_labels (lst): labels to display
|
| top_k (int) : number of classes with highest confusion to include
|
| in confusion matrix, default: 10
|
| figsize (tuple) : figure size, default: (20, 26)
|
| fontsize (float) : size of texts for labels
|
| normalize (str) : option to normalize confusion matrix
|
| (same in sklearn.metrics.confusion_matrix),
|
| but only accepts 2 value: "true" (normalize
|
| by row) and None (no normalization), default: "true"
|
| save_path (str) : path to save the plot if provided, default: None
|
|
|
| Returns:
|
| None
|
| '''
|
|
|
|
|
| cm = confusion_matrix(y_true, y_pred, labels=labels, normalize=normalize)
|
|
|
|
|
| confusions = []
|
| for i in range(len(cm)):
|
| for j in range(len(cm)):
|
| if i != j and cm[i][j] > 0:
|
| confusions.append((i, j, cm[i][j]))
|
|
|
|
|
| top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k]
|
|
|
|
|
| fig, axes = plt.subplots(
|
| nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]}
|
| )
|
|
|
|
|
| sns.heatmap(
|
| cm, cmap="Blues", linewidths=0.5, linecolor="gray",
|
| xticklabels=display_labels, yticklabels=display_labels,
|
| cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0]
|
| )
|
|
|
| axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20)
|
| axes[0].set_xlabel("Predicted Label", fontsize=14)
|
| axes[0].set_ylabel("True Label", fontsize=14)
|
| axes[0].tick_params(axis="x", labelsize=13)
|
| axes[0].tick_params(axis="y", labelsize=13)
|
| plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right")
|
|
|
|
|
| columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"]
|
| data = [
|
| [display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)]
|
| for i, j, v in top_confusions
|
| ]
|
|
|
| axes[1].axis("off")
|
| table = axes[1].table(
|
| cellText=data, colLabels=columns, loc="center", cellLoc="center",
|
| colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1]
|
| )
|
| table.auto_set_font_size(False)
|
| table.set_fontsize(14)
|
| table.scale(1, 2)
|
| axes[1].set_title(
|
| f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10
|
| )
|
|
|
| plt.tight_layout(h_pad=5)
|
|
|
| if save_path:
|
| plt.savefig(save_path, bbox_inches="tight")
|
|
|
| plt.show()
|
|
|
|
|
| def plot_training_progress(
|
| avg_training_losses,
|
| avg_val_losses,
|
| precision_scores,
|
| recall_scores,
|
| f1_scores,
|
| lr_changes,
|
| save_path=None
|
| ):
|
| '''
|
| Plot training process over epochs, specifically, 3 subplots are created:
|
| - One plot for average train and validation loss
|
| - One plot for accuracy and weighted F1 score on validation data
|
| - One plot for learning rates
|
|
|
| Args:
|
| avg_training_losses (lst): average training loss
|
| avg_val_losses (lst) : average validation loss
|
| precision_scores (lst) : precision on validation data
|
| recall_scores (lst) : recall on validation data
|
| f1_scores (lst) : macro F1 score on validation data
|
| lr_changes (lst) : learning rates
|
| save_path (str) : path to save the plot if provided, default: None
|
|
|
| Returns:
|
| None
|
| '''
|
|
|
| fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15))
|
| n_epochs = [i+1 for i in range(len(avg_training_losses))]
|
|
|
|
|
| axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue")
|
| axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red")
|
| axes[0].set(
|
| xlabel="Epoch",
|
| ylabel="Average Loss",
|
| title="Average Training vs Validation Loss"
|
| )
|
| axes[0].legend(loc="upper right")
|
| axes[0].grid(True)
|
|
|
|
|
| axes[1].plot(n_epochs, precision_scores, label="Precision", color="blue")
|
| axes[1].plot(n_epochs, recall_scores, label="Recall", color="green")
|
| axes[1].plot(n_epochs, f1_scores, label="Macro F1 Score", color="red")
|
| axes[1].set(
|
| xlabel="Epoch",
|
| ylabel="Score (%)",
|
| title="Validation Precision, Recall and Macro F1 Score"
|
| )
|
| axes[1].legend(loc="lower right")
|
| axes[1].grid(True)
|
|
|
|
|
| axes[2].plot(n_epochs, lr_changes)
|
| axes[2].set(
|
| xlabel="Epoch",
|
| ylabel="Learning Rate",
|
| title="Learning Rate Changes"
|
| )
|
| axes[2].grid(True)
|
|
|
| plt.suptitle("Training Process", fontsize=16)
|
| plt.tight_layout()
|
|
|
| if save_path:
|
| plt.savefig(save_path, bbox_inches="tight")
|
|
|
| plt.show() |