AnikS22 commited on
Commit
28564a3
·
verified ·
1 Parent(s): e52f2ac

Upload src/visualize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/visualize.py +244 -0
src/visualize.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for QC at every pipeline stage.
3
+
4
+ Generates overlay images showing predictions on raw EM images:
5
+ - Cyan circles for 6nm particles
6
+ - Yellow circles for 12nm particles
7
+ """
8
+
9
+ import numpy as np
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as mpatches
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional
16
+
17
+
18
+ # Color scheme
19
+ COLORS = {
20
+ "6nm": (0, 255, 255), # cyan
21
+ "12nm": (255, 255, 0), # yellow
22
+ "6nm_pred": (0, 200, 200),
23
+ "12nm_pred": (200, 200, 0),
24
+ }
25
+
26
+ RADII = {"6nm": 6, "12nm": 12}
27
+
28
+
29
+ def overlay_annotations(
30
+ image: np.ndarray,
31
+ annotations: Dict[str, np.ndarray],
32
+ title: str = "",
33
+ save_path: Optional[Path] = None,
34
+ predictions: Optional[List[dict]] = None,
35
+ figsize: tuple = (12, 12),
36
+ ) -> plt.Figure:
37
+ """
38
+ Overlay ground truth annotations (and optional predictions) on image.
39
+
40
+ Args:
41
+ image: (H, W) grayscale image
42
+ annotations: {'6nm': Nx2, '12nm': Mx2} pixel coordinates
43
+ title: figure title
44
+ save_path: if provided, save figure here
45
+ predictions: optional list of {'x', 'y', 'class', 'conf'}
46
+ figsize: figure size
47
+
48
+ Returns:
49
+ matplotlib Figure
50
+ """
51
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
52
+ ax.imshow(image, cmap="gray")
53
+
54
+ # Ground truth circles (solid)
55
+ for cls, coords in annotations.items():
56
+ if len(coords) == 0:
57
+ continue
58
+ color_rgb = np.array(COLORS[cls]) / 255.0
59
+ radius = RADII[cls]
60
+ for x, y in coords:
61
+ circle = plt.Circle(
62
+ (x, y), radius, fill=False,
63
+ edgecolor=color_rgb, linewidth=1.5,
64
+ )
65
+ ax.add_patch(circle)
66
+
67
+ # Predictions (dashed)
68
+ if predictions:
69
+ for det in predictions:
70
+ cls = det["class"]
71
+ color_rgb = np.array(COLORS.get(f"{cls}_pred", COLORS[cls])) / 255.0
72
+ radius = RADII[cls]
73
+ circle = plt.Circle(
74
+ (det["x"], det["y"]), radius, fill=False,
75
+ edgecolor=color_rgb, linewidth=1.0, linestyle="--",
76
+ )
77
+ ax.add_patch(circle)
78
+ # Confidence label
79
+ ax.text(
80
+ det["x"] + radius + 2, det["y"],
81
+ f'{det["conf"]:.2f}',
82
+ color=color_rgb, fontsize=6,
83
+ )
84
+
85
+ # Legend
86
+ legend_elements = [
87
+ mpatches.Patch(facecolor="none", edgecolor="cyan", label=f'6nm GT ({len(annotations.get("6nm", []))})', linewidth=1.5),
88
+ mpatches.Patch(facecolor="none", edgecolor="yellow", label=f'12nm GT ({len(annotations.get("12nm", []))})', linewidth=1.5),
89
+ ]
90
+ if predictions:
91
+ n_pred_6 = sum(1 for d in predictions if d["class"] == "6nm")
92
+ n_pred_12 = sum(1 for d in predictions if d["class"] == "12nm")
93
+ legend_elements.extend([
94
+ mpatches.Patch(facecolor="none", edgecolor="darkcyan", label=f"6nm pred ({n_pred_6})", linewidth=1.0),
95
+ mpatches.Patch(facecolor="none", edgecolor="goldenrod", label=f"12nm pred ({n_pred_12})", linewidth=1.0),
96
+ ])
97
+ ax.legend(handles=legend_elements, loc="upper right", fontsize=8)
98
+
99
+ ax.set_title(title, fontsize=10)
100
+ ax.axis("off")
101
+
102
+ if save_path:
103
+ save_path = Path(save_path)
104
+ save_path.parent.mkdir(parents=True, exist_ok=True)
105
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
106
+ plt.close(fig)
107
+
108
+ return fig
109
+
110
+
111
+ def plot_heatmap_overlay(
112
+ image: np.ndarray,
113
+ heatmap: np.ndarray,
114
+ title: str = "",
115
+ save_path: Optional[Path] = None,
116
+ ) -> plt.Figure:
117
+ """
118
+ Overlay predicted heatmap on image for QC.
119
+
120
+ Args:
121
+ image: (H, W) grayscale
122
+ heatmap: (2, H/2, W/2) predicted heatmap
123
+ """
124
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
125
+
126
+ axes[0].imshow(image, cmap="gray")
127
+ axes[0].set_title("Raw Image")
128
+ axes[0].axis("off")
129
+
130
+ # Upsample heatmap to image size for overlay
131
+ h, w = image.shape[:2]
132
+
133
+ for idx, (cls, color) in enumerate([("6nm", "hot"), ("12nm", "cool")]):
134
+ hm = heatmap[idx]
135
+ # Resize to image dims
136
+ from skimage.transform import resize
137
+ hm_up = resize(hm, (h, w), order=1)
138
+
139
+ axes[idx + 1].imshow(image, cmap="gray")
140
+ axes[idx + 1].imshow(hm_up, cmap=color, alpha=0.5, vmin=0, vmax=1)
141
+ axes[idx + 1].set_title(f"{cls} heatmap")
142
+ axes[idx + 1].axis("off")
143
+
144
+ fig.suptitle(title, fontsize=12)
145
+
146
+ if save_path:
147
+ save_path = Path(save_path)
148
+ save_path.parent.mkdir(parents=True, exist_ok=True)
149
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
150
+ plt.close(fig)
151
+
152
+ return fig
153
+
154
+
155
+ def plot_training_curves(
156
+ metrics: dict,
157
+ save_path: Optional[Path] = None,
158
+ ) -> plt.Figure:
159
+ """Plot training loss and F1 curves."""
160
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
161
+
162
+ epochs = range(1, len(metrics["train_loss"]) + 1)
163
+
164
+ # Loss
165
+ ax1.plot(epochs, metrics["train_loss"], label="Train Loss")
166
+ if "val_loss" in metrics:
167
+ ax1.plot(epochs, metrics["val_loss"], label="Val Loss")
168
+ ax1.set_xlabel("Epoch")
169
+ ax1.set_ylabel("Loss")
170
+ ax1.set_title("Training Loss")
171
+ ax1.legend()
172
+ ax1.grid(True, alpha=0.3)
173
+
174
+ # F1
175
+ if "val_f1_6nm" in metrics:
176
+ ax2.plot(epochs, metrics["val_f1_6nm"], label="6nm F1")
177
+ if "val_f1_12nm" in metrics:
178
+ ax2.plot(epochs, metrics["val_f1_12nm"], label="12nm F1")
179
+ if "val_f1_mean" in metrics:
180
+ ax2.plot(epochs, metrics["val_f1_mean"], label="Mean F1", linewidth=2)
181
+ ax2.set_xlabel("Epoch")
182
+ ax2.set_ylabel("F1 Score")
183
+ ax2.set_title("Validation F1")
184
+ ax2.legend()
185
+ ax2.grid(True, alpha=0.3)
186
+
187
+ if save_path:
188
+ save_path = Path(save_path)
189
+ save_path.parent.mkdir(parents=True, exist_ok=True)
190
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
191
+ plt.close(fig)
192
+
193
+ return fig
194
+
195
+
196
+ def plot_precision_recall_curve(
197
+ detections: List[dict],
198
+ gt_coords: np.ndarray,
199
+ match_radius: float,
200
+ cls_name: str = "",
201
+ save_path: Optional[Path] = None,
202
+ ) -> plt.Figure:
203
+ """Plot precision-recall curve for one class."""
204
+ sorted_dets = sorted(detections, key=lambda d: d["conf"], reverse=True)
205
+
206
+ tp_list = []
207
+ matched_gt = set()
208
+
209
+ for det in sorted_dets:
210
+ det_coord = np.array([det["x"], det["y"]])
211
+ if len(gt_coords) > 0:
212
+ dists = np.sqrt(np.sum((gt_coords - det_coord) ** 2, axis=1))
213
+ min_idx = np.argmin(dists)
214
+ if dists[min_idx] <= match_radius and min_idx not in matched_gt:
215
+ tp_list.append(1)
216
+ matched_gt.add(min_idx)
217
+ else:
218
+ tp_list.append(0)
219
+ else:
220
+ tp_list.append(0)
221
+
222
+ tp_cumsum = np.cumsum(tp_list)
223
+ fp_cumsum = np.cumsum([1 - t for t in tp_list])
224
+ n_gt = max(len(gt_coords), 1)
225
+
226
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum)
227
+ recall = tp_cumsum / n_gt
228
+
229
+ fig, ax = plt.subplots(figsize=(6, 6))
230
+ ax.plot(recall, precision, linewidth=2)
231
+ ax.set_xlabel("Recall")
232
+ ax.set_ylabel("Precision")
233
+ ax.set_title(f"PR Curve — {cls_name}")
234
+ ax.set_xlim(0, 1)
235
+ ax.set_ylim(0, 1)
236
+ ax.grid(True, alpha=0.3)
237
+
238
+ if save_path:
239
+ save_path = Path(save_path)
240
+ save_path.parent.mkdir(parents=True, exist_ok=True)
241
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
242
+ plt.close(fig)
243
+
244
+ return fig