Spaces:
Runtime error
Runtime error
| ''' | |
| Save SAM mask predictions | |
| ''' | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| import torch.multiprocessing as mp | |
| import pickle | |
| from tqdm import tqdm | |
| import torch | |
| import cv2 | |
| import os | |
| import json | |
| import argparse | |
| import numpy as np | |
| img_anno = { | |
| 'ade20k_val':['ADEChallengeData2016/images/validation', 'ADEChallengeData2016/ade20k_panoptic_val.json'], | |
| 'pc_val': ['pascal_ctx_d2/images/validation','' ], | |
| 'pas_val':['pascal_voc_d2/images/validation',''], | |
| } | |
| sam_checkpoint_dict = { | |
| 'vit_b': 'pretrained_checkpoint/sam_vit_b_01ec64.pth', | |
| 'vit_h': 'pretrained_checkpoint/sam_vit_h_4b8939.pth', | |
| 'vit_l': 'pretrained_checkpoint/sam_vit_l_0b3195.pth', | |
| 'vit_t': 'pretrained_checkpoint/mobile_sam.pt' | |
| } | |
| def process_images(args, gpu, data_chunk, save_path, if_parallel): | |
| def to_parallel(if_parallel): | |
| sam_checkpoint = sam_checkpoint_dict[args.sam_model] | |
| sam = sam_model_registry[args.sam_model](checkpoint=sam_checkpoint) | |
| if not if_parallel: | |
| torch.cuda.set_device(gpu) | |
| sam = sam.cuda() | |
| else: | |
| sam = sam.cuda() | |
| sam = torch.nn.DataParallel(sam) | |
| sam = sam.module | |
| return sam | |
| sam = to_parallel(if_parallel) | |
| mask_generator = SamAutomaticMaskGenerator( | |
| model=sam, | |
| pred_iou_thresh=0.8, | |
| stability_score_thresh=0.7, | |
| crop_n_layers=0, | |
| crop_n_points_downscale_factor=2, | |
| min_mask_region_area=100, | |
| output_mode='coco_rle' | |
| ) | |
| # Process each image | |
| for image_info in tqdm(data_chunk): | |
| if isinstance(image_info, dict): | |
| if 'coco_url' in image_info: | |
| coco_url = image_info['coco_url'] | |
| file_name = coco_url.split('/')[-1].split('.')[0] + '.jpg' | |
| elif 'file_name' in image_info: | |
| file_name = image_info['file_name'].split('.')[0] + '.jpg' | |
| file_path = os.path.join(dataset_path,img_anno[args.data_name][0]) | |
| else: | |
| assert isinstance(image_info, str) | |
| file_name = image_info.split('.')[0] + '.jpg' | |
| file_path = os.path.join(dataset_path,img_anno[args.data_name][0]) | |
| image_path = f'{file_path}/{file_name}' | |
| try: | |
| id =file_name.split('.')[0] | |
| id = id.replace('/','_') | |
| savepath = f'{save_path}/{id}.pkl' | |
| if not os.path.exists(savepath): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| everything_mask = mask_generator.generate(image, train=False) | |
| everything_mask = sorted(everything_mask, key=lambda x: x['area'], reverse=True) | |
| if len(everything_mask) >50: | |
| everything_mask = everything_mask[:50] | |
| with open(savepath, 'wb') as f: | |
| pickle.dump(everything_mask, f) | |
| except Exception as e: | |
| print(f"Failed to load or convert image at {image_path}. Error: {e}") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--data_name', type=str, default='pas_val') | |
| parser.add_argument('--sam_model', type=str, default='vit_h') | |
| argss = parser.parse_args() | |
| gpus = os.getenv("CUDA_VISIBLE_DEVICES", "") | |
| dataset_path = os.getenv("DETECTRON2_DATASETS", "/users/cx_xchen/DATASETS/") | |
| num_gpus = len([x.strip() for x in gpus.split(",") if x.strip().isdigit()]) | |
| print(f"Using {num_gpus} GPUs") | |
| # File paths | |
| if img_anno[argss.data_name][1] != '': | |
| json_file_path = os.path.join(dataset_path, img_anno[argss.data_name][1]) | |
| # Load data | |
| with open(json_file_path, 'r') as file: | |
| data = json.load(file) | |
| # Split data into chunks for each GPU | |
| data_chunks = np.array_split(data['images'], num_gpus) | |
| else: | |
| image_dir = os.path.join(dataset_path, img_anno[argss.data_name][0]) | |
| image_files = os.listdir(image_dir) | |
| data_chunks = np.array_split(image_files, num_gpus) | |
| # Create processes | |
| save_path = f'output/SAM_masks_pred/{argss.sam_model}_{argss.data_name}' | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path) | |
| processes = [] | |
| parallel = False | |
| # if parallel: | |
| # assert num_gpus>1 | |
| for gpu in range(num_gpus): | |
| p = mp.Process(target=process_images, args=(argss, gpu, data_chunks[gpu],save_path, False)) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() | |
| # elif num_gpus<=1: | |
| # process_images(argss, None, np.concatenate(data_chunks), save_path, if_parallel=True) | |
| # else: | |
| # assert NotImplemented | |