Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import os.path as osp | |
| from unittest.mock import patch | |
| import pytest | |
| import torchvision | |
| import mmcv | |
| from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, | |
| ENV_XDG_CACHE_HOME, _get_mmcv_home, | |
| _load_checkpoint, | |
| get_deprecated_model_names, | |
| get_external_models) | |
| from mmcv.utils import digit_version | |
| def test_set_mmcv_home(): | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/') | |
| os.environ[ENV_MMCV_HOME] = mmcv_home | |
| assert _get_mmcv_home() == mmcv_home | |
| def test_default_mmcv_home(): | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| os.environ.pop(ENV_XDG_CACHE_HOME, None) | |
| assert _get_mmcv_home() == os.path.expanduser( | |
| os.path.join(DEFAULT_CACHE_DIR, 'mmcv')) | |
| model_urls = get_external_models() | |
| assert model_urls == mmcv.load( | |
| osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')) | |
| def test_get_external_models(): | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/') | |
| os.environ[ENV_MMCV_HOME] = mmcv_home | |
| ext_urls = get_external_models() | |
| assert ext_urls == { | |
| 'train': 'https://localhost/train.pth', | |
| 'test': 'test.pth', | |
| 'val': 'val.pth', | |
| 'train_empty': 'train.pth' | |
| } | |
| def test_get_deprecated_models(): | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/') | |
| os.environ[ENV_MMCV_HOME] = mmcv_home | |
| dep_urls = get_deprecated_model_names() | |
| assert dep_urls == { | |
| 'train_old': 'train', | |
| 'test_old': 'test', | |
| } | |
| def load_from_http(url, map_location=None): | |
| return 'url:' + url | |
| def load_url(url, map_location=None, model_dir=None): | |
| return load_from_http(url) | |
| def load(filepath, map_location=None): | |
| return 'local:' + filepath | |
| def test_load_external_url(): | |
| # test modelzoo:// | |
| torchvision_version = torchvision.__version__ | |
| if digit_version(torchvision_version) < digit_version('0.10.0a0'): | |
| assert (_load_checkpoint('modelzoo://resnet50') == | |
| 'url:https://download.pytorch.org/models/resnet50-19c8e' | |
| '357.pth') | |
| assert (_load_checkpoint('torchvision://resnet50') == | |
| 'url:https://download.pytorch.org/models/resnet50-19c8e' | |
| '357.pth') | |
| else: | |
| assert (_load_checkpoint('modelzoo://resnet50') == | |
| 'url:https://download.pytorch.org/models/resnet50-0676b' | |
| 'a61.pth') | |
| assert (_load_checkpoint('torchvision://resnet50') == | |
| 'url:https://download.pytorch.org/models/resnet50-0676b' | |
| 'a61.pth') | |
| if digit_version(torchvision_version) >= digit_version('0.13.0a0'): | |
| # Test load new format torchvision models. | |
| assert ( | |
| _load_checkpoint('torchvision://resnet50.imagenet1k_v1') == | |
| 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') | |
| assert ( | |
| _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') == | |
| 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') | |
| _load_checkpoint('torchvision://resnet50.default') | |
| # test open-mmlab:// with default MMCV_HOME | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| os.environ.pop(ENV_XDG_CACHE_HOME, None) | |
| url = _load_checkpoint('open-mmlab://train') | |
| assert url == 'url:https://localhost/train.pth' | |
| # test open-mmlab:// with deprecated model name | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| os.environ.pop(ENV_XDG_CACHE_HOME, None) | |
| with pytest.warns( | |
| Warning, | |
| match='open-mmlab://train_old is deprecated in favor of ' | |
| 'open-mmlab://train'): | |
| url = _load_checkpoint('open-mmlab://train_old') | |
| assert url == 'url:https://localhost/train.pth' | |
| # test openmmlab:// with deprecated model name | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| os.environ.pop(ENV_XDG_CACHE_HOME, None) | |
| with pytest.warns( | |
| Warning, | |
| match='openmmlab://train_old is deprecated in favor of ' | |
| 'openmmlab://train'): | |
| url = _load_checkpoint('openmmlab://train_old') | |
| assert url == 'url:https://localhost/train.pth' | |
| # test open-mmlab:// with user-defined MMCV_HOME | |
| os.environ.pop(ENV_MMCV_HOME, None) | |
| mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home') | |
| os.environ[ENV_MMCV_HOME] = mmcv_home | |
| url = _load_checkpoint('open-mmlab://train') | |
| assert url == 'url:https://localhost/train.pth' | |
| with pytest.raises(FileNotFoundError, match='train.pth can not be found.'): | |
| _load_checkpoint('open-mmlab://train_empty') | |
| url = _load_checkpoint('open-mmlab://test') | |
| assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' | |
| url = _load_checkpoint('open-mmlab://val') | |
| assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}' | |
| # test http:// https:// | |
| url = _load_checkpoint('http://localhost/train.pth') | |
| assert url == 'url:http://localhost/train.pth' | |
| # test local file | |
| with pytest.raises(FileNotFoundError, match='train.pth can not be found.'): | |
| _load_checkpoint('train.pth') | |
| url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth')) | |
| assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}' | |