|
|
|
|
|
import tempfile |
|
|
from unittest import TestCase |
|
|
from unittest.mock import Mock |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from mmengine.evaluator import Evaluator |
|
|
from mmengine.model import BaseModel |
|
|
from mmengine.optim import OptimWrapper |
|
|
from mmengine.runner import Runner |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from mmdet.registry import DATASETS |
|
|
from mmdet.utils import register_all_modules |
|
|
|
|
|
register_all_modules() |
|
|
|
|
|
|
|
|
class ToyModel(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(2, 1) |
|
|
|
|
|
def forward(self, inputs, data_samples, mode='tensor'): |
|
|
labels = torch.stack(data_samples) |
|
|
inputs = torch.stack(inputs) |
|
|
outputs = self.linear(inputs) |
|
|
if mode == 'tensor': |
|
|
return outputs |
|
|
elif mode == 'loss': |
|
|
loss = (labels - outputs).sum() |
|
|
outputs = dict(loss=loss) |
|
|
return outputs |
|
|
else: |
|
|
return outputs |
|
|
|
|
|
|
|
|
class ToyModel1(BaseModel, ToyModel): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return super(BaseModel, self).forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
class ToyModel2(BaseModel): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.teacher = ToyModel1() |
|
|
self.student = ToyModel1() |
|
|
self.semi_test_cfg = dict(predict_on='teacher') |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self.student(*args, **kwargs) |
|
|
|
|
|
|
|
|
@DATASETS.register_module(force=True) |
|
|
class DummyDataset(Dataset): |
|
|
METAINFO = dict() |
|
|
data = torch.randn(12, 2) |
|
|
label = torch.ones(12) |
|
|
|
|
|
@property |
|
|
def metainfo(self): |
|
|
return self.METAINFO |
|
|
|
|
|
def __len__(self): |
|
|
return self.data.size(0) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
return dict(inputs=self.data[index], data_samples=self.label[index]) |
|
|
|
|
|
|
|
|
class TestTeacherStudentValLoop(TestCase): |
|
|
|
|
|
def setUp(self): |
|
|
self.temp_dir = tempfile.TemporaryDirectory() |
|
|
|
|
|
def tearDown(self): |
|
|
self.temp_dir.cleanup() |
|
|
|
|
|
def test_teacher_student_val_loop(self): |
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
model = ToyModel2().to(device) |
|
|
evaluator = Mock() |
|
|
evaluator.evaluate = Mock(return_value=dict(acc=0.5)) |
|
|
evaluator.__class__ = Evaluator |
|
|
runner = Runner( |
|
|
model=model, |
|
|
train_dataloader=dict( |
|
|
dataset=dict(type='DummyDataset'), |
|
|
sampler=dict(type='DefaultSampler', shuffle=True), |
|
|
batch_size=3, |
|
|
num_workers=0), |
|
|
val_dataloader=dict( |
|
|
dataset=dict(type='DummyDataset'), |
|
|
sampler=dict(type='DefaultSampler', shuffle=False), |
|
|
batch_size=3, |
|
|
num_workers=0), |
|
|
val_evaluator=evaluator, |
|
|
work_dir=self.temp_dir.name, |
|
|
default_scope='mmdet', |
|
|
optim_wrapper=OptimWrapper( |
|
|
torch.optim.Adam(ToyModel().parameters())), |
|
|
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), |
|
|
val_cfg=dict(type='TeacherStudentValLoop'), |
|
|
default_hooks=dict(logger=dict(type='LoggerHook', interval=1)), |
|
|
experiment_name='test1') |
|
|
runner.train() |
|
|
|