|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from unittest.mock import patch |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from mmengine.model import BaseModel |
|
|
from mmengine.registry import MODELS |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class ToyModel(BaseModel): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__() |
|
|
self.conv = nn.Conv2d(1, 1, 1) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return {'loss': torch.tensor(0.0)} |
|
|
|
|
|
|
|
|
def update_params_step(self, loss): |
|
|
pass |
|
|
|
|
|
|
|
|
def runtimeinfo_step(self, runner, batch_idx, data_batch=None): |
|
|
runner.message_hub.update_info('iter', runner.iter) |
|
|
lr_dict = runner.optim_wrapper.get_lr() |
|
|
for name, lr in lr_dict.items(): |
|
|
runner.message_hub.update_scalar(f'train/{name}', lr[0]) |
|
|
|
|
|
momentum_dict = runner.optim_wrapper.get_momentum() |
|
|
for name, momentum in momentum_dict.items(): |
|
|
runner.message_hub.update_scalar(f'train/{name}', momentum[0]) |
|
|
|
|
|
|
|
|
@patch('mmengine.optim.optimizer.OptimWrapper.update_params', |
|
|
update_params_step) |
|
|
@patch('mmengine.hooks.RuntimeInfoHook.before_train_iter', runtimeinfo_step) |
|
|
def fake_run(cfg): |
|
|
from mmengine.runner import Runner |
|
|
cfg.pop('model') |
|
|
cfg.pop('visualizer') |
|
|
cfg.pop('val_dataloader') |
|
|
cfg.pop('val_evaluator') |
|
|
cfg.pop('val_cfg') |
|
|
cfg.pop('test_dataloader') |
|
|
cfg.pop('test_evaluator') |
|
|
cfg.pop('test_cfg') |
|
|
extra_cfg = dict( |
|
|
model=dict(type='ToyModel'), |
|
|
visualizer=dict( |
|
|
type='Visualizer', |
|
|
vis_backends=[ |
|
|
dict(type='TensorboardVisBackend', save_dir='temp_dir') |
|
|
]), |
|
|
) |
|
|
cfg.merge_from_dict(extra_cfg) |
|
|
|
|
|
runner = Runner.from_cfg(cfg) |
|
|
|
|
|
|
|
|
runner.train() |
|
|
|