Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| import time | |
| class IterationCounter(): | |
| def modify_commandline_options(parser, is_train): | |
| parser.add_argument("--total_nimgs", default=25 * | |
| (1000 ** 2), type=int) | |
| parser.add_argument("--save_freq", default=50000, type=int) | |
| parser.add_argument("--evaluation_freq", default=50000, type=int) | |
| parser.add_argument("--print_freq", default=480, type=int) | |
| parser.add_argument("--display_freq", default=1600, type=int) | |
| return parser | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.iter_record_path = os.path.join( | |
| self.opt.checkpoints_dir, self.opt.name, 'iter.txt') | |
| self.steps_so_far = 0 | |
| if "unaligned" in opt.dataset_mode: | |
| self.batch_size = opt.batch_size * 2 | |
| else: | |
| self.batch_size = opt.batch_size | |
| self.time_measurements = {} | |
| automatically_find_resume_iter = opt.isTrain and opt.continue_train \ | |
| and opt.resume_iter == "latest" and opt.pretrained_name is None | |
| resume_at_specified_iter = opt.isTrain and opt.continue_train \ | |
| and opt.resume_iter.replace("k", "").isnumeric() | |
| if automatically_find_resume_iter: | |
| try: | |
| self.steps_so_far = np.loadtxt( | |
| self.iter_record_path, delimiter=',', dtype=int) | |
| print('Resuming from iteration %d' % (self.steps_so_far)) | |
| except Exception: | |
| print('Could not load iteration record at %s. ' | |
| 'Starting from beginning.' % self.iter_record_path) | |
| elif resume_at_specified_iter: | |
| steps = int(opt.resume_iter.replace("k", "")) | |
| if "k" in opt.resume_iter: | |
| steps *= 1000 | |
| self.steps_so_far = steps | |
| else: | |
| self.steps_so_far = 0 | |
| def record_one_iteration(self): | |
| if self.needs_saving(): | |
| np.savetxt(self.iter_record_path, | |
| [self.steps_so_far], delimiter=',', fmt='%d') | |
| print("Saved current iter count at %s" % self.iter_record_path) | |
| self.steps_so_far += self.batch_size | |
| def needs_saving(self): | |
| return (self.steps_so_far % self.opt.save_freq) < self.batch_size | |
| def needs_evaluation(self): | |
| return (self.steps_so_far >= self.opt.evaluation_freq) and \ | |
| ((self.steps_so_far % self.opt.evaluation_freq) < self.batch_size) | |
| def needs_printing(self): | |
| return (self.steps_so_far % self.opt.print_freq) < self.batch_size | |
| def needs_displaying(self): | |
| return (self.steps_so_far % self.opt.display_freq) < self.batch_size | |
| def completed_training(self): | |
| return (self.steps_so_far >= self.opt.total_nimgs) | |
| class TimeMeasurement: | |
| def __init__(self, name, parent): | |
| self.name = name | |
| self.parent = parent | |
| def __enter__(self): | |
| self.start_time = time.time() | |
| def __exit__(self, type, value, traceback): | |
| torch.cuda.synchronize() | |
| start_time = self.start_time | |
| elapsed_time = (time.time() - start_time) / self.parent.batch_size | |
| if self.name not in self.parent.time_measurements: | |
| self.parent.time_measurements[self.name] = elapsed_time | |
| else: | |
| prev_time = self.parent.time_measurements[self.name] | |
| updated_time = prev_time * 0.98 + elapsed_time * 0.02 | |
| self.parent.time_measurements[self.name] = updated_time | |
| def time_measurement(self, name): | |
| return IterationCounter.TimeMeasurement(name, self) | |