| import argparse |
| import os |
| import warnings |
|
|
| import cv2 |
| import torch |
| import torch.nn.parallel |
| import torch.utils.data |
| from loguru import logger |
|
|
| import deepspeed |
| import utils.config as config |
| from engine.engine import inference |
| from model import build_segmenter |
| from utils.dataset_mosaic import RefDataset |
| from utils.misc import setup_logger |
|
|
| from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint |
|
|
|
|
| warnings.filterwarnings("ignore") |
| cv2.setNumThreads(0) |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| description='Pytorch Referring Expression Segmentation') |
| parser.add_argument('--config', |
| default='path to xxx.yaml', |
| type=str, |
| help='config file') |
| parser.add_argument('--opts', |
| default=None, |
| nargs=argparse.REMAINDER, |
| help='override some settings in the config.') |
| args = parser.parse_args() |
| assert args.config is not None |
| cfg = config.load_cfg_from_cfg_file(args.config) |
| if args.opts is not None: |
| cfg = config.merge_cfg_from_list(cfg, args.opts) |
| return cfg |
|
|
|
|
| @logger.catch |
| def main(): |
| args = get_parser() |
| args.output_dir = os.path.join(args.output_folder, args.exp_name) |
| if args.visualize: |
| args.vis_dir = os.path.join(args.output_dir, "vis") |
| os.makedirs(args.vis_dir, exist_ok=True) |
|
|
| |
| setup_logger(args.output_dir, |
| distributed_rank=0, |
| filename="test.log", |
| mode="a") |
| logger.info(args.test_split) |
|
|
| |
| test_data = RefDataset(lmdb_dir=args.test_lmdb, |
| mask_dir=args.mask_root, |
| dataset=args.dataset, |
| split=args.test_split, |
| mode='test', |
| input_size=args.input_size, |
| word_length=args.word_len, |
| args=args) |
| test_loader = torch.utils.data.DataLoader(test_data, |
| batch_size=1, |
| shuffle=False, |
| num_workers=1, |
| pin_memory=True) |
|
|
|
|
| |
| model = build_segmenter(args, DDP=False) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| if os.path.isdir(args.output_dir): |
| logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'") |
| |
| |
| model = load_state_dict_from_zero_checkpoint(model, args.output_dir, tag="best_model").cuda() |
| |
|
|
| logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'") |
| else: |
| raise ValueError( |
| "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!" |
| .format(args.output_dir)) |
|
|
| |
| inference(test_loader, model, args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|