Spaces:
Runtime error
Runtime error
| import torch | |
| from models import MultiGPUModelWrapper | |
| from swapae.optimizers.swapping_autoencoder_optimizer import SwappingAutoencoderOptimizer | |
| import swapae.util | |
| class ClassifierOptimizer(SwappingAutoencoderOptimizer): | |
| def modify_commandline_options(parser, is_train): | |
| parser = SwappingAutoencoderOptimizer.modify_commandline_options(parser, is_train) | |
| return parser | |
| def train_one_step(self, data_i, total_steps_so_far): | |
| images_minibatch, labels = self.prepare_images(data_i) | |
| c_losses = self.train_classifier_one_step(images_minibatch, labels) | |
| self.adjust_lr_if_necessary(total_steps_so_far) | |
| return util.to_numpy(c_losses) | |
| def train_classifier_one_step(self, images, labels): | |
| self.set_requires_grad(self.Gparams, False) | |
| self.optimizer_C.zero_grad() | |
| losses, metrics = self.model(images, labels, command="compute_classifier_losses") | |
| loss = sum([v.mean() for v in losses.values()]) | |
| loss.backward() | |
| self.optimizer_C.step() | |
| losses.update(metrics) | |
| return losses | |
| def get_visuals_for_snapshot(self, data_i): | |
| images, labels = self.prepare_images(data_i) | |
| with torch.no_grad(): | |
| return self.model(images, labels, command="get_visuals_for_snapshot") | |
| def save(self, total_steps_so_far): | |
| self.model.save(total_steps_so_far) | |