| |
| import os |
| import sys |
| sys.path.append('/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/model') |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| import json |
| import random |
| import torch |
| import torch.nn.functional as F |
| import torchvision.models as models |
| import torchvision.transforms as transforms |
|
|
| import transformer |
|
|
| import tqdm |
|
|
| from pycocotools import mask as mask_utils |
| from torch.cuda.amp import autocast |
|
|
|
|
| MASKThresh = 0.5 |
| root_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/HANDAL" |
| json_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/HANDAL/handal_test_all.json" |
| with open(json_path, "r") as fp: |
| datas = json.load(fp) |
|
|
|
|
| def reshape_img_war(img, size=(480, 480)): |
| C = 1 |
| if len(img.shape) == 2: |
| H, W = img.shape |
| img = img[..., None] |
| else: |
| H, W, C = img.shape |
| |
| temp = np.zeros((max(H, W), max(H, W), C), dtype=np.uint8) |
|
|
| if H > W: |
| L = (H - W) // 2 |
| temp[:, L:-L] = img |
| elif W > H: |
| L = (W - H) // 2 |
| temp[L:-L] = img |
| else: |
| temp = img |
|
|
| temp = cv2.resize(temp, size, interpolation=cv2.INTER_NEAREST) |
|
|
| return temp |
|
|
| def get_model(model_path): |
| |
| device = torch.device('cuda') |
|
|
| backbone = models.resnet50(pretrained=False) |
| resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3'] |
| resnet_module_list = [getattr(backbone,l) for l in resnet_feature_layers] |
| last_layer_idx = resnet_feature_layers.index('layer3') |
| backbone = torch.nn.Sequential(*resnet_module_list[:last_layer_idx+1]) |
|
|
| |
| pos_weight = 0.1 |
| feat_weight = 1 |
| dropout = 0.1 |
| activation = 'relu' |
| mode = 'small' |
| layer_type = ['I', 'C', 'I', 'C', 'I', 'N'] |
| drop_feat = 0.1 |
| feat_dim=1024 |
|
|
| |
| netEncoder = transformer.TransEncoder(feat_dim, |
| pos_weight = pos_weight, |
| feat_weight = feat_weight, |
| dropout = dropout, |
| activation = activation, |
| mode = mode, |
| layer_type = layer_type, |
| drop_feat = drop_feat) |
|
|
| netEncoder.to(device) |
|
|
| print ('Loading net weight from {}'.format(model_path)) |
| param = torch.load(model_path) |
| backbone.load_state_dict(param['backbone']) |
| netEncoder.load_state_dict(param['encoder']) |
| backbone.eval() |
| netEncoder.eval() |
| backbone.to(device) |
| netEncoder.to(device) |
|
|
| return backbone, netEncoder |
|
|
| def get_tensors(I1np, I2np, M1np): |
|
|
| |
| I1np = I1np |
|
|
| I1 = I1np |
| I2 = I2np |
|
|
| norm_mean, norm_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) |
| transformINet = transforms.Compose([transforms.ToTensor(), |
| transforms.Normalize(norm_mean, norm_std)]) |
| tensor1 = I1 |
| tensor2 = I2 |
| tensor3 = M1np |
|
|
|
|
| tensor1 = transformINet(tensor1).unsqueeze(0).cuda() |
| tensor2 = transformINet(tensor2).unsqueeze(0).cuda() |
| tensor3 = torch.from_numpy(tensor3).unsqueeze(0).type(torch.FloatTensor).cuda() |
|
|
| return I1, I2, tensor1, tensor2, tensor3 |
|
|
| def forward_pass(backbone, netEncoder, tensor1, tensor2, tensor3): |
| with torch.no_grad(): |
| with autocast(): |
| feat1 = backbone( tensor1 ) |
| feat1 = F.normalize(feat1, dim=1) |
| feat2 = backbone( tensor2 ) |
| feat2 = F.normalize(feat2, dim=1) |
|
|
| fmask = backbone(tensor3.unsqueeze(0).repeat(1, 3, 1, 1)) |
| fmask = F.normalize(fmask, dim=1) |
|
|
| out1, out2, out3 = netEncoder(feat1, feat2, fmask) |
| m1_final, m2_final, m3_final = out1[0, 2].cpu().numpy(), out2[0, 2].cpu().numpy(), out3.item() |
| |
| return m1_final, m2_final, m3_final |
|
|
|
|
| |
| def egoexo(backbone, netEncoder, ego, exo, obj, pred_json): |
|
|
| pred_json['masks'][obj][f'{ego}_{exo}'] = {} |
| data_list = [] |
| for data in datas: |
| if data["image"].split('/')[0] == obj: |
| data_list.append(data) |
| for idx, data in enumerate(data_list): |
| query_frame_tmp = cv2.imread(os.path.join(root_path, data["first_frame_image"]))[..., ::-1] |
| query_frame = Image.fromarray(reshape_img_war(query_frame_tmp)) |
| ann_query = data["first_frame_anns"][0] |
| query_mask = mask_utils.decode(ann_query["segmentation"]) |
| query_mask = reshape_img_war(query_mask) |
|
|
| target_frame_tmp = cv2.imread(os.path.join(root_path, data["image"]))[..., ::-1] |
| target_frame = Image.fromarray(reshape_img_war(target_frame_tmp)) |
|
|
| Ix, Iy, tensor1, tensor2, tensor3 = get_tensors(query_frame, target_frame, query_mask) |
| mx, my, confidence = forward_pass(backbone, netEncoder, tensor1, tensor2, tensor3) |
|
|
| y_step = (my > MASKThresh) |
|
|
| target_pred = mask_utils.encode(np.asfortranarray(y_step.astype(np.uint8))) |
| target_pred['counts'] = target_pred['counts'].decode('ascii') |
| idx = str(idx) |
| pred_json['masks'][obj][f'{ego}_{exo}'][idx] = {'pred_mask': target_pred, 'confidence': confidence} |
| |
|
|
| def evaluate_ours(backbone, netEncoder): |
| result_list = [] |
| for data in tqdm.tqdm(datas): |
| query_frame_tmp = cv2.imread(os.path.join(root_path, data["first_frame_image"]))[..., ::-1] |
| query_frame = Image.fromarray(reshape_img_war(query_frame_tmp)) |
| ann_query = data["first_frame_anns"][0] |
| query_mask = mask_utils.decode(ann_query["segmentation"]) |
| query_mask = reshape_img_war(query_mask) |
|
|
| target_frame_tmp = cv2.imread(os.path.join(root_path, data["image"]))[..., ::-1] |
| target_frame = Image.fromarray(reshape_img_war(target_frame_tmp)) |
|
|
| Ix, Iy, tensor1, tensor2, tensor3 = get_tensors(query_frame, target_frame, query_mask) |
| mx, my, confidence = forward_pass(backbone, netEncoder, tensor1, tensor2, tensor3) |
|
|
| y_step = (my > MASKThresh) |
|
|
| target_pred = mask_utils.encode(np.asfortranarray(y_step.astype(np.uint8))) |
| target_pred['counts'] = target_pred['counts'].decode('ascii') |
| sample = { |
| 'image': data["image"], |
| 'anns': data["anns"], |
| 'new_img_id': data["new_img_id"], |
| 'pred_mask': target_pred |
| } |
| result_list.append(sample) |
| return result_list |
|
|
|
|
|
|
|
|
| def main(): |
|
|
| resume_path = '/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/model/moco_v2_800ep_pretrain_torchvision.pth.tar' |
| param = torch.load(resume_path)['model'] |
| new_param = {} |
| for key in param.keys(): |
| if 'fc' in key: |
| continue |
| new_param[key] = param[key] |
|
|
| backbone = models.resnet50(pretrained=False) |
| backbone.load_state_dict(new_param, strict=False) |
| resnet_feature_layers = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3'] |
| resnet_module_list = [getattr(backbone, l) for l in resnet_feature_layers] |
| last_layer_idx = resnet_feature_layers.index('layer3') |
| backbone = torch.nn.Sequential(*resnet_module_list[:last_layer_idx + 1]) |
| feat_dim = 1024 |
| backbone.cuda() |
|
|
| |
| netEncoder = transformer.TransEncoder(feat_dim, |
| pos_weight=0.1, |
| feat_weight=1, |
| dropout=0.1, |
| activation="relu", |
| mode="small", |
| layer_type=['I', 'C', 'I', 'C', 'I', 'N'], |
| drop_feat=0.1) |
|
|
| netEncoder.cuda() |
| |
| result_list = evaluate_ours(backbone, netEncoder) |
|
|
| return result_list |
|
|
| if __name__ == '__main__': |
|
|
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--out_path', type=str, required=True) |
|
|
| args = parser.parse_args() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| results = main() |
| print("data_num:", len(results)) |
|
|
| os.makedirs(args.out_path, exist_ok=True) |
| with open(f"{args.out_path}/segswap_handal_pred_1.json", "w") as fp: |
| json.dump(results, fp) |
|
|
|
|