File size: 9,438 Bytes
b20701a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | '''
Utility functions for visualization
'''
from pathlib import Path
import random
import cv2
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from src.data_utils import get_sample_frames
# Function for plotting sample frames from videos in training set
def plot_sample_frames(root, sample_classes=["Ăn", "Nghỉ ngơi", "Chạy"], n_frames=5, save_path=None):
# Set up
n_sample_classes = len(sample_classes)
fig, axes = plt.subplots(n_sample_classes, n_frames, figsize=(15,8))
for row, cls in enumerate(sample_classes):
# Get sample video
sample_dir = Path(root) / cls
if not (sample_dir.exists() and sample_dir.is_dir()):
print(f"The directory \"{cls}\" is not available, skipping")
continue
all_path = [video_path for video_path in sample_dir.iterdir()]
sample_path = random.choice(all_path)
# Get sample frames
attempt = 0
sample_frames = get_sample_frames(sample_path, num_frames=n_frames)
if sample_frames is None:
while sample_frames is None and attempt < 10:
sample_path = random.choice(all_path)
sample_frames = get_sample_frames(sample_path, num_frames=n_frames)
attempt += 1
# Plotting
for col in range(n_frames):
ax = axes[row, col]
ax.imshow(sample_frames[col])
ax.axis("off")
# Labeling each row
fig.text(
0.1,
1 - (row + 0.5) / n_sample_classes,
cls,
ha="left",
va="center",
fontsize=16
)
plt.tight_layout(rect=[0.05, 0, 1, 1])
plt.suptitle("Sample Frames", fontsize=16)
plt.subplots_adjust(top=0.88, left=0.2)
if save_path:
plt.savefig(save_path)
plt.show()
def plot_resolution_distribution(all_width, all_height, save_path=None):
# Set up
plt.figure(figsize=(8, 6))
# Plotting
square_res_x = square_res_y = np.arange(250, step=1)
plt.plot(square_res_x, square_res_y, "--r", label="Square Frame Ratio (1:1)")
plt.scatter(all_width, all_height, alpha=0.5)
plt.title("Resolution Distribution")
plt.xlabel("Width (pixel)")
plt.ylabel("Height (pixel)")
plt.legend()
if save_path:
plt.savefig(save_path)
plt.show()
def plot_frame_count_distribution(all_frame_count, save_path=None):
# Set up
plt.figure(figsize=(8, 6))
frame_count_dist = all_frame_count.value_counts()
# Plotting
ax = sns.barplot(
x=frame_count_dist.index,
y=frame_count_dist.values,
color="skyblue",
edgecolor="black",
linewidth=1.2
)
ax.set_axisbelow(True)
ax.grid(axis="y", linestyle="--")
plt.title("Frame Count Distribution")
plt.xlabel("Number of Frames")
plt.ylabel("Number of Videos")
if save_path:
plt.savefig(save_path)
plt.show()
def plot_class_balance(labels, save_path=None):
# Set up
plt.figure(figsize=(18, 6))
class_count = labels.value_counts()
# Plotting
ax = sns.barplot(
x=class_count.index,
y=class_count.values,
hue=class_count.index,
palette="flare"
)
ax.set_axisbelow(True)
ax.grid(axis="y", linestyle="--")
plt.xticks(rotation=90, fontsize=10)
plt.title("Number of Videos per Class", fontsize=16)
plt.xlabel("Class", fontsize=12)
plt.ylabel("Number of Videos", fontsize=12)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()
def plot_confusion_matrix(
y_true, y_pred, labels, display_labels, top_k=10, figsize=(20, 26),
normalize="true", save_path=None
):
'''
Plot confusion matrix and a table of top_k misclassified pairs
Args:
y_true (lst) : true labels
y_pred (lst) : predictions from model
labels (lst) : list of integer labels
display_labels (lst): labels to display
top_k (int) : number of classes with highest confusion to include
in confusion matrix, default: 10
figsize (tuple) : figure size, default: (20, 26)
fontsize (float) : size of texts for labels
normalize (str) : option to normalize confusion matrix
(same in sklearn.metrics.confusion_matrix),
but only accepts 2 value: "true" (normalize
by row) and None (no normalization), default: "true"
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
# Full confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=labels, normalize=normalize)
# Find (i, j) indices (i != j) that have highest confusion
confusions = []
for i in range(len(cm)):
for j in range(len(cm)):
if i != j and cm[i][j] > 0:
confusions.append((i, j, cm[i][j]))
# Sorting and find top-k confused pairs
top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k]
# Set up plots
fig, axes = plt.subplots(
nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]}
)
# Plot confusion matrix
sns.heatmap(
cm, cmap="Blues", linewidths=0.5, linecolor="gray",
xticklabels=display_labels, yticklabels=display_labels,
cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0]
)
axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20)
axes[0].set_xlabel("Predicted Label", fontsize=14)
axes[0].set_ylabel("True Label", fontsize=14)
axes[0].tick_params(axis="x", labelsize=13)
axes[0].tick_params(axis="y", labelsize=13)
plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right")
# Table of top_k misclassified pairs
columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"]
data = [
[display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)]
for i, j, v in top_confusions
]
axes[1].axis("off")
table = axes[1].table(
cellText=data, colLabels=columns, loc="center", cellLoc="center",
colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1]
)
table.auto_set_font_size(False)
table.set_fontsize(14)
table.scale(1, 2)
axes[1].set_title(
f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10
)
plt.tight_layout(h_pad=5)
if save_path:
plt.savefig(save_path, bbox_inches="tight")
plt.show()
def plot_training_progress(
avg_training_losses,
avg_val_losses,
precision_scores,
recall_scores,
f1_scores,
lr_changes,
save_path=None
):
'''
Plot training process over epochs, specifically, 3 subplots are created:
- One plot for average train and validation loss
- One plot for accuracy and weighted F1 score on validation data
- One plot for learning rates
Args:
avg_training_losses (lst): average training loss
avg_val_losses (lst) : average validation loss
precision_scores (lst) : precision on validation data
recall_scores (lst) : recall on validation data
f1_scores (lst) : macro F1 score on validation data
lr_changes (lst) : learning rates
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15))
n_epochs = [i+1 for i in range(len(avg_training_losses))]
# Avg Training vs Validation loss
axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue")
axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red")
axes[0].set(
xlabel="Epoch",
ylabel="Average Loss",
title="Average Training vs Validation Loss"
)
axes[0].legend(loc="upper right")
axes[0].grid(True)
# Precision, Recall and Macro F1 score on validation data
axes[1].plot(n_epochs, precision_scores, label="Precision", color="blue")
axes[1].plot(n_epochs, recall_scores, label="Recall", color="green")
axes[1].plot(n_epochs, f1_scores, label="Macro F1 Score", color="red")
axes[1].set(
xlabel="Epoch",
ylabel="Score (%)",
title="Validation Precision, Recall and Macro F1 Score"
)
axes[1].legend(loc="lower right")
axes[1].grid(True)
# Learning rate
axes[2].plot(n_epochs, lr_changes)
axes[2].set(
xlabel="Epoch",
ylabel="Learning Rate",
title="Learning Rate Changes"
)
axes[2].grid(True)
plt.suptitle("Training Process", fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches="tight")
plt.show() |