| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
| import re |
| from itertools import groupby |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| from mmpretrain.utils import load_json_log |
|
|
|
|
| def cal_train_time(log_dicts, args): |
| """Compute the average time per training iteration.""" |
| for i, log_dict in enumerate(log_dicts): |
| print(f'{"-" * 5}Analyze train time of {args.json_logs[i]}{"-" * 5}') |
| train_logs = log_dict['train'] |
|
|
| if 'epoch' in train_logs[0]: |
| epoch_ave_times = [] |
| for _, logs in groupby(train_logs, lambda log: log['epoch']): |
| if args.include_outliers: |
| all_time = np.array([log['time'] for log in logs]) |
| else: |
| all_time = np.array([log['time'] for log in logs])[1:] |
| epoch_ave_times.append(all_time.mean()) |
| epoch_ave_times = np.array(epoch_ave_times) |
| slowest_epoch = epoch_ave_times.argmax() |
| fastest_epoch = epoch_ave_times.argmin() |
| std_over_epoch = epoch_ave_times.std() |
| print(f'slowest epoch {slowest_epoch + 1}, ' |
| f'average time is {epoch_ave_times[slowest_epoch]:.4f}') |
| print(f'fastest epoch {fastest_epoch + 1}, ' |
| f'average time is {epoch_ave_times[fastest_epoch]:.4f}') |
| print(f'time std over epochs is {std_over_epoch:.4f}') |
|
|
| avg_iter_time = np.array([log['time'] for log in train_logs]).mean() |
| print(f'average iter time: {avg_iter_time:.4f} s/iter') |
| print() |
|
|
|
|
| def get_legends(args): |
| """if legend is None, use {filename}_{key} as legend.""" |
| legend = args.legend |
| if legend is None: |
| legend = [] |
| for json_log in args.json_logs: |
| for metric in args.keys: |
| |
| basename = os.path.basename(json_log)[:-5] |
| if basename.endswith('.log'): |
| basename = basename[:-4] |
| legend.append(f'{basename}_{metric}') |
| assert len(legend) == (len(args.json_logs) * len(args.keys)) |
| return legend |
|
|
|
|
| def plot_phase_train(metric, train_logs, curve_label): |
| """plot phase of train curve.""" |
| xs = np.array([log['step'] for log in train_logs]) |
| ys = np.array([log[metric] for log in train_logs]) |
|
|
| if 'epoch' in train_logs[0]: |
| scale_factor = train_logs[-1]['step'] / train_logs[-1]['epoch'] |
| xs = xs / scale_factor |
| plt.xlabel('Epochs') |
| else: |
| plt.xlabel('Iters') |
|
|
| plt.plot(xs, ys, label=curve_label, linewidth=0.75) |
|
|
|
|
| def plot_phase_val(metric, val_logs, curve_label): |
| """plot phase of val curve.""" |
| xs = np.array([log['step'] for log in val_logs]) |
| ys = np.array([log[metric] for log in val_logs]) |
|
|
| plt.xlabel('Steps') |
| plt.plot(xs, ys, label=curve_label, linewidth=0.75) |
|
|
|
|
| def plot_curve_helper(log_dicts, metrics, args, legend): |
| """plot curves from log_dicts by metrics.""" |
| num_metrics = len(metrics) |
| for i, log_dict in enumerate(log_dicts): |
| for j, key in enumerate(metrics): |
| json_log = args.json_logs[i] |
| print(f'plot curve of {json_log}, metric is {key}') |
| curve_label = legend[i * num_metrics + j] |
|
|
| train_keys = {} if len(log_dict['train']) == 0 else set( |
| log_dict['train'][0].keys()) - {'step', 'epoch'} |
| val_keys = {} if len(log_dict['val']) == 0 else set( |
| log_dict['val'][0].keys()) - {'step'} |
|
|
| if key in val_keys: |
| plot_phase_val(key, log_dict['val'], curve_label) |
| elif key in train_keys: |
| plot_phase_train(key, log_dict['train'], curve_label) |
| else: |
| raise ValueError( |
| f'Invalid key "{key}", please choose from ' |
| f'{set.union(set(train_keys), set(val_keys))}.') |
| plt.legend() |
|
|
|
|
| def plot_curve(log_dicts, args): |
| """Plot train metric-iter graph.""" |
| |
| try: |
| import seaborn as sns |
| sns.set_style(args.style) |
| except ImportError: |
| pass |
|
|
| |
| wind_w, wind_h = args.window_size.split('*') |
| wind_w, wind_h = int(wind_w), int(wind_h) |
| plt.figure(figsize=(wind_w, wind_h)) |
|
|
| |
| legends = get_legends(args) |
| metrics = args.keys |
|
|
| |
| plot_curve_helper(log_dicts, metrics, args, legends) |
|
|
| |
| if args.title is not None: |
| plt.title(args.title) |
| if args.out is None: |
| plt.show() |
| else: |
| print(f'save curve to: {args.out}') |
| plt.savefig(args.out) |
| plt.cla() |
|
|
|
|
| def add_plot_parser(subparsers): |
| parser_plt = subparsers.add_parser( |
| 'plot_curve', help='parser for plotting curves') |
| parser_plt.add_argument( |
| 'json_logs', |
| type=str, |
| nargs='+', |
| help='path of train log in json format') |
| parser_plt.add_argument( |
| '--keys', |
| type=str, |
| nargs='+', |
| default=['loss'], |
| help='the metric that you want to plot') |
| parser_plt.add_argument('--title', type=str, help='title of figure') |
| parser_plt.add_argument( |
| '--legend', |
| type=str, |
| nargs='+', |
| default=None, |
| help='legend of each plot') |
| parser_plt.add_argument( |
| '--style', |
| type=str, |
| default='whitegrid', |
| help='style of the figure, need `seaborn` package.') |
| parser_plt.add_argument('--out', type=str, default=None) |
| parser_plt.add_argument( |
| '--window-size', |
| default='12*7', |
| help='size of the window to display images, in format of "$W*$H".') |
|
|
|
|
| def add_time_parser(subparsers): |
| parser_time = subparsers.add_parser( |
| 'cal_train_time', |
| help='parser for computing the average time per training iteration') |
| parser_time.add_argument( |
| 'json_logs', |
| type=str, |
| nargs='+', |
| help='path of train log in json format') |
| parser_time.add_argument( |
| '--include-outliers', |
| action='store_true', |
| help='include the first value of every epoch when computing ' |
| 'the average time') |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Analyze Json Log') |
| |
| subparsers = parser.add_subparsers(dest='task', help='task parser') |
| add_plot_parser(subparsers) |
| add_time_parser(subparsers) |
| args = parser.parse_args() |
|
|
| if hasattr(args, 'window_size') and args.window_size != '': |
| assert re.match(r'\d+\*\d+', args.window_size), \ |
| "'window-size' must be in format 'W*H'." |
| return args |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| json_logs = args.json_logs |
| for json_log in json_logs: |
| assert json_log.endswith('.json') |
|
|
| log_dicts = [load_json_log(json_log) for json_log in json_logs] |
|
|
| if args.task == 'cal_train_time': |
| cal_train_time(log_dicts, args) |
| elif args.task == 'plot_curve': |
| plot_curve(log_dicts, args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|