Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. | |
| # This program is free software; you can redistribute it and/or modify | |
| # it under the terms of the MIT License. | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # MIT License for more details. | |
| import os | |
| import glob | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| def intersperse(lst, item): | |
| # Adds blank symbol | |
| result = [item] * (len(lst) * 2 + 1) | |
| result[1::2] = lst | |
| return result | |
| def parse_filelist(filelist_path, split_char="|"): | |
| with open(filelist_path, encoding='utf-8') as f: | |
| filepaths_and_text = [line.strip().split(split_char) for line in f] | |
| return filepaths_and_text | |
| def latest_checkpoint_path(dir_path, regex="grad_*.pt"): | |
| f_list = glob.glob(os.path.join(dir_path, regex)) | |
| f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) | |
| x = f_list[-1] | |
| return x | |
| def load_checkpoint(logdir, model, num=None): | |
| if num is None: | |
| model_path = latest_checkpoint_path(logdir, regex="grad_*.pt") | |
| else: | |
| model_path = os.path.join(logdir, f"grad_{num}.pt") | |
| print(f'Loading checkpoint {model_path}...') | |
| model_dict = torch.load(model_path, map_location=lambda loc, storage: loc) | |
| model.load_state_dict(model_dict, strict=False) | |
| return model | |
| def save_figure_to_numpy(fig): | |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data | |
| def plot_tensor(tensor): | |
| plt.style.use('default') | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| data = save_figure_to_numpy(fig) | |
| plt.close() | |
| return data | |
| def save_plot(tensor, savepath): | |
| plt.style.use('default') | |
| fig, ax = plt.subplots(figsize=(12, 3)) | |
| im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| plt.savefig(savepath) | |
| plt.close() | |
| return | |