| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import torch
|
| | import matplotlib.pyplot as plt
|
| |
|
| |
|
| | class Chart:
|
| | def __init__(self):
|
| | self.loss_list = []
|
| |
|
| | def add_ckpt(self, ckpt_path, line_name):
|
| | ckpt = torch.load(ckpt_path, map_location="cpu")
|
| | train_step_list = ckpt["train_step_list"]
|
| | train_loss_list = ckpt["train_loss_list"]
|
| | val_step_list = ckpt["val_step_list"]
|
| | val_loss_list = ckpt["val_loss_list"]
|
| | val_step_list = [val_step_list[0]] + val_step_list[4::5]
|
| | val_loss_list = [val_loss_list[0]] + val_loss_list[4::5]
|
| | self.loss_list.append((line_name, train_step_list, train_loss_list, val_step_list, val_loss_list))
|
| |
|
| | def draw(self, save_path, plot_val=True):
|
| |
|
| | plt.rcParams["font.size"] = 14
|
| | plt.rcParams["font.family"] = "serif"
|
| | plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans", "Lucida Grande"]
|
| | plt.rcParams["font.serif"] = ["Times New Roman", "DejaVu Serif"]
|
| |
|
| |
|
| | plt.figure(figsize=(7.766, 4.8))
|
| | for loss in self.loss_list:
|
| | if plot_val:
|
| | (line,) = plt.plot(loss[1], loss[2], label=loss[0], linewidth=0.5, alpha=0.5)
|
| | line_color = line.get_color()
|
| | plt.plot(loss[3], loss[4], linewidth=1.5, color=line_color)
|
| | else:
|
| | plt.plot(loss[1], loss[2], label=loss[0], linewidth=1)
|
| | plt.xlabel("Step")
|
| | plt.ylabel("Loss")
|
| | legend = plt.legend()
|
| |
|
| |
|
| |
|
| | for line in legend.get_lines():
|
| | line.set_linewidth(2)
|
| |
|
| | plt.savefig(save_path, transparent=True)
|
| | plt.close()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | chart = Chart()
|
| |
|
| |
|
| | chart.add_ckpt("output/syncnet/train-2024_10_24-21:03:11/checkpoints/checkpoint-10000.pt", "Dim 512")
|
| | chart.add_ckpt("output/syncnet/train-2024_10_25-18:21:59/checkpoints/checkpoint-10000.pt", "Dim 2048")
|
| | chart.add_ckpt("output/syncnet/train-2024_10_24-22:37:04/checkpoints/checkpoint-10000.pt", "Dim 4096")
|
| | chart.add_ckpt("output/syncnet/train-2024_10_25-02:30:17/checkpoints/checkpoint-10000.pt", "Dim 6144")
|
| | chart.draw("ablation.pdf", plot_val=True)
|
| |
|