Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Plot losses from log") | |
| parser.add_argument("--log-file", help="path to log file", required=True) | |
| parser.add_argument("--fake-weight", help="weight for fake loss", default=1.4, type=float) | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| with open(args.log_file, "r") as f: | |
| lines = f.readlines() | |
| real_losses = [] | |
| fake_losses = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith("fake_loss"): | |
| fake_losses.append(float(line.split(" ")[-1])) | |
| elif line.startswith("real_loss"): | |
| real_losses.append(float(line.split(" ")[-1])) | |
| real_losses = np.array(real_losses) | |
| fake_losses = np.array(fake_losses) | |
| loss = (fake_losses * args.fake_weight + real_losses)/2 | |
| plt.title("Weighted loss ({}*fake_loss + real_loss)/2)".format(args.fake_weight)) | |
| best_loss_idx = np.argsort(loss)[:5] | |
| # ignore early epochs loss is quite noisy and there could be spikes | |
| best_loss_idx = best_loss_idx[best_loss_idx > 16] | |
| plt.scatter(best_loss_idx, loss[best_loss_idx], c="red") | |
| for idx in best_loss_idx: | |
| plt.annotate(str(idx), (idx, loss[idx])) | |
| plt.plot(loss) | |
| plt.show() | |
| if __name__ == '__main__': | |
| main() | |