| import os.path as osp | |
| import tempfile | |
| from copy import deepcopy | |
| import pytest | |
| from mmengine.config import Config | |
| from mmdet.utils import replace_cfg_vals | |
| def test_replace_cfg_vals(): | |
| temp_file = tempfile.NamedTemporaryFile() | |
| cfg_path = f'{temp_file.name}.py' | |
| with open(cfg_path, 'w') as f: | |
| f.write('configs') | |
| ori_cfg_dict = dict() | |
| ori_cfg_dict['cfg_name'] = osp.basename(temp_file.name) | |
| ori_cfg_dict['work_dir'] = 'work_dirs/${cfg_name}/${percent}/${fold}' | |
| ori_cfg_dict['percent'] = 5 | |
| ori_cfg_dict['fold'] = 1 | |
| ori_cfg_dict['model_wrapper'] = dict( | |
| type='SoftTeacher', detector='${model}') | |
| ori_cfg_dict['model'] = dict( | |
| type='FasterRCNN', | |
| backbone=dict(type='ResNet'), | |
| neck=dict(type='FPN'), | |
| rpn_head=dict(type='RPNHead'), | |
| roi_head=dict(type='StandardRoIHead'), | |
| train_cfg=dict( | |
| rpn=dict( | |
| assigner=dict(type='MaxIoUAssigner'), | |
| sampler=dict(type='RandomSampler'), | |
| ), | |
| rpn_proposal=dict(nms=dict(type='nms', iou_threshold=0.7)), | |
| rcnn=dict( | |
| assigner=dict(type='MaxIoUAssigner'), | |
| sampler=dict(type='RandomSampler'), | |
| ), | |
| ), | |
| test_cfg=dict( | |
| rpn=dict(nms=dict(type='nms', iou_threshold=0.7)), | |
| rcnn=dict(nms=dict(type='nms', iou_threshold=0.5)), | |
| ), | |
| ) | |
| ori_cfg_dict['iou_threshold'] = dict( | |
| rpn_proposal_nms='${model.train_cfg.rpn_proposal.nms.iou_threshold}', | |
| test_rpn_nms='${model.test_cfg.rpn.nms.iou_threshold}', | |
| test_rcnn_nms='${model.test_cfg.rcnn.nms.iou_threshold}', | |
| ) | |
| ori_cfg_dict['str'] = 'Hello, world!' | |
| ori_cfg_dict['dict'] = {'Hello': 'world!'} | |
| ori_cfg_dict['list'] = [ | |
| 'Hello, world!', | |
| ] | |
| ori_cfg_dict['tuple'] = ('Hello, world!', ) | |
| ori_cfg_dict['test_str'] = 'xxx${str}xxx' | |
| ori_cfg = Config(ori_cfg_dict, filename=cfg_path) | |
| updated_cfg = replace_cfg_vals(deepcopy(ori_cfg)) | |
| assert updated_cfg.work_dir \ | |
| == f'work_dirs/{osp.basename(temp_file.name)}/5/1' | |
| assert updated_cfg.model.detector == ori_cfg.model | |
| assert updated_cfg.iou_threshold.rpn_proposal_nms \ | |
| == ori_cfg.model.train_cfg.rpn_proposal.nms.iou_threshold | |
| assert updated_cfg.test_str == 'xxxHello, world!xxx' | |
| ori_cfg_dict['test_dict'] = 'xxx${dict}xxx' | |
| ori_cfg_dict['test_list'] = 'xxx${list}xxx' | |
| ori_cfg_dict['test_tuple'] = 'xxx${tuple}xxx' | |
| with pytest.raises(AssertionError): | |
| cfg = deepcopy(ori_cfg) | |
| cfg['test_dict'] = 'xxx${dict}xxx' | |
| updated_cfg = replace_cfg_vals(cfg) | |
| with pytest.raises(AssertionError): | |
| cfg = deepcopy(ori_cfg) | |
| cfg['test_list'] = 'xxx${list}xxx' | |
| updated_cfg = replace_cfg_vals(cfg) | |
| with pytest.raises(AssertionError): | |
| cfg = deepcopy(ori_cfg) | |
| cfg['test_tuple'] = 'xxx${tuple}xxx' | |
| updated_cfg = replace_cfg_vals(cfg) | |