| |
|
|
| import glob |
| import json |
| from functools import partial |
| from multiprocessing import Pool |
| from pathlib import Path |
|
|
| import pandas as pd |
| import plotly.express as px |
| from omegaconf import OmegaConf |
|
|
|
|
| def parallel(func, files, num_workers=16): |
| with Pool(num_workers) as p: |
| results = p.map(func, files) |
| results = list(results) |
| |
| if len(results) > 0 and isinstance(results[0], list): |
| results = [item for sublist in results for item in sublist] |
|
|
| return results |
|
|
|
|
| def parallel_from_glob(func, glob_pattern, num_workers=16): |
| files = glob.glob(glob_pattern, recursive=True) |
| return parallel(partial(func), files, num_workers=num_workers) |
|
|
|
|
| def load_raw_json(path): |
| with open(path, "r") as f: |
| return json.load(f) |
|
|
|
|
| def load_raw_jsonl(jsonl_file): |
| metrics = [] |
|
|
| with open(jsonl_file, "r") as f: |
| for i, line in enumerate(f): |
| try: |
| json_obj = json.loads(line) |
| except json.decoder.JSONDecodeError as e: |
| print(f"Error decoding line {i+1} in file {jsonl_file}") |
|
|
| metrics.append(json_obj) |
|
|
| return metrics |
|
|
|
|
| def get_metrics(path): |
| results_dir = Path(path) |
|
|
| results = load_raw_jsonl(results_dir) |
| params = OmegaConf.load(results_dir.parent / "config.yaml") |
| params = OmegaConf.to_container(params, resolve=True) |
| df = pd.json_normalize( |
| [{"params": params, "metrics": res} for res in results], sep="/" |
| ) |
| return df |
|
|
|
|
| def get_merged_df(path): |
| dfs = parallel_from_glob(get_metrics, path, num_workers=80) |
| return pd.concat(dfs) |
|
|
|
|
| |
| df = get_merged_df("/path/to/metrics.jsonl") |
| fig = px.line( |
| df, |
| x="metrics/global_step", |
| y="metrics/loss/out", |
| ) |
| fig.update_yaxes(type="log") |
| fig.show() |
|
|
| |
|
|