File size: 24,471 Bytes
427d150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
"""

Validation script

"""
import math
import os
import pandas as pd
import csv
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import numpy as np
import time
import matplotlib.pyplot as plt
from models.ProtoSAM import ProtoSAM,  ALPNetWrapper, SamWrapperWrapper, InputFactory, ModelWrapper, TYPE_ALPNET, TYPE_SAM
from models.ProtoMedSAM import ProtoMedSAM
from models.grid_proto_fewshot import FewShotSeg
from models.segment_anything.utils.transforms import ResizeLongestSide
from models.SamWrapper import SamWrapper
# from dataloaders.PolypDataset import get_polyp_dataset, get_vps_easy_unseen_dataset, get_vps_hard_unseen_dataset, PolypDataset, KVASIR, CVC300, COLON_DB, ETIS_DB, CLINIC_DB
from dataloaders.PolypDataset import get_polyp_dataset, PolypDataset
from dataloaders.PolypTransforms import get_polyp_transform
from dataloaders.SimpleDataset import SimpleDataset
from dataloaders.ManualAnnoDatasetv2 import get_nii_dataset
from dataloaders.common import ValidationDataset
from config_ssl_upload import ex

import tqdm
from tqdm.auto import tqdm
import cv2
from collections import defaultdict

# config pre-trained model caching path
os.environ['TORCH_HOME'] = "./pretrained_model"

# Supported Datasets
CHAOS = "chaos"
SABS = "sabs"
POLYPS = "polyps"

ALP_DS = [CHAOS, SABS]

ROT_DEG = 0

def get_bounding_box(segmentation_map):
    """Generate bounding box from a segmentation map. one bounding box to include the extreme points of the segmentation map."""
    if isinstance(segmentation_map, torch.Tensor):
        segmentation_map = segmentation_map.cpu().numpy()
    
    bbox = cv2.boundingRect(segmentation_map.astype(np.uint8))
    # plot bounding boxes for each contours
    # plt.figure()
    # x, y, w, h = bbox
    # plt.imshow(segmentation_map)
    # plt.gca().add_patch(plt.Rectangle((x, y), w, h, fill=False, edgecolor='r', linewidth=2))
    # plt.savefig("debug/bounding_boxes.png") 

    return bbox

def calc_iou(boxA, boxB):
    """

    boxA: [x, y, w, h]

    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
    yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])

    interArea = max(0, xB - xA) * max(0, yB - yA)
    boxAArea = boxA[2] * boxA[3]
    boxBArea = boxB[2] * boxB[3]
    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou


def eval_detection(pred_list):
    """

    pred_list: list of dictionaries with keys 'pred_bbox', 'gt_bbox' and score (prediction confidence score).

    compute AP50, AP75, AP50:95:10

    """
    iou_thresholds = np.round(np.arange(0.5, 1.0, 0.05), 2)
    ap_dict = {iou: [] for iou in iou_thresholds}
    for iou_threshold in iou_thresholds:
        tp, fp = 0, 0
        
        for pred in pred_list:
            pred_bbox = pred['pred_bbox']
            gt_bbox = pred['gt_bbox']
            
            iou = calc_iou(pred_bbox, gt_bbox)
            
            if iou >= iou_threshold:
                tp += 1
            else:
                fp += 1

        precision = tp / (tp + fp)
        recall = tp / len(pred_list) 
        f1 = 2 * (precision * recall) / (precision + recall)        

        ap_dict[iou_threshold] = {
            'iou_threshold': iou_threshold,
            'tp': tp,
            'fp': fp,
            'n_gt': len(pred_list),
            'f1': f1,
            'precision': precision,
            'recall': recall
        }
    
    # Convert results to a DataFrame and save to CSV
    results = []
    for iou_threshold in iou_thresholds:
        results.append(ap_dict[iou_threshold])
    
    df = pd.DataFrame(results)
    return df


def plot_pred_gt_support(query_image, pred, gt, support_images, support_masks, score=None, save_path="debug/pred_vs_gt"):
    """

    Save 5 key images: support images, support mask, query, ground truth and prediction.

    Handles both grayscale and RGB images consistently with the same mask color.

    

    Args:

        query_image: Query image tensor (grayscale or RGB)

        pred: 2d tensor where 1 represents foreground and 0 represents background

        gt: 2d tensor where 1 represents foreground and 0 represents background

        support_images: Support image tensors (grayscale or RGB)

        support_masks: Support mask tensors

        score: Optional score to add to filename

        save_path: Base path without extension for saving images

    """
    # Create directory for this case
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Process query image - ensure HxWxC format for visualization
    query_image = query_image.clone().detach().cpu()
    if len(query_image.shape) == 3 and query_image.shape[0] <= 3:  # CHW format
        query_image = query_image.permute(1, 2, 0)
    
    # Handle grayscale vs RGB consistently
    if len(query_image.shape) == 2 or (len(query_image.shape) == 3 and query_image.shape[2] == 1):
        # For grayscale, use cmap='gray' for visualization
        is_grayscale = True
        if len(query_image.shape) == 3:
            query_image = query_image.squeeze(2)  # Remove channel dimension for grayscale
    else:
        is_grayscale = False
    
    # Normalize image for visualization
    query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8)
    
    # Convert pred and gt to numpy for visualization
    pred_np = pred.cpu().float().numpy()  # Ensure float before converting to numpy
    gt_np = gt.cpu().float().numpy()  # Ensure float before converting to numpy
    
    # Ensure binary masks
    pred_np = (pred_np > 0).astype(np.float32)
    gt_np = (gt_np > 0).astype(np.float32)
    
    # Set all positive values to 1.0 to ensure consistent red coloring in YlOrRd colormap
    pred_np[pred_np > 0] = 1.0
    gt_np[gt_np > 0] = 1.0
    
    # Create colormap for mask overlays - using the YlOrRd colormap as requested
    mask_cmap = plt.cm.get_cmap('YlOrRd')
    
    # Generate color masks with alpha values
    pred_rgba = mask_cmap(pred_np)
    pred_rgba[..., 3] = pred_np * 0.7  # Last channel is alpha - semitransparent where mask=1
    
    gt_rgba = mask_cmap(gt_np)
    gt_rgba[..., 3] = gt_np * 0.7  # Last channel is alpha - semitransparent where mask=1
    
    # 1. Save query image (original)
    plt.figure(figsize=(10, 10))
    if is_grayscale:
        plt.imshow(query_image, cmap='gray')
    else:
        plt.imshow(query_image)
    plt.axis('off')
    # Remove padding/whitespace
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.savefig(f"{save_path}/query.png", bbox_inches='tight', pad_inches=0)
    plt.close()
    
    # 2. Save query image with prediction overlay
    plt.figure(figsize=(10, 10))
    if is_grayscale:
        plt.imshow(query_image, cmap='gray')
    else:
        plt.imshow(query_image)
    plt.imshow(pred_rgba)
    plt.axis('off')
    # Remove padding/whitespace
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.savefig(f"{save_path}/pred.png", bbox_inches='tight', pad_inches=0)
    plt.close()
    
    # 3. Save query image with ground truth overlay
    plt.figure(figsize=(10, 10))
    if is_grayscale:
        plt.imshow(query_image, cmap='gray')
    else:
        plt.imshow(query_image)
    plt.imshow(gt_rgba)
    plt.axis('off')
    # Remove padding/whitespace
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
    plt.savefig(f"{save_path}/gt.png", bbox_inches='tight', pad_inches=0)
    plt.close()
    
    # Process and save support images and masks (just the first one for brevity)
    if support_images is not None:
        if isinstance(support_images, list):
            support_images = torch.cat(support_images, dim=0).clone().detach()
        if isinstance(support_masks, list):
            support_masks = torch.cat(support_masks, dim=0).clone().detach()
        
        # Move to CPU for processing
        support_images = support_images.cpu()
        support_masks = support_masks.cpu()
        
        # Handle different dimensions of support images
        if len(support_images.shape) == 4:  # NCHW format
            # Convert to NHWC for visualization
            support_images = support_images.permute(0, 2, 3, 1)
        
        # Just process the first support image
        i = 0
        if support_images.shape[0] > 0:
            support_img = support_images[i].clone()
            support_mask = support_masks[i].clone()
            
            # Check if grayscale or RGB
            if support_img.shape[-1] == 1:  # Last dimension is channels
                support_img = support_img.squeeze(-1)  # Remove channel dimension
                support_is_gray = True
            elif support_img.shape[-1] == 3:
                support_is_gray = False
            else:  # Assume it's grayscale if not 1 or 3 channels
                support_is_gray = True
            
            # Normalize support image
            support_img = (support_img - support_img.min()) / (support_img.max() - support_img.min() + 1e-8)
            
            # 4. Save support image only
            plt.figure(figsize=(10, 10))
            if support_is_gray:
                plt.imshow(support_img, cmap='gray')
            else:
                plt.imshow(support_img)
            plt.axis('off')
            # Remove padding/whitespace
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
            plt.savefig(f"{save_path}/support_1.png", bbox_inches='tight', pad_inches=0)
            plt.close()
            
            # 5. Save support mask only (direct mask visualization similar to gt/pred)
            plt.figure(figsize=(10, 10))
            
            # Process support mask with same approach
            support_mask_np = support_mask.cpu().float().numpy()
            support_mask_np = (support_mask_np > 0).astype(np.float32)
            support_mask_np[support_mask_np > 0] = 1.0  # Set to 1.0 for consistent coloring
            
            support_mask_rgba = mask_cmap(support_mask_np)
            support_mask_rgba[..., 3] = support_mask_np * 0.7  # Last channel is alpha - semitransparent where mask=1
            
            if is_grayscale:
                plt.imshow(support_img, cmap='gray')
            else:
                plt.imshow(support_img)
            plt.imshow(support_mask_rgba)
            plt.axis('off')
            # Remove padding/whitespace
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
            plt.savefig(f"{save_path}/support_mask.png", bbox_inches='tight', pad_inches=0)
            plt.close()




def get_dice_iou_precision_recall(pred: torch.Tensor, gt: torch.Tensor):
    """

    pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background

    gt: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background

    """
    if gt.sum() == 0:
        print("gt is all background")
        return {"dice": 0, "precision": 0, "recall": 0}

    # Resize pred to match gt dimensions if they're different
    if pred.shape != gt.shape:
        print(f"Resizing prediction from {pred.shape} to match ground truth {gt.shape}")
        # Use interpolate to resize pred to match gt dimensions
        pred = torch.nn.functional.interpolate(
            pred.unsqueeze(0).unsqueeze(0).float(), 
            size=gt.shape, 
            mode='nearest'
        ).squeeze(0).squeeze(0)

    tp = (pred * gt).sum()
    fp = (pred * (1 - gt)).sum()
    fn = ((1 - pred) * gt).sum()
    dice = 2 * tp / (2 * tp + fp + fn + 1e-8)
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    iou = tp / (tp + fp + fn + 1e-8)
    return {"dice": dice, "iou": iou, "precision": precision, "recall": recall}


def get_alpnet_model(_config) -> ModelWrapper:
    alpnet = FewShotSeg(
       _config["input_size"][0],
       _config["reload_model_path"],
       _config["model"]
    )
    alpnet.cuda()
    alpnet_wrapper = ALPNetWrapper(alpnet)
    
    return alpnet_wrapper

def get_sam_model(_config) -> ModelWrapper:
    sam_args = {
        "model_type": "vit_h",
        "sam_checkpoint": "pretrained_model/sam_vit_h.pth"
    }
    sam = SamWrapper(sam_args=sam_args).cuda()
    sam_wrapper = SamWrapperWrapper(sam)
    return sam_wrapper  

def get_model(_config) -> ProtoSAM:
    # Initial Segmentation Model
    if _config["base_model"] == TYPE_ALPNET:
        base_model = get_alpnet_model(_config)
    else:
        raise NotImplementedError(f"base model {_config['base_model']} not implemented")
    
    # ProtoSAM model
    if _config["protosam_sam_ver"] in  ("sam_h", "sam_b"):
        sam_h_checkpoint = "pretrained_model/sam_vit_h.pth"
        sam_b_checkpoint = "pretrained_model/sam_vit_b.pth"
        sam_checkpoint = sam_h_checkpoint if _config["protosam_sam_ver"] == "sam_h" else sam_b_checkpoint
        model = ProtoSAM(image_size = (1024, 1024),
                    coarse_segmentation_model=base_model,
                    use_bbox=_config["use_bbox"],
                    use_points=_config["use_points"],
                    use_mask=_config["use_mask"],
                    debug=_config["debug"],
                    num_points_for_sam=1,
                    use_cca=_config["do_cca"],
                    point_mode=_config["point_mode"],
                    use_sam_trans=True, 
                    coarse_pred_only=_config["coarse_pred_only"],
                    sam_pretrained_path=sam_checkpoint,
                    use_neg_points=_config["use_neg_points"],) 
    elif _config["protosam_sam_ver"] == "medsam":
        model = ProtoMedSAM(image_size = (1024, 1024),
                            coarse_segmentation_model=base_model,
                            debug=_config["debug"],
                            use_cca=_config["do_cca"],
        )
    else:
        raise NotImplementedError(f"protosam_sam_ver {_config['protosam_sam_ver']} not implemented")
    
    return model


def get_support_set_polyps(_config, dataset:PolypDataset):
    n_support = _config["n_support"]
    (support_images, support_labels, case) = dataset.get_support(n_support=n_support)
    
    return support_images, support_labels, case


def get_support_set_alpds(config, dataset:ValidationDataset):
    support_set = dataset.get_support_set(config)
    support_fg_masks = support_set["support_labels"]
    support_images = support_set["support_images"]
    support_scan_id = support_set["support_scan_id"]
    return support_images, support_fg_masks, support_scan_id


def get_support_set(_config, dataset):
    if _config["dataset"].lower() == POLYPS:
        support_images, support_fg_masks, case = get_support_set_polyps(_config, dataset)
    elif any(item in _config["dataset"].lower() for item in ALP_DS):
        support_images, support_fg_masks, support_scan_id = get_support_set_alpds(_config, dataset)
    else:
        raise NotImplementedError(f"dataset {_config['dataset']} not implemented")
    return support_images, support_fg_masks, support_scan_id


def update_support_set_by_scan_part(support_images, support_labels, qpart):
    qpart_support_images = [support_images[qpart]]
    qpart_support_labels = [support_labels[qpart]]
    
    return qpart_support_images, qpart_support_labels


def manage_support_sets(sample_batched, all_support_images, all_support_fg_mask, support_images, support_fg_mask, qpart=None):
    if sample_batched['part_assign'][0] != qpart:
        qpart = sample_batched['part_assign'][0]
        support_images, support_fg_mask = update_support_set_by_scan_part(all_support_images, all_support_fg_mask, qpart)
            
    return support_images, support_fg_mask, qpart


@ex.automain
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/interm_preds', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        print(f"####### created dir:{_run.observers[0].dir} #######")
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')
    print(f"config do_cca: {_config['do_cca']}, use_bbox: {_config['use_bbox']}")
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info(f'###### Reload model {_config["reload_model_path"]} ######')
    print(f'###### Reload model {_config["reload_model_path"]} ######')
    model = get_model(_config)
    model = model.to(torch.device("cuda"))
    model.eval()
    
    sam_trans = ResizeLongestSide(1024)
    if _config["dataset"].lower() == POLYPS:
        tr_dataset, te_dataset = get_polyp_dataset(sam_trans=sam_trans, image_size=(1024, 1024))
    elif CHAOS in _config["dataset"].lower() or SABS in _config["dataset"].lower():
        tr_dataset, te_dataset = get_nii_dataset(_config, _config["input_size"][0]) 
    else:
        raise NotImplementedError(
            f"dataset {_config['dataset']} not implemented")

    # dataloaders
    testloader = DataLoader(
        te_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=False,
        drop_last=False
    )

    _log.info('###### Starting validation ######')
    model.eval()

    mean_dice = []
    mean_prec = []
    mean_rec = []
    mean_iou = []
    
    mean_dice_cases = {}
    mean_iou_cases = {} 
    bboxes_w_scores = []
    
    curr_case = None
    supp_fts = None
    qpart = None
    support_images = support_fg_mask = None
    all_support_images, all_support_fg_mask, support_scan_id = None, None, None
    MAX_SUPPORT_IMAGES = 1
    is_alp_ds = any(item in _config["dataset"].lower() for item in ALP_DS)
    is_polyp_ds  = _config["dataset"].lower() == POLYPS
    
    if is_alp_ds:
        all_support_images, all_support_fg_mask, support_scan_id = get_support_set(_config, te_dataset)
    elif is_polyp_ds:
        support_images, support_fg_mask, case = get_support_set_polyps(_config, tr_dataset)
        
    with tqdm(testloader) as pbar: 
        for idx, sample_batched in enumerate(tqdm(testloader)):
            case = sample_batched['case'][0]
            if is_alp_ds: 
                support_images, support_fg_mask, qpart = manage_support_sets(
                                                            sample_batched,
                                                            all_support_images,
                                                            all_support_fg_mask,
                                                            support_images,
                                                            support_fg_mask,
                                                            qpart,
                )
            
            if is_alp_ds and sample_batched["scan_id"][0] in support_scan_id:
                continue
             
            query_images = sample_batched['image'].cuda()
            query_labels = torch.cat([sample_batched['label']], dim=0)
            if not 1 in query_labels and _config["skip_no_organ_slices"]:
                continue
            
            n_try = 1
            with torch.no_grad():
                coarse_model_input = InputFactory.create_input(
                                        input_type=_config["base_model"],
                                        query_image=query_images,
                                        support_images=support_images,
                                        support_labels=support_fg_mask,
                                        isval=True,
                                        val_wsize=_config["val_wsize"],
                                        original_sz=query_images.shape[-2:],
                                        img_sz=query_images.shape[-2:],
                                        gts=query_labels,
                )
                coarse_model_input.to(torch.device("cuda"))
                    
                query_pred, scores = model(
                        query_images, coarse_model_input, degrees_rotate=0)
            query_pred = query_pred.cpu().detach()
                
            if _config["debug"]:
                if is_alp_ds:
                    save_path = f'debug/preds/{case}_{sample_batched["z_id"].item()}_{idx}_{n_try}'
                    os.makedirs(save_path, exist_ok=True)
                elif is_polyp_ds:
                    save_path = f'debug/preds/{case}_{idx}_{n_try}'
                    os.makedirs(save_path, exist_ok=True)
                plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu(),
                                    support_images, support_fg_mask, save_path=save_path, score=scores[0])

            # print(query_pred.shape)
            # print(query_labels[0].shape)
            metrics = get_dice_iou_precision_recall(
                query_pred, query_labels[0].to(query_pred.device))
            mean_dice.append(metrics["dice"])
            mean_prec.append(metrics["precision"])
            mean_rec.append(metrics["recall"])
            mean_iou.append(metrics["iou"])

            bboxes_w_scores.append({"pred_bbox": get_bounding_box(query_pred.cpu()),
                                    "gt_bbox": get_bounding_box(query_labels[0].cpu()),
                                    "score": np.mean(scores)})
            
            if case not in mean_dice_cases:
                mean_dice_cases[case] = []
                mean_iou_cases[case] = []
            mean_dice_cases[case].append(metrics["dice"])
            mean_iou_cases[case].append(metrics["iou"])

            if metrics["dice"] < 0.6 and _config["debug"]:
                path = f'{_run.observers[0].dir}/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}'
                if _config["debug"]:
                    path = f'debug/bad_preds/case_{case}_idx_{idx}_dice_{metrics["dice"]:.4f}'
                os.makedirs(path, exist_ok=True)
                print(f"saving bad prediction to {path}")
                plot_pred_gt_support(query_images[0,0].cpu(), query_pred.cpu(), query_labels[0].cpu(
                    ), support_images, support_fg_mask, save_path=path, score=scores[0])
                
            pbar.set_postfix_str({"mdice": f"{np.mean(mean_dice):.4f}", "miou": f"{np.mean(mean_iou):.4f}, n_try: {n_try}"})
                

    for k in mean_dice_cases.keys():
        _run.log_scalar(f'mar_val_batches_meanDice_{k}', np.mean(mean_dice_cases[k]))
        _run.log_scalar(f'mar_val_batches_meanIOU_{k}', np.mean(mean_iou_cases[k]))
        _log.info(f'mar_val batches meanDice_{k}: {np.mean(mean_dice_cases[k])}')
        _log.info(f'mar_val batches meanIOU_{k}: {np.mean(mean_iou_cases[k])}') 
    
    # write validation result to log file
    m_meanDice = np.mean(mean_dice)
    m_meanPrec = np.mean(mean_prec)
    m_meanRec = np.mean(mean_rec)
    m_meanIOU = np.mean(mean_iou)

    _run.log_scalar('mar_val_batches_meanDice', m_meanDice)
    _run.log_scalar('mar_val_batches_meanPrec', m_meanPrec)
    _run.log_scalar('mar_val_al_batches_meanRec', m_meanRec)
    _run.log_scalar('mar_val_al_batches_meanIOU', m_meanIOU)
    _log.info(f'mar_val batches meanDice: {m_meanDice}')
    _log.info(f'mar_val batches meanPrec: {m_meanPrec}')
    _log.info(f'mar_val batches meanRec: {m_meanRec}')
    _log.info(f'mar_val batches meanIOU: {m_meanIOU}')
    print("============ ============")
    _log.info(f'End of validation')
    return 1