| from typing import List, Optional |
| import pathlib |
| import pandas as pd |
| import numpy as np |
| import numba |
| import click |
| import time |
| import collections |
| import json |
| import wandb |
| import yaml |
| import numbers |
| import scipy.ndimage as sn |
| from diffusion_policy.common.json_logger import read_json_log, JsonLogger |
| import logging |
|
|
| @numba.jit(nopython=True) |
| def get_indexed_window_average( |
| arr: np.ndarray, idxs: np.ndarray, window_size: int): |
| result = np.zeros(idxs.shape, dtype=arr.dtype) |
| length = arr.shape[0] |
| for i in range(len(idxs)): |
| idx = idxs[i] |
| start = max(idx - window_size, 0) |
| end = min(start + window_size, length) |
| result[i] = np.mean(arr[start:end]) |
| return result |
|
|
|
|
| def compute_metrics(log_df: pd.DataFrame, key: str, |
| end_step: Optional[int]=None, |
| k_min_loss: int=10, |
| k_around_max: int=10, |
| max_k_window: int=10, |
| replace_slash: int=True, |
| ): |
| if key not in log_df: |
| return dict() |
|
|
| |
| if end_step is not None: |
| log_df = log_df.iloc[:end_step] |
| is_key = ~pd.isnull(log_df[key]) |
| is_key_idxs = is_key.index[is_key].to_numpy() |
| if len(is_key_idxs) == 0: |
| return dict() |
|
|
| key_data = log_df[key][is_key].to_numpy() |
| |
| |
| |
| |
| train_loss = log_df['train_loss'][is_key].to_numpy() |
| val_loss = log_df['val_loss'][is_key].to_numpy() |
|
|
| result = dict() |
|
|
| log_key = key |
| if replace_slash: |
| log_key = key.replace('/', '_') |
| |
| max_value = np.max(key_data) |
| result['max/'+log_key] = max_value |
|
|
| |
| max_idx = np.argmax(key_data) |
| end = min(max_idx + k_around_max // 2, len(key_data)) |
| start = max(end - k_around_max, 0) |
| k_around_max_value = np.mean(key_data[start:end]) |
| result['k_around_max/'+log_key] = k_around_max_value |
|
|
| |
| k_window_value = sn.uniform_filter1d(key_data, size=max_k_window, axis=0, mode='nearest') |
| max_k_window_value = np.max(k_window_value) |
| result['max_k_window/'+log_key] = max_k_window_value |
|
|
| |
| min_idx = np.argmin(train_loss) |
| min_train_loss_value = key_data[min_idx] |
| result['min_train_loss/'+log_key] = min_train_loss_value |
|
|
| |
| min_idx = np.argmin(val_loss) |
| min_val_loss_value = key_data[min_idx] |
| result['min_val_loss/'+log_key] = min_val_loss_value |
|
|
| |
| min_loss_idxs = np.argsort(train_loss)[:k_min_loss] |
| k_min_train_loss_value = np.mean(key_data[min_loss_idxs]) |
| result['k_min_train_loss/'+log_key] = k_min_train_loss_value |
|
|
| |
| min_loss_idxs = np.argsort(val_loss)[:k_min_loss] |
| k_min_val_loss_value = np.mean(key_data[min_loss_idxs]) |
| result['k_min_val_loss/'+log_key] = k_min_val_loss_value |
|
|
| |
| result['last/'+log_key] = key_data[-1] |
|
|
| |
| result['metric_global_step/'+log_key] = is_key_idxs[-1] |
| return result |
|
|
|
|
| def compute_metrics_agg( |
| log_dfs: List[pd.DataFrame], |
| key: str, end_step:int, |
| **kwargs): |
| |
| |
| results = collections.defaultdict(list) |
| for log_df in log_dfs: |
| result = compute_metrics(log_df, key=key, end_step=end_step, **kwargs) |
| for k, v in result.items(): |
| results[k].append(v) |
| |
| agg_result = dict() |
| for k, v in results.items(): |
| value = np.mean(v) |
| if k.startswith('metric_global_step'): |
| value = int(value) |
| agg_result[k] = value |
| return agg_result |
|
|
|
|
| @click.command() |
| @click.option('--input', '-i', required=True, help='Root logging dir, contains train_* dirs') |
| @click.option('--key', '-k', multiple=True, default=['test/mean_score']) |
| @click.option('--interval', default=10, type=float) |
| @click.option('--replace_slash', default=True, type=bool) |
| @click.option('--index_key', '-ik', default='global_step') |
| @click.option('--use_wandb', '-w', is_flag=True, default=False) |
| @click.option('--project', default=None) |
| @click.option('--name', default=None) |
| @click.option('--id', default=None) |
| @click.option('--group', default=None) |
| def main( |
| input, |
| key, |
| interval, |
| replace_slash, |
| index_key, |
| use_wandb, |
| |
| project, |
| name, |
| id, |
| group): |
| root_dir = pathlib.Path(input) |
| assert root_dir.is_dir() |
| metrics_dir = root_dir.joinpath('metrics') |
| metrics_dir.mkdir(exist_ok=True) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| handlers=[ |
| logging.FileHandler(str(metrics_dir.joinpath("metrics.log"))), |
| logging.StreamHandler() |
| ] |
| ) |
| |
| train_dirs = list(root_dir.glob('train_*')) |
| log_files = [x.joinpath('logs.json.txt') for x in train_dirs] |
| logging.info("Monitor waiting for log files!") |
| while True: |
| |
| files_exist = True |
| for log_file in log_files: |
| if not log_file.is_file(): |
| files_exist = False |
| if files_exist: |
| break |
| time.sleep(1.0) |
| logging.info("All log files ready!") |
|
|
| |
| metric_log_path = metrics_dir.joinpath('logs.json.txt') |
| metric_path = metrics_dir.joinpath('metrics.json') |
| config_path = root_dir.joinpath('config.yaml') |
|
|
| |
| config = yaml.safe_load(config_path.open('r')) |
|
|
| |
| wandb_run = None |
| if use_wandb: |
| wandb_kwargs = config['logging'] |
| if project is not None: |
| wandb_kwargs['project'] = project |
| if id is not None: |
| wandb_kwargs['id'] = id |
| if name is not None: |
| wandb_kwargs['name'] = name |
| if group is not None: |
| wandb_kwargs['group'] = group |
| wandb_kwargs['resume'] = True |
| wandb_run = wandb.init( |
| dir=str(metrics_dir), |
| config=config, |
| |
| |
| |
| **wandb_kwargs |
| ) |
| wandb.config.update( |
| { |
| "output_dir": str(root_dir), |
| } |
| ) |
|
|
| with JsonLogger(metric_log_path) as json_logger: |
| last_log = json_logger.get_last_log() |
| while True: |
| |
| log_dfs = [read_json_log(str(x), required_keys=key) for x in log_files] |
|
|
| |
| last_log_idx = -1 |
| if last_log is not None: |
| last_log_idx = log_dfs[0].index[log_dfs[0][index_key] <= last_log[index_key]][-1] |
| |
| start_idx = last_log_idx + 1 |
| |
| end_idx = min(*[len(x) for x in log_dfs]) |
|
|
| |
| for this_idx in range(start_idx, end_idx): |
| |
| all_metrics = dict() |
| global_step = log_dfs[0]['global_step'][this_idx] |
| epoch = log_dfs[0]['epoch'][this_idx] |
| all_metrics['global_step'] = global_step |
| all_metrics['epoch'] = epoch |
| for k in key: |
| metrics = compute_metrics_agg( |
| log_dfs=log_dfs, key=k, end_step=this_idx+1, |
| replace_slash=replace_slash) |
| all_metrics.update(metrics) |
|
|
| |
| old_metrics = all_metrics |
| all_metrics = dict() |
| for k, v in old_metrics.items(): |
| if isinstance(v, numbers.Integral): |
| all_metrics[k] = int(v) |
| elif isinstance(v, numbers.Number): |
| all_metrics[k] = float(v) |
| |
| has_update = all_metrics != last_log |
| if has_update: |
| last_log = all_metrics |
| json_logger.log(all_metrics) |
|
|
| with metric_path.open('w') as f: |
| json.dump(all_metrics, f, sort_keys=True, indent=2) |
|
|
| if wandb_run is not None: |
| wandb_run.log(all_metrics, step=all_metrics[index_key]) |
|
|
| logging.info(f"Metrics logged at step {all_metrics[index_key]}") |
| |
| time.sleep(interval) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|