|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
import time |
|
|
from unittest import TestCase |
|
|
from uuid import uuid4 |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.distributed import destroy_process_group |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
import mmengine.hooks |
|
|
import mmengine.optim |
|
|
from mmengine.config import Config |
|
|
from mmengine.dist import is_distributed |
|
|
from mmengine.evaluator import BaseMetric |
|
|
from mmengine.logging import MessageHub, MMLogger |
|
|
from mmengine.model import BaseModel |
|
|
from mmengine.registry import DATASETS, METRICS, MODELS, DefaultScope |
|
|
from mmengine.runner import Runner |
|
|
from mmengine.visualization import Visualizer |
|
|
|
|
|
|
|
|
class ToyModel(BaseModel): |
|
|
|
|
|
def __init__(self, data_preprocessor=None): |
|
|
super().__init__(data_preprocessor=data_preprocessor) |
|
|
self.linear1 = nn.Linear(2, 2) |
|
|
self.linear2 = nn.Linear(2, 1) |
|
|
|
|
|
def forward(self, inputs, data_samples=None, mode='tensor'): |
|
|
if isinstance(inputs, list): |
|
|
inputs = torch.stack(inputs) |
|
|
if isinstance(data_samples, list): |
|
|
data_samples = torch.stack(data_samples) |
|
|
outputs = self.linear1(inputs) |
|
|
outputs = self.linear2(outputs) |
|
|
|
|
|
if mode == 'tensor': |
|
|
return outputs |
|
|
elif mode == 'loss': |
|
|
loss = (data_samples - outputs).sum() |
|
|
outputs = dict(loss=loss) |
|
|
return outputs |
|
|
elif mode == 'predict': |
|
|
return outputs |
|
|
|
|
|
|
|
|
class ToyDataset(Dataset): |
|
|
METAINFO = dict() |
|
|
data = torch.randn(12, 2) |
|
|
label = torch.ones(12) |
|
|
|
|
|
@property |
|
|
def metainfo(self): |
|
|
return self.METAINFO |
|
|
|
|
|
def __len__(self): |
|
|
return self.data.size(0) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
return dict(inputs=self.data[index], data_samples=self.label[index]) |
|
|
|
|
|
|
|
|
class ToyMetric(BaseMetric): |
|
|
|
|
|
def __init__(self, collect_device='cpu', dummy_metrics=None): |
|
|
super().__init__(collect_device=collect_device) |
|
|
self.dummy_metrics = dummy_metrics |
|
|
|
|
|
def process(self, data_batch, predictions): |
|
|
result = {'acc': 1} |
|
|
self.results.append(result) |
|
|
|
|
|
def compute_metrics(self, results): |
|
|
return dict(acc=1) |
|
|
|
|
|
|
|
|
class RunnerTestCase(TestCase): |
|
|
"""A test case to build runner easily. |
|
|
|
|
|
`RunnerTestCase` will do the following things: |
|
|
|
|
|
1. Registers a toy model, a toy metric, and a toy dataset, which can be |
|
|
used to run the `Runner` successfully. |
|
|
2. Provides epoch based and iteration based cfg to build runner. |
|
|
3. Provides `build_runner` method to build runner easily. |
|
|
4. Clean the global variable used by the runner. |
|
|
""" |
|
|
dist_cfg = dict( |
|
|
MASTER_ADDR='127.0.0.1', |
|
|
MASTER_PORT=29600, |
|
|
RANK='0', |
|
|
WORLD_SIZE='1', |
|
|
LOCAL_RANK='0') |
|
|
|
|
|
def setUp(self) -> None: |
|
|
self.temp_dir = tempfile.TemporaryDirectory() |
|
|
|
|
|
|
|
|
MODELS.register_module(module=ToyModel, force=True) |
|
|
METRICS.register_module(module=ToyMetric, force=True) |
|
|
DATASETS.register_module(module=ToyDataset, force=True) |
|
|
epoch_based_cfg = dict( |
|
|
work_dir=self.temp_dir.name, |
|
|
model=dict(type='ToyModel'), |
|
|
train_dataloader=dict( |
|
|
dataset=dict(type='ToyDataset'), |
|
|
sampler=dict(type='DefaultSampler', shuffle=True), |
|
|
batch_size=3, |
|
|
num_workers=0), |
|
|
val_dataloader=dict( |
|
|
dataset=dict(type='ToyDataset'), |
|
|
sampler=dict(type='DefaultSampler', shuffle=False), |
|
|
batch_size=3, |
|
|
num_workers=0), |
|
|
val_evaluator=[dict(type='ToyMetric')], |
|
|
test_dataloader=dict( |
|
|
dataset=dict(type='ToyDataset'), |
|
|
sampler=dict(type='DefaultSampler', shuffle=False), |
|
|
batch_size=3, |
|
|
num_workers=0), |
|
|
test_evaluator=[dict(type='ToyMetric')], |
|
|
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), |
|
|
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), |
|
|
val_cfg=dict(), |
|
|
test_cfg=dict(), |
|
|
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), |
|
|
custom_hooks=[], |
|
|
env_cfg=dict(dist_cfg=dict(backend='nccl')), |
|
|
experiment_name='test1') |
|
|
self.epoch_based_cfg = Config(epoch_based_cfg) |
|
|
|
|
|
|
|
|
self.iter_based_cfg: Config = copy.deepcopy(self.epoch_based_cfg) |
|
|
self.iter_based_cfg.train_dataloader = dict( |
|
|
dataset=dict(type='ToyDataset'), |
|
|
sampler=dict(type='InfiniteSampler', shuffle=True), |
|
|
batch_size=3, |
|
|
num_workers=0) |
|
|
self.iter_based_cfg.log_processor = dict(by_epoch=False) |
|
|
|
|
|
self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) |
|
|
self.iter_based_cfg.default_hooks = dict( |
|
|
logger=dict(type='LoggerHook', interval=1), |
|
|
checkpoint=dict( |
|
|
type='CheckpointHook', interval=12, by_epoch=False)) |
|
|
|
|
|
def tearDown(self): |
|
|
|
|
|
|
|
|
logging.shutdown() |
|
|
MMLogger._instance_dict.clear() |
|
|
Visualizer._instance_dict.clear() |
|
|
DefaultScope._instance_dict.clear() |
|
|
MessageHub._instance_dict.clear() |
|
|
MODELS.module_dict.pop('ToyModel', None) |
|
|
METRICS.module_dict.pop('ToyMetric', None) |
|
|
DATASETS.module_dict.pop('ToyDataset', None) |
|
|
self.temp_dir.cleanup() |
|
|
if is_distributed(): |
|
|
destroy_process_group() |
|
|
|
|
|
def build_runner(self, cfg: Config): |
|
|
cfg.experiment_name = self.experiment_name |
|
|
runner = Runner.from_cfg(cfg) |
|
|
return runner |
|
|
|
|
|
@property |
|
|
def experiment_name(self): |
|
|
|
|
|
|
|
|
|
|
|
return f'{self._testMethodName}_{time.time()} + ' \ |
|
|
f'{uuid4()}' |
|
|
|
|
|
def setup_dist_env(self): |
|
|
self.dist_cfg['MASTER_PORT'] += 1 |
|
|
os.environ['MASTER_PORT'] = str(self.dist_cfg['MASTER_PORT']) |
|
|
os.environ['MASTER_ADDR'] = self.dist_cfg['MASTER_ADDR'] |
|
|
os.environ['RANK'] = self.dist_cfg['RANK'] |
|
|
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] |
|
|
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] |
|
|
|
|
|
def clear_work_dir(self): |
|
|
logging.shutdown() |
|
|
for filename in os.listdir(self.temp_dir.name): |
|
|
filepath = os.path.join(self.temp_dir.name, filename) |
|
|
if os.path.isfile(filepath): |
|
|
os.remove(filepath) |
|
|
else: |
|
|
shutil.rmtree(filepath) |
|
|
|