Spaces:
Build error
Build error
| from detectron2.engine import AMPTrainer | |
| import torch | |
| import time | |
| def cycle(iterable): | |
| while True: | |
| for x in iterable: | |
| yield x | |
| class MattingTrainer(AMPTrainer): | |
| def __init__(self, model, data_loader, optimizer, grad_scaler=None): | |
| super().__init__(model, data_loader, optimizer, grad_scaler=None) | |
| self.data_loader_iter = iter(cycle(self.data_loader)) | |
| def run_step(self): | |
| """ | |
| Implement the AMP training logic. | |
| """ | |
| assert self.model.training, "[AMPTrainer] model was changed to eval mode!" | |
| assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" | |
| from torch.cuda.amp import autocast | |
| #matting pass | |
| start = time.perf_counter() | |
| data = next(self.data_loader_iter) | |
| data_time = time.perf_counter() - start | |
| with autocast(): | |
| loss_dict = self.model(data) | |
| if isinstance(loss_dict, torch.Tensor): | |
| losses = loss_dict | |
| loss_dict = {"total_loss": loss_dict} | |
| else: | |
| losses = sum(loss_dict.values()) | |
| self.optimizer.zero_grad() | |
| self.grad_scaler.scale(losses).backward() | |
| self._write_metrics(loss_dict, data_time) | |
| self.grad_scaler.step(self.optimizer) | |
| self.grad_scaler.update() |