Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional | |
| import torch.nn as nn | |
| from mmengine.hooks import Hook | |
| from mmengine.model import is_model_wrapper | |
| from mmengine.runner import Runner | |
| from mmdet.registry import HOOKS | |
| class MeanTeacherHook(Hook): | |
| """Mean Teacher Hook. | |
| Mean Teacher is an efficient semi-supervised learning method in | |
| `Mean Teacher <https://arxiv.org/abs/1703.01780>`_. | |
| This method requires two models with exactly the same structure, | |
| as the student model and the teacher model, respectively. | |
| The student model updates the parameters through gradient descent, | |
| and the teacher model updates the parameters through | |
| exponential moving average of the student model. | |
| Compared with the student model, the teacher model | |
| is smoother and accumulates more knowledge. | |
| Args: | |
| momentum (float): The momentum used for updating teacher's parameter. | |
| Teacher's parameter are updated with the formula: | |
| `teacher = (1-momentum) * teacher + momentum * student`. | |
| Defaults to 0.001. | |
| interval (int): Update teacher's parameter every interval iteration. | |
| Defaults to 1. | |
| skip_buffers (bool): Whether to skip the model buffers, such as | |
| batchnorm running stats (running_mean, running_var), it does not | |
| perform the ema operation. Default to True. | |
| """ | |
| def __init__(self, | |
| momentum: float = 0.001, | |
| interval: int = 1, | |
| skip_buffer=True) -> None: | |
| assert 0 < momentum < 1 | |
| self.momentum = momentum | |
| self.interval = interval | |
| self.skip_buffers = skip_buffer | |
| def before_train(self, runner: Runner) -> None: | |
| """To check that teacher model and student model exist.""" | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| assert hasattr(model, 'teacher') | |
| assert hasattr(model, 'student') | |
| # only do it at initial stage | |
| if runner.iter == 0: | |
| self.momentum_update(model, 1) | |
| def after_train_iter(self, | |
| runner: Runner, | |
| batch_idx: int, | |
| data_batch: Optional[dict] = None, | |
| outputs: Optional[dict] = None) -> None: | |
| """Update teacher's parameter every self.interval iterations.""" | |
| if (runner.iter + 1) % self.interval != 0: | |
| return | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| self.momentum_update(model, self.momentum) | |
| def momentum_update(self, model: nn.Module, momentum: float) -> None: | |
| """Compute the moving average of the parameters using exponential | |
| moving average.""" | |
| if self.skip_buffers: | |
| for (src_name, src_parm), (dst_name, dst_parm) in zip( | |
| model.student.named_parameters(), | |
| model.teacher.named_parameters()): | |
| dst_parm.data.mul_(1 - momentum).add_( | |
| src_parm.data, alpha=momentum) | |
| else: | |
| for (src_parm, | |
| dst_parm) in zip(model.student.state_dict().values(), | |
| model.teacher.state_dict().values()): | |
| # exclude num_tracking | |
| if dst_parm.dtype.is_floating_point: | |
| dst_parm.data.mul_(1 - momentum).add_( | |
| src_parm.data, alpha=momentum) | |