| | |
| | import tempfile |
| | from unittest import TestCase |
| | from unittest.mock import Mock |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from mmengine.evaluator import Evaluator |
| | from mmengine.model import BaseModel |
| | from mmengine.optim import OptimWrapper |
| | from mmengine.runner import Runner |
| | from torch.utils.data import Dataset |
| |
|
| | from mmdet.registry import DATASETS |
| | from mmdet.utils import register_all_modules |
| |
|
| | register_all_modules() |
| |
|
| |
|
| | class ToyModel(nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.linear = nn.Linear(2, 1) |
| |
|
| | def forward(self, inputs, data_samples, mode='tensor'): |
| | labels = torch.stack(data_samples) |
| | inputs = torch.stack(inputs) |
| | outputs = self.linear(inputs) |
| | if mode == 'tensor': |
| | return outputs |
| | elif mode == 'loss': |
| | loss = (labels - outputs).sum() |
| | outputs = dict(loss=loss) |
| | return outputs |
| | else: |
| | return outputs |
| |
|
| |
|
| | class ToyModel1(BaseModel, ToyModel): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, *args, **kwargs): |
| | return super(BaseModel, self).forward(*args, **kwargs) |
| |
|
| |
|
| | class ToyModel2(BaseModel): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.teacher = ToyModel1() |
| | self.student = ToyModel1() |
| | self.semi_test_cfg = dict(predict_on='teacher') |
| |
|
| | def forward(self, *args, **kwargs): |
| | return self.student(*args, **kwargs) |
| |
|
| |
|
| | @DATASETS.register_module(force=True) |
| | class DummyDataset(Dataset): |
| | METAINFO = dict() |
| | data = torch.randn(12, 2) |
| | label = torch.ones(12) |
| |
|
| | @property |
| | def metainfo(self): |
| | return self.METAINFO |
| |
|
| | def __len__(self): |
| | return self.data.size(0) |
| |
|
| | def __getitem__(self, index): |
| | return dict(inputs=self.data[index], data_samples=self.label[index]) |
| |
|
| |
|
| | class TestTeacherStudentValLoop(TestCase): |
| |
|
| | def setUp(self): |
| | self.temp_dir = tempfile.TemporaryDirectory() |
| |
|
| | def tearDown(self): |
| | self.temp_dir.cleanup() |
| |
|
| | def test_teacher_student_val_loop(self): |
| | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| | model = ToyModel2().to(device) |
| | evaluator = Mock() |
| | evaluator.evaluate = Mock(return_value=dict(acc=0.5)) |
| | evaluator.__class__ = Evaluator |
| | runner = Runner( |
| | model=model, |
| | train_dataloader=dict( |
| | dataset=dict(type='DummyDataset'), |
| | sampler=dict(type='DefaultSampler', shuffle=True), |
| | batch_size=3, |
| | num_workers=0), |
| | val_dataloader=dict( |
| | dataset=dict(type='DummyDataset'), |
| | sampler=dict(type='DefaultSampler', shuffle=False), |
| | batch_size=3, |
| | num_workers=0), |
| | val_evaluator=evaluator, |
| | work_dir=self.temp_dir.name, |
| | default_scope='mmdet', |
| | optim_wrapper=OptimWrapper( |
| | torch.optim.Adam(ToyModel().parameters())), |
| | train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), |
| | val_cfg=dict(type='TeacherStudentValLoop'), |
| | default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), |
| | experiment_name='test1') |
| | runner.train() |
| |
|