Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| def save_model(model, optimizer, state, path): | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module # save state dict of wrapped module | |
| if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)): | |
| os.makedirs(os.path.dirname(path)) | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'state': state, # state of training loop (was 'step') | |
| }, path) | |
| def load_model(model, optimizer, path, cuda): | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module # load state dict of wrapped module | |
| if cuda: | |
| checkpoint = torch.load(path) | |
| else: | |
| checkpoint = torch.load(path, map_location='cpu') | |
| try: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| except: | |
| # work-around for loading checkpoints where DataParallel was saved instead of inner module | |
| from collections import OrderedDict | |
| model_state_dict_fixed = OrderedDict() | |
| prefix = 'module.' | |
| for k, v in checkpoint['model_state_dict'].items(): | |
| if k.startswith(prefix): | |
| k = k[len(prefix):] | |
| model_state_dict_fixed[k] = v | |
| model.load_state_dict(model_state_dict_fixed) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| if 'state' in checkpoint: | |
| state = checkpoint['state'] | |
| else: | |
| # older checkpoints only store step, rest of state won't be there | |
| state = {'step': checkpoint['step']} | |
| return state | |
| def compute_loss(model, inputs, targets, criterion, compute_grad=False): | |
| ''' | |
| Computes gradients of model with given inputs and targets and loss function. | |
| Optionally backpropagates to compute gradients for weights. | |
| Procedure depends on whether we have one model for each source or not | |
| :param model: Model to train with | |
| :param inputs: Input mixture | |
| :param targets: Target sources | |
| :param criterion: Loss function to use (L1, L2, ..) | |
| :param compute_grad: Whether to compute gradients | |
| :return: Model outputs, Average loss over batch | |
| ''' | |
| all_outputs = {} | |
| if model.separate: | |
| avg_loss = 0.0 | |
| num_sources = 0 | |
| for inst in model.instruments: | |
| output = model(inputs, inst) | |
| loss = criterion(output[inst], targets[inst]) | |
| if compute_grad: | |
| loss.backward() | |
| avg_loss += loss.item() | |
| num_sources += 1 | |
| all_outputs[inst] = output[inst].detach().clone() | |
| avg_loss /= float(num_sources) | |
| else: | |
| loss = 0 | |
| all_outputs = model(inputs) | |
| for inst in all_outputs.keys(): | |
| loss += criterion(all_outputs[inst], targets[inst]) | |
| if compute_grad: | |
| loss.backward() | |
| avg_loss = loss.item() / float(len(all_outputs)) | |
| return all_outputs, avg_loss | |
| class DataParallel(torch.nn.DataParallel): | |
| def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
| super(DataParallel, self).__init__(module, device_ids, output_device, dim) | |
| def __getattr__(self, name): | |
| try: | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| return getattr(self.module, name) |