| import os |
| from PIL import Image |
| import sys |
| from matplotlib import pyplot as plt |
| import torch |
|
|
| sys.path.append("/home/ubuntu/Desktop/Domain_Adaptation_Project/repos/SVDSAM/") |
| from utils import * |
|
|
| |
| test_path = "endovis17_lora16" |
|
|
| |
| |
| instruments = [('Left Grasping Retractor','Right Grasping Retractor'),('Left Large Needle Driver','Right Large Needle Driver'),('Left Prograsp Forceps','Right Prograsp Forceps')] |
|
|
| for dataset in sorted(os.listdir(test_path)): |
| for instrument in instruments: |
| dices = [] |
| ious = [] |
| if len(instrument)==3: |
| gt_path1 = os.path.join(test_path, dataset,instrument[0],'rescaled_gt') |
| gt_path2 = os.path.join(test_path, dataset,instrument[2],'rescaled_gt') |
| extra_preds_path = os.path.join(test_path, dataset,instrument[2],'rescaled_preds') |
| else: |
| gt_path = os.path.join(test_path, dataset,instrument[0],'rescaled_gt') |
| left_preds_path = os.path.join(test_path, dataset,instrument[0],'rescaled_preds') |
| right_preds_path = os.path.join(test_path, dataset,instrument[1],'rescaled_preds') |
| for frame in sorted(os.listdir(left_preds_path)): |
| if len(instrument)==3: |
| gold1 = ((plt.imread(os.path.join(gt_path1,frame))[:,:,0][58:-52,143:-126])>=0.5)+0 |
| gold2 = ((plt.imread(os.path.join(gt_path2,frame))[:,:,0][58:-52,143:-126])>=0.5)+0 |
| extra_pred = ((plt.imread(os.path.join(extra_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5) |
| gold = (gold1 | gold2)+0 |
| else: |
| gold = ((plt.imread(os.path.join(gt_path,frame))[:,:,0][58:-52,143:-126])>=0.5)+0 |
| left_pred = ((plt.imread(os.path.join(left_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5) |
| right_pred = ((plt.imread(os.path.join(right_preds_path, frame))[:,:,0][58:-52,143:-126])>=0.5) |
| |
| pred = (left_pred | right_pred) |
| if len(instrument)==3: |
| pred = (pred | extra_pred) |
| pred = pred + 0 |
| gold = torch.Tensor(gold).unsqueeze(0) |
| pred = torch.Tensor(pred).unsqueeze(0) |
| dices.append(dice_coef(gold, pred)) |
| ious.append(iou_coef(gold, pred)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| print(f"Dataset: {dataset}, instrument: {instrument}, dice: {torch.mean(torch.Tensor(dices))}, iou: {torch.mean(torch.Tensor(ious))}") |
| print('\n') |