File size: 24,821 Bytes
322161a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
import os
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image, ImageDraw
import torch.nn.functional as F

norm = lambda t: (t - t.min()) / (t.max() - t.min())
denorm = lambda t, min_, max_: t * (max_ - min_) + min_

percentilerange = lambda t, perc: t.min() + perc * (t.max() - t.min())
midrange = lambda t: percentilerange(t, .5)

downsample_mask = lambda mask, H, W: F.interpolate(mask.unsqueeze(1), size=(H, W), mode='bilinear',
                                                   align_corners=False).squeeze(1)


# downsampled_mask: [bsz,vecs], vecs can be H*W for example
# s_feat_volume: [bsz,c,vecs]
# returns [bsz,c], [bsz,c,vecs]
def fg_bg_proto(sfeat_volume, downsampled_smask):
    B, C, vecs = sfeat_volume.shape
    reshaped_mask = downsampled_smask.expand(B, vecs).unsqueeze(1)  # ->[B,1,vecs]

    masked_fg = reshaped_mask * sfeat_volume
    fg_proto = torch.sum(masked_fg, dim=-1) / (torch.sum(reshaped_mask, dim=-1) + 1e-8)

    masked_bg = (1 - reshaped_mask) * sfeat_volume
    bg_proto = torch.sum(masked_bg, dim=-1) / (torch.sum(1 - reshaped_mask, dim=-1) + 1e-8)
    assert fg_proto.shape == (B, C), ":o"
    return fg_proto, bg_proto


# intersection = lambda pred, target: (pred * target).float().sum()
# union = lambda pred, target: (pred + target).clamp(0, 1).float().sum()
#
#
# def iou(pred, target):  # binary only, input bsz,h,w
#     i, u = intersection(pred, target), union(pred, target)
#     iou = (i + 1e-8) / (u + 1e-8)
#     return iou.item()
#
#
# class SimpleAvgMeter:
#     def __init__(self, n_classes, device=torch.device('cuda')):
#         self.n_lasses = n_classes
#         self.intersection_buf = torch.zeros(n_classes).to(device)
#         self.union_buf = torch.zeros(n_classes).to(device)
#
#     def update(self, pred, target, class_id):
#         self.intersection_buf[class_id] += intersection(pred, target)
#         self.union_buf[class_id] += union(pred, target)
#
#     def IoU(self, class_id):
#         return self.intersection_buf[class_id] / self.union_buf[class_id] * 100
#
#     def cls_mIoU(self, class_ids):
#         return (self.intersection_buf[class_ids] / self.union_buf[class_ids]).mean() * 100
#
#     def compute_mIoU(self):
#         noentry = self.union_buf == 0
#         if noentry.sum() > 0: print("SimpleAvgMeter warning: ", noentry.sum(), "elements of", self.nclasses,
#                                     "have no empty.")
#         return self.cls_mIoU(~noentry)

# class KMeans():
#     # expects input to be in shape [bsz, -1]
#     def __init__(self, data, k=2, num_iterations=10):
#         self.k = k
#         self.device = data.device
#         self.centroids = self._init_centroids(data)
#
#         for _ in range(num_iterations):
#             labels = self._assign_clusters(data)
#             self._update_centroids(data, labels)
#
#         self.labels = self._assign_clusters(data)  # Final cluster assignment
#
#     def _init_centroids(self, data):
#         # Randomly initialize centroids
#         centroids = []
#         min_values = data.min(dim=1, keepdim=True).values
#         range_values = (data.max(dim=1, keepdim=True).values - min_values)
#
#         for _ in range(self.k):
#             random_values = torch.rand((data.shape[0], 1)).to(self.device)
#             centroids.append(min_values + random_values * range_values)
#
#         return torch.cat(centroids, dim=1)
#
#     def _assign_clusters(self, data):
#         # Calculate distances between data points and centroids
#         distances = torch.abs(data.unsqueeze(2) - self.centroids)  # Expand data tensor to calculate distances
#         # Determine the closest centroid for each data point
#         labels = torch.argmin(distances, dim=2)
#         # Sort labels so that the largest mean data point has the highest label
#         cluster_means = [data[labels == k].mean() for k in range(self.k)]
#         sorted_labels = {k: rank for rank, k in enumerate(sorted(range(self.k), key=lambda k: cluster_means[k]))}
#         labels = torch.tensor([sorted_labels[label.item()] for label in labels.flatten()]).reshape_as(labels).to(
#             self.device)
#
#         return labels
#
#     def _update_centroids(self, data, labels):
#         # Calculate new centroids as the mean of the data points closest to each centroid
#         mask = torch.nn.functional.one_hot(labels, num_classes=self.k).to(torch.float32)
#         summed_data = torch.bmm(mask.transpose(1, 2), data.unsqueeze(2))  # Sum data points per centroid
#         self.centroids = summed_data.squeeze() / mask.sum(dim=1, keepdim=True)  # Normalize to get the mean
#
#     def compute_thresholds(self):
#         # Flatten the centroids along the middle dimension
#         flat_centroids = self.centroids.view(self.centroids.size(0), -1)
#
#         # Sort the flattened centroids
#         sorted_centroids, _ = torch.sort(flat_centroids, dim=1)
#
#         # Compute the midpoints between consecutive centroids
#         thresholds = (sorted_centroids[:, :-1] + sorted_centroids[:, 1:]) / 2.0
#
#         return thresholds
#
#     def inference(self, data):
#         # Assign data points to the nearest centroid
#         return self._assign_clusters(data)

# def iterative_triclass_thresholding(image, max_iterations=100, tolerance=25):
#     # Ensure image is grayscale
#     if len(image.shape) == 3:
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#
#     # Initialize iteration parameters
#     TBD_region = image.copy()
#     iteration = 0
#     prev_threshold = 0
#
#     while iteration < max_iterations:
#         iteration += 1
#
#         # Step 1: Apply Otsu's thresholding on the TBD region
#         current_threshold, _ = cv2.threshold(TBD_region, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
#
#         # Check stopping criteria
#         if abs(current_threshold - prev_threshold) < tolerance:
#             break
#         prev_threshold = current_threshold
#
#         # Step 2: Calculate means for upper and lower regions
#         upper_region = TBD_region[TBD_region > current_threshold]
#         lower_region = TBD_region[TBD_region <= current_threshold]
#
#         if len(upper_region) == 0 or len(lower_region) == 0:
#             break  # No further division possible
#
#         mean_upper = np.mean(upper_region)
#         mean_lower = np.mean(lower_region)
#
#         # Step 3: Update temporary foreground, background, and TBD regions
#         TBD_region[(TBD_region > mean_upper)] = 255  # Temporary foreground F
#         TBD_region[(TBD_region < mean_lower)] = 0  # Temporary background B
#
#         # Extracting the new TBD region (between mean_lower and mean_upper)
#         mask = (TBD_region > mean_lower) & (TBD_region < mean_upper)
#         TBD_region = TBD_region[mask]  # Apply mask to extract region
#
#     # Final classification after convergence or max iterations
#     final_foreground = (image > current_threshold).astype(np.uint8) * 255
#     final_background = (image <= current_threshold).astype(np.uint8) * 255
#
#     return current_threshold, final_foreground

def otsus(batched_tensor_image, drop_least=0.05, mode='ordinary'):
    bsz = batched_tensor_image.size(0)
    binary_tensors = []
    thresholds = []

    for i in range(bsz):
        # Convert the tensor to numpy array
        numpy_image = batched_tensor_image[i].cpu().numpy()

        # Rescale to [0, 255] and convert to uint8 type for OpenCV compatibility
        npmin, npmax = numpy_image.min(), numpy_image.max()
        numpy_image = (norm(numpy_image) * 255).astype(np.uint8)

        # Drop values that are in the lowest percentiles
        truncated_vals = numpy_image[numpy_image >= int(255 * drop_least)]

        # Apply Otsu's thresholding
        if mode == 'via_triclass':
            thresh_value, _ = iterative_triclass_thresholding(truncated_vals)
        else:
            thresh_value, _ = cv2.threshold(truncated_vals, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # Apply the computed threshold on the original image
        binary_image = (numpy_image > thresh_value).astype(np.uint8) * 255

        # Convert the result back to a tensor and append to the list
        binary_tensors.append(torch.from_numpy(binary_image).float() / 255)

        thresholds.append(torch.tensor(denorm(thresh_value / 255, npmin, npmax)) \
                          .to(batched_tensor_image.device, dtype=batched_tensor_image.dtype))

    # Convert list of tensors back to a single batched tensor
    binary_tensor_batch = torch.stack(binary_tensors, dim=0)
    thresh_batch = torch.stack(thresholds, dim=0)
    return thresh_batch, binary_tensor_batch


def iterative_otsus(probab_mask, s_mask, maxiters=5, mode='ordinary',

                    debug=False):  # verify that it works correctly when batch_size >1
    it = 1
    otsuthresh = 0
    assert probab_mask.min() >= 0 and probab_mask.max() <= 1, 'you should pass probabilites'
    while True:
        clipped = torch.where(probab_mask < otsuthresh, 0, probab_mask)
        otsuthresh, newmask = otsus(clipped.detach(), drop_least=.02, mode=mode)
        if otsuthresh >= s_mask.mean():
            return otsuthresh.to(probab_mask.device), newmask.to(probab_mask.device)
        if it >= maxiters:
            if debug:
                print('reached maxiter:', it, 'with thresh', otsuthresh.item(), \
                      'removed', int(((clipped == 0).sum() / clipped.numel()).item() * 10000) / 100, \
                      '% at lower and and new min,max is', clipped[clipped > 0].min().item(), clipped.max().item())
                display(pilImageRow(norm(probab_mask[0]), s_mask[0], maxwidth=300))
            return s_mask.mean(), (probab_mask > s_mask.mean()).float()  # otsuthresh
        it += 1


# def upgrade_scipy():
#     os.system('!pip install - -upgrade scipy')
#
#
# def slicRGB(q_img, n_segments=50, compactness=10., sigma=1, mask=None, debug=False):
#     import skimage.segmentation as skseg
#
#     rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
#                             enforce_connectivity=True)
#
#     if debug:
#         plt.imshow(skseg.mark_boundaries(q_img, rgb_labels))
#         plt.show()
#
#     return rgb_labels
#
#
#
# def slicRGBP(q_img, fg_pred, n_segments=30, compactness=0.1, sigma=1, mask=None, debug=False):
#     import skimage.segmentation as skseg
#
#     def concat_rgb_pred(rgbimg, pred):
#         h, w = rgbimg.shape[:2]
#         return np.concatenate((rgbimg, pred.reshape(h, w, 1)), axis=-1)
#
#     rgbp_img = concat_rgb_pred(q_img, fg_pred)
#     rgbp_labels = skseg.slic(rgbp_img, n_segments=n_segments, compactness=compactness, mask=mask, sigma=sigma,
#                              enforce_connectivity=True)
#
#     if debug:
#         rgb_labels = skseg.slic(q_img, n_segments=n_segments, compactness=10., sigma=sigma, mask=mask,
#                                 enforce_connectivity=True)
#         pred_labels = skseg.slic(fg_pred, n_segments=n_segments, compactness=compactness, sigma=sigma, mask=mask,
#                                  channel_axis=None, enforce_connectivity=True)
#
#         rows, cols = 1, 3
#         fig, ax = plt.subplots(rows, cols, figsize=(10, 10), sharex=True, sharey=True)
#         ax[0].imshow(skseg.mark_boundaries(q_img, rgbp_labels))
#         ax[1].imshow(skseg.mark_boundaries(q_img, rgb_labels))
#         ax[2].imshow(skseg.mark_boundaries(q_img, pred_labels))
#         plt.show()
#
#     return rgbp_labels
#
#
# def calc_cluster_means(label_id_map, fg_prob):
#     fg_pred_clustered = np.zeros_like(fg_prob)
#     label_ids = np.unique(label_id_map)
#     for lab_id in label_ids:
#         cluster = fg_prob[label_id_map == lab_id]
#         fg_pred_clustered[label_id_map == lab_id] = cluster.mean()
#     return fg_pred_clustered


def install_pydensecrf():
    os.system('pip install git+https://github.com/lucasb-eyer/pydensecrf.git')


class CRF:
    def __init__(self, gaussian_stdxy=(3, 3), gaussian_compat=3,

                 bilateral_stdxy=(80, 80), bilateral_compat=10, stdrgb=(13, 13, 13)):
        self.gaussian_stdxy = gaussian_stdxy
        self.gaussian_compat = gaussian_compat
        self.bilateral_stdxy = bilateral_stdxy
        self.bilateral_compat = bilateral_compat
        self.stdrgb = stdrgb
        self.iters = 5
        self.debug = False

    def refine(self, image_tensor, fg_probs, soft_thresh=None, T=1):

        """

        Refine segmentation using DenseCRF.



        Args:

            - image_tensor (tensor): Original image, shape [1, 3, H, W].

            - fg_probs (tensor): Fg probabilities from the network, shape [1, H, W]

            - soft_thresh: The preferred threshold for fg_probs for segmenting into binary prediction mask

            - T: a temperature for softmax/sigmoid



        Returns:

            - Refined segmentation mask, shape [1, H, W].

        """
        try:
            import pydensecrf.densecrf as dcrf
            from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
        except ImportError as e:
            print("pydensecrf not found. Installing...")
            install_pydensecrf()  # Ensure this function installs pydensecrf and handles any potential errors during installation.

        # After installation, retry importing. This is placed inside the except block to avoid repeating the import statements.
        try:
            import pydensecrf.densecrf as dcrf
            from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
        except ImportError as e:
            print("Failed to import after installation. Please check the installation of pydensecrf.")
            raise  # This will raise the last exception that was handled by the except block

        # We find the segmentation threshold that splits fg-bg
        if soft_thresh is None:
            soft_thresh, _ = otsus(fg_probs)
        image_tensor, fg_probs, soft_thresh = image_tensor.cpu(), fg_probs.cpu(), soft_thresh.cpu()
        # Then we presume at this threshold the probability should be 0.5
        # probability 0 should stay 0, 1 should stay 1
        # sigmoid=lambda x: 1/(1 + np.exp(-x))
        fg_probs = torch.sigmoid(T * (fg_probs - soft_thresh))
        probs = torch.stack([1 - fg_probs, fg_probs], dim=1)  # crf expects both classes as input
        if self.debug:
            print('softthresh', soft_thresh)
            print('fg_probs min max', fg_probs.min(), fg_probs.max())
        # C: Number of classes
        bsz, C, H, W = probs.shape

        refined_masks = []
        image_numpy = np.ascontiguousarray( \
            (255 * image_tensor.permute(0, 2, 3, 1)).numpy().astype(np.uint8))
        probs_numpy = probs.numpy()
        for (image, prob) in zip(image_numpy, probs_numpy):
            # Unary potentials
            unary = np.ascontiguousarray(unary_from_softmax(prob))
            d = dcrf.DenseCRF2D(W, H, C)
            d.setUnaryEnergy(unary)

            # Add pairwise potentials
            d.addPairwiseGaussian(sxy=self.gaussian_stdxy, compat=self.gaussian_compat)
            d.addPairwiseBilateral(sxy=self.bilateral_stdxy, srgb=self.stdrgb,
                                   rgbim=image, compat=self.bilateral_compat)

            # Perform inference
            Q = d.inference(self.iters)
            if self.debug:
                print('Q:', np.array(Q).shape, np.array(Q)[0].mean(), np.array(Q).mean())
            result = np.reshape(Q, (2, H, W))  # np.argmax(Q, axis=0).reshape((H, W))
            refined_masks.append(result)

        return torch.from_numpy(np.stack(refined_masks, axis=0))

    #     def iterrefine(self, iters, image_tensor, fg_probs, soft_thresh=None, T=1):
    #         q1 = fg_probs
    #         for iter in range(iters):
    #             print(q1.shape)
    #             q1 = self.refine(image_tensor, q1, soft_thresh=None, T=1)[:,1]
    #         return q1

    def iterrefine(self, iters, q_img, fg_probs, thresh_fn, debug=False):
        pred = fg_probs.unsqueeze(1).expand(1, 2, *fg_probs.shape[-2:])
        for it in range(iters):
            thresh = thresh_fn(pred[:, 1])[0]

            if debug and i % 10 == 0:
                print('thresh', thresh)
                display(to_pil(pred[0, 1]))

            pred = self.refine(q_img, pred[:, 1], soft_thresh=thresh)
        return pred


#
# class Subplot:
#     def __init__(self):
#         self.vertical_lines = []
#         self.histograms = []
#         self.gaussian_curves = []
#         self.colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
#         self.title = ''
#
#     class Element:
#         def __init__(self, x=None, y=None, label=''):
#             if x is not None:
#                 self.x = Subplot.to_np(x)
#             if y is not None:
#                 self.y = Subplot.to_np(y)
#
#             self.label = label
#
#     @staticmethod
#     def to_np(t):
#         return t.detach().cpu().numpy()
#
#     def add_vertical(self, x, label=''):
#         self.vertical_lines.append(Subplot.Element(x=x, label=label))
#         return self
#
#     def add_histogram(self, samples, label=''):
#         self.histograms.append(Subplot.Element(x=samples, label=label))
#         return self
#
#     def add_gaussian(self, gaussian):
#         samples, mu, var = gaussian.samples, gaussian.mean, gaussian.covs
#         # Generate a range of x values
#         x_values = np.linspace(samples.min(), samples.max(), 100)
#         x_values = np.linspace(samples.min(), samples.max(), 100)
#
#         # Compute Gaussian values for these x values
#         gaussian1_values = gaussian.gaussian_pdf(x_values, mu[0].item(), var[0].item())
#         gaussian2_values = gaussian.gaussian_pdf(x_values, mu[1].item(), var[1].item())
#         self.gaussian_curves.append(Subplot.Element(x_values, gaussian1_values))
#         self.gaussian_curves.append(Subplot.Element(x_values, gaussian2_values))
#         return self
#
#
# class PredHistos2():
#     def __init__(self, n_cols=1):
#         self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#         self.n_cols = n_cols
#         if n_cols == 1:
#             self.builder = Subplot()
#         self.subplots = [Subplot() for x in range(n_cols)]
#         self.alpha = 0.5
#         self.bins = 200
#
#     def reload(self, n_cols=1):
#         self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#
#     def aggr(self, ax, sub):
#         for hist, col in zip(sub.histograms, sub.colors):
#             ax.hist(hist.x, self.bins, density=True, color=col, alpha=self.alpha, label=hist.label)
#         for vline, col in zip(sub.vertical_lines, sub.colors):
#             ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
#         for gaussian, col in zip(sub.gaussian_curves, sub.colors):
#             ax.plot(gaussian.x, gaussian.y, gaussian.label, col)
#         ax.legend()
#
#     def plot(self, name=''):
#
#         if self.n_cols == 1:
#             self.aggr(plt, self.builder)
#         else:
#             for ax, sub in zip(self.axes, self.subplots):
#                 self.aggr(ax, sub)
#                 ax.set_title(sub.title)
#
#         plt.legend()
#         plt.title(name)
#         plt.show()
#
#
# from sklearn.mixture import GaussianMixture
# import scipy.optimize as opt
# from scipy.optimize import fsolve
#
#
# class GMM:
#     def __init__(self, q_pred_coarse, name='gaussian', n_components=2):
#         samples = q_pred_coarse.detach().cpu().numpy()
#         self.samples = samples.reshape(-1, 1)
#
#         # Fit a mixture of 2 Gaussians using EM
#         gmm = GaussianMixture(n_components)
#         gmm.fit(samples)
#         self.means = gmm.means_.flatten()
#         self.covs = gmm.covariances_.flatten()
#         self.weights = gmm.weights_
#         self.label = name
#
#     def intersect(self):
#         # Use fsolve to find the intersection
#         gaussian_intersect, = fsolve(difference, self.means.mean(), args=(
#         self.means[0].item(), self.covs[0].item(), self.means[1].item(), self.means[1].item()))
#         return gaussian_intersect
#
#
# class PredHistoSNS():
#     def __init__(self, n_cols=1):
#         import seaborn as sns
#         sns.set_theme(style="whitegrid")  # Set the Seaborn theme. You can change the style as needed.
#         self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#         self.n_cols = n_cols
#         if n_cols == 1:
#             self.axes = [self.axes]  # Wrap the single axis in a list to simplify the loop logic later.
#             self.builder = Subplot()  # This is assuming Subplot is a properly defined class.
#         self.subplots = [Subplot() for _ in range(n_cols)]  # Use underscore for unused loop variable.
#         self.alpha = 0.5
#         self.bins = 200
#
#     def reload(self, n_cols=1):
#         self.fig, self.axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, 4))
#
#     def aggr(self, ax, sub):
#         import seaborn as sns
#         for hist, col in zip(sub.histograms, sub.colors):
#             sns.histplot(hist.x, bins=self.bins, kde=False, color=col, ax=ax, alpha=self.alpha, label=hist.label)
#         for vline, col in zip(sub.vertical_lines, sub.colors):
#             ax.axvline(x=vline.x, color=col, label=vline.label, linestyle='--')
#         for gaussian, col in zip(sub.gaussian_curves, sub.colors):
#             sns.lineplot(x=gaussian.x, y=gaussian.y, label=gaussian.label, color=col, ax=ax)
#         ax.legend()
#
#     def plot(self, name=''):
#
#         if self.n_cols == 1:
#             self.aggr(self.axes[0], self.builder)
#         else:
#             for ax, sub in zip(self.axes, self.subplots):
#                 self.aggr(ax, sub)
#                 ax.set_title(sub.title)
#
#         plt.show()
#
#
# def overlay_mask(image, mask, color=[255, 0, 0], alpha=0.2):
#     """
#     Apply an overlay of a binary mask onto an image using a specified color.
#
#     :param image: A PyTorch tensor of the image (C x H x W) with pixel values in [0, 1].
#     :param mask: A PyTorch tensor of the mask (H x W) with binary values (0 or 1).
#     :param color: A list of 3 elements representing the RGB values of the overlay color.
#     :param alpha: A float representing the transparency of the overlay (0 to 1).
#     :return: An overlayed image tensor.
#     """
#     # Ensure the mask is binary
#     mask = (mask > 0).float()
#
#     # Create an RGB version of the mask
#     mask_rgb = torch.tensor(color).view(3, 1, 1) / 255.0  # Normalize the color vector
#     mask_rgb = mask_rgb * mask
#
#     # Overlay the mask onto the image
#     overlayed_image = (1 - alpha) * image + alpha * mask_rgb
#
#     # Ensure the resulting tensor values are between 0 and 1
#     overlayed_image = torch.clamp(overlayed_image, 0, 1)
#
#     return overlayed_image
#
#
# import pandas as pd

# to_pil = lambda t: transforms.ToPILImage()(t) if t.shape[-1] > 4 else transforms.ToPILImage()(t.permute(2, 0, 1))
#
#
# def pilImageRow(*imgs, maxwidth=800, bordercolor=0x000000):
#     imgs = [to_pil(im.float()) for im in imgs]
#     dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height))
#     for i, im in enumerate(imgs):
#         loc = [x0, y0, x1, y1] = [i * im.width, 0, (i + 1) * im.width, im.height]
#         dst.paste(im, (x0, y0))
#         ImageDraw.Draw(dst).rectangle(loc, width=2, outline=bordercolor)
#     factorToBig = dst.width / maxwidth
#     dst = dst.resize((int(dst.width / factorToBig), int(dst.height / factorToBig)))
#     return dst
#
#
# def tensor_table(**kwargs):
#     tensor_overview = {}
#     for name, tensor in kwargs.items():
#         if callable(tensor):
#             print(name, [tensor(t) for _, t in kwargs.items() if isinstance(t, torch.Tensor)])
#         else:
#             tensor_overview[name] = {
#                 'min': tensor.min().item(),
#                 'max': tensor.max().item(),
#                 'shape': tensor.shape,
#             }
#     return pd.DataFrame.from_dict(tensor_overview, orient='index')