| |
| |
| |
| |
| |
| |
| |
| |
| import time |
| import torch |
|
|
|
|
| def str2bool(s): |
| return s.lower() in ('true', '1') |
|
|
|
|
| class Timer: |
| def __init__(self): |
| self.clock = {} |
|
|
| def start(self, key="default"): |
| self.clock[key] = time.time() |
|
|
| def end(self, key="default"): |
| if key not in self.clock: |
| raise Exception(f"{key} is not in the clock.") |
| interval = time.time() - self.clock[key] |
| del self.clock[key] |
| return interval |
| |
|
|
| def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path): |
| torch.save({ |
| 'epoch': epoch, |
| 'model': net_state_dict, |
| 'optimizer': optimizer_state_dict, |
| 'best_score': best_score |
| }, checkpoint_path) |
| torch.save(net_state_dict, model_path) |
| |
| |
| def load_checkpoint(checkpoint_path): |
| return torch.load(checkpoint_path) |
|
|
|
|
| def freeze_net_layers(net): |
| for param in net.parameters(): |
| param.requires_grad = False |
|
|
|
|
| def store_labels(path, labels): |
| with open(path, "w") as f: |
| f.write("\n".join(labels)) |
|
|