Zorrojurro commited on
Commit
114dec9
·
verified ·
1 Parent(s): d9ca01e

Upload src/evaluation/visualize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/evaluation/visualize.py +270 -0
src/evaluation/visualize.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for the Thermal Pattern Analysis project.
3
+
4
+ Provides:
5
+ - Preprocessing step visualisation
6
+ - Confusion matrix heatmap
7
+ - ROC curve
8
+ - Attention weights over a sequence
9
+ - Grad-CAM heatmap overlay
10
+ - Training history plots
11
+ """
12
+
13
+ import os
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from pathlib import Path
20
+ from typing import Optional, List
21
+ from sklearn.metrics import confusion_matrix, roc_curve, auc
22
+
23
+
24
+ class Visualizer:
25
+ """Static visualisation helpers; all methods save to disk."""
26
+
27
+ def __init__(self, output_dir: str = "results/visualizations"):
28
+ self.output_dir = Path(output_dir)
29
+ self.output_dir.mkdir(parents=True, exist_ok=True)
30
+ plt.style.use("seaborn-v0_8-darkgrid")
31
+
32
+ # ------------------------------------------------------------------
33
+ # Preprocessing
34
+ # ------------------------------------------------------------------
35
+
36
+ def plot_preprocessing_steps(
37
+ self,
38
+ original: np.ndarray,
39
+ resized: np.ndarray,
40
+ denoised: np.ndarray,
41
+ enhanced: np.ndarray,
42
+ normalized: np.ndarray,
43
+ filename: str = "preprocessing_steps.png",
44
+ ):
45
+ """Visual comparison of each preprocessing stage."""
46
+ stages = [
47
+ ("Original", original),
48
+ ("Resized", resized),
49
+ ("Denoised", denoised),
50
+ ("CLAHE Enhanced", enhanced),
51
+ ("Normalized", normalized),
52
+ ]
53
+ fig, axes = plt.subplots(1, len(stages), figsize=(20, 4))
54
+ for ax, (title, img) in zip(axes, stages):
55
+ ax.imshow(img, cmap="inferno")
56
+ ax.set_title(title, fontsize=12)
57
+ ax.axis("off")
58
+
59
+ plt.suptitle("Image Preprocessing Pipeline", fontsize=14, y=1.02)
60
+ plt.tight_layout()
61
+ plt.savefig(self.output_dir / filename, dpi=150, bbox_inches="tight")
62
+ plt.close()
63
+
64
+ # ------------------------------------------------------------------
65
+ # Confusion Matrix
66
+ # ------------------------------------------------------------------
67
+
68
+ def plot_confusion_matrix(
69
+ self,
70
+ y_true: list,
71
+ y_pred: list,
72
+ labels: list = None,
73
+ filename: str = "confusion_matrix.png",
74
+ ):
75
+ """Plot a confusion matrix heatmap."""
76
+ if labels is None:
77
+ labels = ["Normal", "Abnormal"]
78
+
79
+ cm = confusion_matrix(y_true, y_pred)
80
+
81
+ fig, ax = plt.subplots(figsize=(8, 6))
82
+ sns.heatmap(
83
+ cm,
84
+ annot=True,
85
+ fmt="d",
86
+ cmap="Blues",
87
+ xticklabels=labels,
88
+ yticklabels=labels,
89
+ ax=ax,
90
+ )
91
+ ax.set_xlabel("Predicted", fontsize=12)
92
+ ax.set_ylabel("Actual", fontsize=12)
93
+ ax.set_title("Confusion Matrix", fontsize=14)
94
+ plt.tight_layout()
95
+ plt.savefig(self.output_dir / filename, dpi=150)
96
+ plt.close()
97
+
98
+ # ------------------------------------------------------------------
99
+ # ROC Curve
100
+ # ------------------------------------------------------------------
101
+
102
+ def plot_roc_curve(
103
+ self,
104
+ y_true: list,
105
+ y_scores: list,
106
+ filename: str = "roc_curve.png",
107
+ ):
108
+ """Plot the receiver operating characteristic curve."""
109
+ fpr, tpr, _ = roc_curve(y_true, y_scores)
110
+ roc_auc = auc(fpr, tpr)
111
+
112
+ fig, ax = plt.subplots(figsize=(8, 6))
113
+ ax.plot(fpr, tpr, color="#4C72B0", lw=2, label=f"AUC = {roc_auc:.4f}")
114
+ ax.plot([0, 1], [0, 1], "k--", lw=1, alpha=0.5)
115
+ ax.set_xlim([0, 1])
116
+ ax.set_ylim([0, 1.05])
117
+ ax.set_xlabel("False Positive Rate", fontsize=12)
118
+ ax.set_ylabel("True Positive Rate", fontsize=12)
119
+ ax.set_title("ROC Curve", fontsize=14)
120
+ ax.legend(loc="lower right", fontsize=12)
121
+ plt.tight_layout()
122
+ plt.savefig(self.output_dir / filename, dpi=150)
123
+ plt.close()
124
+
125
+ # ------------------------------------------------------------------
126
+ # Attention Weights
127
+ # ------------------------------------------------------------------
128
+
129
+ def plot_attention_weights(
130
+ self,
131
+ images: list,
132
+ weights: np.ndarray,
133
+ filename: str = "attention_weights.png",
134
+ ):
135
+ """
136
+ Visualise attention weights over a sequence of images.
137
+
138
+ Args:
139
+ images: List of (H, W) numpy arrays.
140
+ weights: 1-D array of attention weights, len = len(images).
141
+ """
142
+ n = len(images)
143
+ fig, axes = plt.subplots(2, 1, figsize=(max(n * 2, 12), 6), gridspec_kw={"height_ratios": [3, 1]})
144
+
145
+ # Top: images
146
+ ax_img = axes[0]
147
+ concat = np.concatenate(images, axis=1)
148
+ ax_img.imshow(concat, cmap="inferno")
149
+ ax_img.set_title("Sequence Frames", fontsize=12)
150
+ ax_img.axis("off")
151
+
152
+ # Bottom: bar chart of weights
153
+ ax_bar = axes[1]
154
+ colors = plt.cm.RdYlGn_r(weights / (weights.max() + 1e-8))
155
+ ax_bar.bar(range(n), weights, color=colors, edgecolor="black", linewidth=0.5)
156
+ ax_bar.set_xlabel("Frame Index", fontsize=11)
157
+ ax_bar.set_ylabel("Attention", fontsize=11)
158
+ ax_bar.set_title("Attention Weights (higher = more important)", fontsize=12)
159
+
160
+ plt.tight_layout()
161
+ plt.savefig(self.output_dir / filename, dpi=150)
162
+ plt.close()
163
+
164
+ # ------------------------------------------------------------------
165
+ # Grad-CAM
166
+ # ------------------------------------------------------------------
167
+
168
+ def plot_gradcam(
169
+ self,
170
+ original_image: np.ndarray,
171
+ heatmap: np.ndarray,
172
+ filename: str = "gradcam.png",
173
+ ):
174
+ """
175
+ Overlay a Grad-CAM heatmap on the original image.
176
+
177
+ Args:
178
+ original_image: (H, W) normalised float image.
179
+ heatmap: (H, W) Grad-CAM activation map.
180
+ """
181
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
182
+
183
+ axes[0].imshow(original_image, cmap="gray")
184
+ axes[0].set_title("Original", fontsize=12)
185
+ axes[0].axis("off")
186
+
187
+ axes[1].imshow(heatmap, cmap="jet")
188
+ axes[1].set_title("Grad-CAM Heatmap", fontsize=12)
189
+ axes[1].axis("off")
190
+
191
+ axes[2].imshow(original_image, cmap="gray")
192
+ axes[2].imshow(heatmap, cmap="jet", alpha=0.5)
193
+ axes[2].set_title("Overlay", fontsize=12)
194
+ axes[2].axis("off")
195
+
196
+ plt.suptitle("Grad-CAM Visualization", fontsize=14)
197
+ plt.tight_layout()
198
+ plt.savefig(self.output_dir / filename, dpi=150, bbox_inches="tight")
199
+ plt.close()
200
+
201
+ # ------------------------------------------------------------------
202
+ # Training history
203
+ # ------------------------------------------------------------------
204
+
205
+ def plot_training_history(
206
+ self,
207
+ train_losses: list,
208
+ val_losses: list,
209
+ train_accs: list = None,
210
+ val_accs: list = None,
211
+ filename: str = "training_history.png",
212
+ ):
213
+ """Plot loss and accuracy curves over epochs."""
214
+ n_plots = 2 if train_accs else 1
215
+ fig, axes = plt.subplots(1, n_plots, figsize=(7 * n_plots, 5))
216
+ if n_plots == 1:
217
+ axes = [axes]
218
+
219
+ # Loss
220
+ axes[0].plot(train_losses, label="Train", linewidth=2)
221
+ axes[0].plot(val_losses, label="Validation", linewidth=2)
222
+ axes[0].set_xlabel("Epoch")
223
+ axes[0].set_ylabel("Loss")
224
+ axes[0].set_title("Training & Validation Loss")
225
+ axes[0].legend()
226
+
227
+ # Accuracy
228
+ if train_accs:
229
+ axes[1].plot(train_accs, label="Train", linewidth=2)
230
+ axes[1].plot(val_accs, label="Validation", linewidth=2)
231
+ axes[1].set_xlabel("Epoch")
232
+ axes[1].set_ylabel("Accuracy")
233
+ axes[1].set_title("Training & Validation Accuracy")
234
+ axes[1].legend()
235
+
236
+ plt.tight_layout()
237
+ plt.savefig(self.output_dir / filename, dpi=150)
238
+ plt.close()
239
+
240
+ # ------------------------------------------------------------------
241
+ # Anomaly score distribution
242
+ # ------------------------------------------------------------------
243
+
244
+ def plot_anomaly_distribution(
245
+ self,
246
+ normal_scores: list,
247
+ abnormal_scores: list,
248
+ threshold: float = 0.7,
249
+ filename: str = "anomaly_distribution.png",
250
+ ):
251
+ """
252
+ Plot the distribution of anomaly scores for normal vs abnormal
253
+ sequences with the decision threshold.
254
+ """
255
+ fig, ax = plt.subplots(figsize=(10, 6))
256
+
257
+ ax.hist(normal_scores, bins=30, alpha=0.6, label="Normal", color="#4C72B0")
258
+ ax.hist(abnormal_scores, bins=30, alpha=0.6, label="Abnormal", color="#C44E52")
259
+ ax.axvline(
260
+ x=threshold, color="black", linestyle="--",
261
+ linewidth=2, label=f"Threshold ({threshold})"
262
+ )
263
+
264
+ ax.set_xlabel("Similarity Score", fontsize=12)
265
+ ax.set_ylabel("Frequency", fontsize=12)
266
+ ax.set_title("Anomaly Score Distribution", fontsize=14)
267
+ ax.legend(fontsize=11)
268
+ plt.tight_layout()
269
+ plt.savefig(self.output_dir / filename, dpi=150)
270
+ plt.close()