Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import math | |
| import os | |
| import numpy as np | |
| import skimage | |
| import torch | |
| from torch.nn import DataParallel | |
| from torch.utils.data import DataLoader | |
| from torchvision import ops | |
| from torchvision.transforms import Resize | |
| from tqdm import tqdm | |
| from models.counter_infer import build_model | |
| from models.matcher import build_matcher | |
| from utils.arg_parser import get_argparser | |
| from utils.box_ops import BoxList | |
| from utils.data import FSC147DATASET, pad_collate_test | |
| from utils.losses import SetCriterion | |
| def evaluate(args): | |
| gpu=0 | |
| torch.cuda.set_device(gpu) | |
| device = torch.device(gpu) | |
| model = DataParallel( | |
| build_model(args).to(device), | |
| device_ids=[gpu], | |
| output_device=gpu | |
| ) | |
| state_dict = torch.load(os.path.join(args.model_path, f'{args.model_name}.pth'))['model'] | |
| state_dict = {k if 'module.' in k else 'module.' + k: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=False) | |
| for split in ['val', 'test']: | |
| test = FSC147DATASET( | |
| args.data_path, | |
| args.image_size, | |
| split=split, | |
| num_objects=args.num_objects, | |
| tiling_p=args.tiling_p, | |
| return_ids=True, | |
| training=False | |
| ) | |
| test_loader = DataLoader( | |
| test, | |
| batch_size=args.batch_size, | |
| drop_last=False, | |
| num_workers=args.num_workers, | |
| collate_fn=pad_collate_test, | |
| ) | |
| ae = torch.tensor(0.0).to(device) | |
| se = torch.tensor(0.0).to(device) | |
| model.eval() | |
| matcher = build_matcher(args) | |
| criterion = SetCriterion(0, matcher, {"loss_giou":args.giou_loss_coef}, ["bboxes", "ce"], focal_alpha=args.focal_alpha) | |
| criterion.to(device) | |
| predictions = dict() | |
| predictions["categories"] = [{"name": "fg", "id": 1}] | |
| predictions["images"] = list() | |
| predictions["annotations"] = list() | |
| anno_id = 1 | |
| for img, bboxes, density_map, ids, gt_bboxes, scaling_factor, padwh in test_loader: | |
| img = img.to(device) | |
| bboxes = bboxes.to(device) | |
| gt_bboxes = gt_bboxes.to(device) | |
| outputs, ref_points, _, _, masks = model(img, bboxes) | |
| w, h = img.shape[-1], img.shape[-2] | |
| losses = [] | |
| num_objects_gt = [] | |
| num_objects_pred = [] | |
| nms_bboxes = [] | |
| nms_scores = [] | |
| nms_masks = [] | |
| for idx in range(img.shape[0]): | |
| thr = 1/0.11 | |
| if len(outputs[idx]['pred_boxes'][-1]) == 0: | |
| nms_bboxes.append(torch.zeros((0, 4))) | |
| nms_scores.append(torch.zeros((0))) | |
| num_objects_pred.append(0) | |
| else: | |
| # threshold and NMS | |
| v = outputs[idx]["box_v"] | |
| v_thr = v.max() / thr | |
| mask = v > v_thr | |
| keep = ops.nms( | |
| outputs[idx]["pred_boxes"][mask], | |
| v[mask], | |
| 0.5, | |
| ) | |
| boxes = outputs[idx]["pred_boxes"][mask][keep] | |
| boxes = torch.clamp(boxes, 0, 1) | |
| scores = outputs[idx]["scores"][mask][keep] | |
| # remove bboxes in padded area | |
| maxw = (img.shape[-1] - padwh[idx][0]).to(device) | |
| maxh = (img.shape[-2] - padwh[idx][1]).to(device) | |
| center = (boxes[:, :2] + boxes[:, 2:]) / 2 | |
| valid = (center[:, 0] * h < maxw) & (center[:, 1] * w < maxh) | |
| scores = scores[valid] | |
| boxes = boxes[valid] | |
| nms_bboxes.append(boxes) | |
| nms_scores.append(scores) | |
| num_objects_pred.append(len(boxes)) | |
| if False: | |
| from matplotlib import pyplot as plt | |
| fig1 = plt.figure(figsize=(8, 8)) | |
| ((ax1_11, ax1_12), (ax1_21, ax1_22)) = fig1.subplots(2, 2) | |
| fig1.tight_layout(pad=2.5) | |
| img_ = np.array((img).cpu()[idx].permute(1, 2, 0)) | |
| img_ = img_ - np.min(img_) | |
| img_ = img_ / np.max(img_) | |
| ax1_11.imshow(img_) | |
| ax1_11.set_title("Input", fontsize=8) | |
| bboxes_ = np.array(bboxes.cpu())[idx] | |
| for i in range(3): | |
| ax1_11.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]], | |
| [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='r') | |
| ax1_12.imshow(img_) | |
| ax1_12.set_title("gt bboxes", fontsize=8) | |
| target_bboxes = gt_bboxes[idx][torch.logical_not((gt_bboxes[idx] == 0).all(dim=1))] | |
| bboxes_ = ((target_bboxes)).detach().cpu() | |
| for i in range(len(bboxes_)): | |
| ax1_12.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]], | |
| [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], c='g') | |
| ax1_21.imshow(img_) | |
| bboxes_pred = nms_bboxes[idx] | |
| bboxes_ = ((bboxes_pred * img_.shape[0])).detach().cpu() | |
| for i in range(len(bboxes_)): | |
| ax1_21.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]], | |
| [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], | |
| c='orange', linewidth=0.5) | |
| ax1_21.set_title("#GT-#PRED=" + str(len(target_bboxes) - len(bboxes_pred))) | |
| from torchvision import transforms as T | |
| res = T.Resize((1024, 1024)) | |
| ax1_21.imshow(res(centerness).detach().cpu()[idx][0], alpha=0.6) | |
| plt.savefig(test.image_names[ids[idx].item()], dpi=200) | |
| plt.close() | |
| for idx in range(img.shape[0]): | |
| img_info = { | |
| "id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]], | |
| "file_name": "None", | |
| } | |
| bboxes = ops.box_convert(nms_bboxes[idx], 'xyxy', 'xywh') | |
| bboxes = bboxes * img.shape[-1] / scaling_factor[idx] | |
| for idxi in range(len(nms_bboxes[idx])): | |
| box = bboxes[idxi].detach().cpu() | |
| anno = { | |
| "id": anno_id, | |
| "image_id": test.map_img_name_to_ori_id()[test.image_names[ids[idx].item()]], | |
| "area": int((box[2] * box[3]).item()), | |
| "bbox": [int(box[0].item()), int(box[1].item()), int(box[2].item()), int(box[3].item())], | |
| "category_id": 1, | |
| "score": float(nms_scores[idx][idxi].item()), | |
| } | |
| anno_id += 1 | |
| predictions["annotations"].append(anno) | |
| predictions["images"].append(img_info) | |
| num_objects_gt = density_map.flatten(1).sum(dim=1) | |
| num_objects_pred = torch.tensor(num_objects_pred) | |
| ae += torch.abs( | |
| num_objects_gt - num_objects_pred | |
| ).sum() | |
| se += torch.pow( | |
| num_objects_gt - num_objects_pred, 2 | |
| ).sum() | |
| print( | |
| f"{split.capitalize()} set", | |
| f"MAE: {ae.item() / len(test):.2f}", | |
| f"RMSE: {torch.sqrt(se / len(test)).item():.2f}", | |
| ) | |
| with open("geco2_" + split + ".json", "w") as handle: | |
| json.dump(predictions, handle) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('GECO2', parents=[get_argparser()]) | |
| args = parser.parse_args() | |
| print(args) | |
| print("model_name: ", args.model_name) | |
| evaluate(args) | |