Buckets:
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| """ | |
| Usage: | |
| python probe_animation.py /path/to/xp [/path/to/other/xp ...] | |
| Where `/path/to/xp` is a folder containing `probe/probe.0.jsonl` | |
| """ | |
| import math | |
| from pathlib import Path | |
| import json | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import matplotlib.animation as animation | |
| import matplotlib | |
| from matplotlib import rc | |
| from multiprocessing import Pool | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("probe_folders", help="contains a `probe` folder", nargs="+") | |
| args = parser.parse_args() | |
| rc("animation", html="jshtml") | |
| matplotlib.rcParams["animation.embed_limit"] = 2**128 | |
| DEFAULT_QUANTILES = [0.001, 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99, 0.999] | |
| DATAS_PER_FILE = {} | |
| datas = [] | |
| NUM_LAYERS = 0 # TODO: Deduce from files | |
| for f in args.probe_folders: | |
| name = Path(f).name | |
| print("Loading ", name) | |
| file = None | |
| for probe_test in ["probe/probe.0.jsonl", "probe.json"]: | |
| if (Path(f) / probe_test).exists(): | |
| file = Path(f) / probe_test | |
| assert file is not None, "Could not find probe json file" | |
| data = file.read_text() | |
| datas = [] | |
| for line in data.splitlines(): | |
| if line == "": | |
| continue | |
| datas.append(json.loads(line)) | |
| datas[-1].setdefault("quantiles", DEFAULT_QUANTILES) | |
| DATAS_PER_FILE[name] = datas | |
| # Assumes layers have the form | |
| # `FSDP.module.blocks.{LAYER_NUM}...` | |
| NUM_LAYERS = max( | |
| NUM_LAYERS, | |
| 1 | |
| + max( | |
| int(k.split("FSDPTransformer.layers.", 1)[1].split(".")[0]) | |
| for k in datas[0]["data"].keys() | |
| if k.startswith("FSDPTransformer.layers.") and k.endswith("::w") | |
| ), | |
| ) | |
| for d in datas: | |
| d["meta"]["it"] = d["meta"]["global_step"] | |
| assert NUM_LAYERS > 0, "Couldn't deduce the model depth" | |
| def get_mean_quantiles(df, names): | |
| means = [] | |
| quantiles = [[] for _ in df["quantiles"]] | |
| for name in names: | |
| d = df["data"][name] | |
| for qi, qval in enumerate(d["quantiles"]): | |
| quantiles[qi].append(qval) | |
| means.append(d["mean"]) | |
| return np.array(means), np.array(quantiles) | |
| # m, q = get_mean_quantiles(datas[0], ["FSDP.module.blocks.22.mlp.fc2::in", "FSDP.module.blocks.23.mlp.fc2::in"]) | |
| # possible_keys = {k.split("::")[0] for k in datas[0]["data"].keys() if "blocks.0." in k} | |
| # possible_suff = {k.split("::")[1] for k in datas[0]["data"].keys() if "blocks.0." in k} | |
| # print("\n".join([str(x) for x in possible_keys])) | |
| # print("\n".join([str(x) for x in possible_suff])) | |
| COLORS = [f"tab:{x}" for x in ["blue", "orange", "green", "red"]] | |
| class Plotter: | |
| def __init__(self, ax, run, layers, color, timesteps) -> None: | |
| datas = DATAS_PER_FILE[run] | |
| for i in timesteps: | |
| for layer in layers: | |
| if layer not in datas[i]["data"]: | |
| raise ValueError(f"Run `{run}`: layer `{layer}` not found!") | |
| self.x = np.arange(0, len(layers), 1) | |
| (self.mean,) = ax.plot(self.x, self.x, color=color, label=run) | |
| self.fills = [] | |
| self.animate_data = [get_mean_quantiles(datas[i], layers) for i in timesteps] | |
| self.iters = [datas[i]["meta"]["it"] for i in timesteps] | |
| self.minimum = min( | |
| [ | |
| min(np.nanmin(quants[3]), np.nanmin(means)) | |
| for means, quants in self.animate_data | |
| ] | |
| ) | |
| self.maximum = max( | |
| [ | |
| max(np.nanmax(quants[-4]), np.nanmax(means)) | |
| for means, quants in self.animate_data | |
| ] | |
| ) | |
| if not math.isfinite(self.minimum) or not math.isfinite(self.maximum): | |
| raise ValueError( | |
| f"Layer `{layers[0]}`: invalid min/max computed: {self.minimum}/{self.maximum}" | |
| ) | |
| self.ax = ax | |
| self.color = color | |
| self.run_name = run | |
| def animate(self, i): | |
| for f in self.fills: | |
| f.remove() | |
| self.fills.clear() | |
| means, quants = self.animate_data[i] | |
| self.fills += [ | |
| self.ax.fill_between( | |
| x=self.x, y1=quants[j], y2=quants[-1 - j], alpha=0.2, color=self.color | |
| ) | |
| for j in [0, 1, 2, 3, 4, 5] | |
| ] | |
| self.mean.set_ydata(means) | |
| self.mean.set_label(f"{self.run_name} it={self.iters[i]}") | |
| return self.mean | |
| def plot_depth_distr_time(layer_fmt, to_file=None, runs=None, subsample=8): | |
| plt.ioff() | |
| while isinstance(layer_fmt, str) or isinstance(layer_fmt[0], str): | |
| layer_fmt = [layer_fmt] | |
| if layer_fmt[0][0].format(0) not in datas[0]["data"].keys(): | |
| return | |
| if runs is None: | |
| runs = list(DATAS_PER_FILE.keys()) | |
| LAYERS = list(range(NUM_LAYERS)) | |
| nrows = len(layer_fmt) | |
| ncols = len(layer_fmt[0]) | |
| fig, axs = plt.subplots( | |
| nrows=nrows, | |
| ncols=ncols, | |
| squeeze=False, | |
| sharex=True, | |
| figsize=[min(6 * ncols, 14), min(5 * nrows, 8)], | |
| layout="compressed", | |
| ) | |
| timesteps = range(0, len(datas), subsample) | |
| plotters = {} | |
| for i in range(nrows): | |
| for j in range(ncols): | |
| print(layer_fmt[i][j]) | |
| plotters[(i, j)] = [ | |
| Plotter( | |
| axs[i, j], | |
| run, | |
| [layer_fmt[i][j].format(layer) for layer in LAYERS], | |
| color, | |
| timesteps, | |
| ) | |
| for run, color in zip(runs, COLORS) | |
| ] | |
| minimum = min(p.minimum for p in plotters[(i, j)]) | |
| maximum = max(p.maximum for p in plotters[(i, j)]) | |
| axs[i, j].set_ylim([minimum, maximum]) | |
| def animate(t): | |
| out = [] | |
| axs[0, 0].legend(loc="upper right") | |
| for k, v in plotters.items(): | |
| i, j = k | |
| axs[i, j].set_title(layer_fmt[i][j]) | |
| if i == nrows - 1: | |
| axs[i, j].set_xlabel("depth") | |
| for p in v: | |
| out.append(p.animate(t)) | |
| return out | |
| ani = animation.FuncAnimation( | |
| fig, animate, frames=len(timesteps), interval=500, blit=True | |
| ) | |
| if to_file is not None: | |
| print("Writing to", to_file) | |
| Path(to_file).write_text(ani.to_jshtml()) | |
| OUTPUT_FOLDER_NAME = "_AND_".join([Path(f).name for f in args.probe_folders]) | |
| RENDER_OUT_PATH = (Path("render_out") / OUTPUT_FOLDER_NAME).absolute() | |
| RENDER_OUT_PATH.mkdir(parents=True, exist_ok=True) | |
| def _render_attn(): | |
| to_file = RENDER_OUT_PATH / "attention.html" | |
| plot_depth_distr_time( | |
| [ | |
| [ | |
| "FSDPTransformer.layers.{}.attention::attn_logits", | |
| "FSDPTransformer.layers.{}.attention::attn_entropy", | |
| "FSDPTransformer.layers.{}.attention.wo::in", | |
| ], | |
| [ | |
| "FSDPTransformer.layers.{}.attention.wq::out", | |
| "FSDPTransformer.layers.{}.attention.wk::out", | |
| "FSDPTransformer.layers.{}.attention.wv::out", | |
| ], | |
| ], | |
| to_file=to_file, | |
| subsample=1, | |
| ) | |
| def _render_res(): | |
| print("## RESIDUAL") | |
| to_file = RENDER_OUT_PATH / "residual.html" | |
| plot_depth_distr_time( | |
| [ | |
| [ | |
| # "FSDPTransformer.layers.{}::res_ffn", | |
| # "FSDPTransformer.layers.{}::res_attn", | |
| "FSDPTransformer.layers.{}.feed_forward.w3::out", | |
| "FSDPTransformer.layers.{}.attention.wo::out", | |
| ], | |
| [ | |
| "FSDPTransformer.layers.{}::out", | |
| "FSDPTransformer.layers.{}::out.g", | |
| ], | |
| ], | |
| to_file=to_file, | |
| subsample=1, | |
| ) | |
| def _render_to_file(linear_layer: str): | |
| if linear_layer == "__res__": | |
| return _render_res() | |
| if linear_layer == "__attn__": | |
| return _render_attn() | |
| if f"{linear_layer.format(0)}::out" not in datas[0]["data"].keys(): | |
| return | |
| print("## ", linear_layer) | |
| to_file = linear_layer.split("{}", 1)[-1].replace(".", "") + ".html" | |
| to_file = RENDER_OUT_PATH / to_file | |
| plot_depth_distr_time( | |
| [ | |
| [f"{linear_layer}::{suffix}" for suffix in ["in", "w", "out"]], | |
| [f"{linear_layer}::{suffix}.g" for suffix in ["in", "w", "out"]], | |
| ], | |
| to_file=to_file, | |
| subsample=1, | |
| ) | |
| with Pool(10) as p: | |
| p.map( | |
| _render_to_file, | |
| [ | |
| "__attn__", | |
| # DINO | |
| "FSDP.module.blocks.{}.mlp.fc1", | |
| "FSDP.module.blocks.{}.mlp.fc2", | |
| "FSDP.module.blocks.{}.mlp.qkv", | |
| "FSDP.module.blocks.{}.mlp.proj", | |
| # linguas | |
| "FSDPTransformer.layers.{}.attention.wq", | |
| "FSDPTransformer.layers.{}.attention.wk", | |
| "FSDPTransformer.layers.{}.attention.wv", | |
| "FSDPTransformer.layers.{}.attention.wo", | |
| "FSDPTransformer.layers.{}.feed_forward.w1", | |
| "FSDPTransformer.layers.{}.feed_forward.w2", | |
| "FSDPTransformer.layers.{}.feed_forward.w3", | |
| "__res__", | |
| ], | |
| ) | |
Xet Storage Details
- Size:
- 9.1 kB
- Xet hash:
- 073d175bca0015324a05d4ea683d308a631330918e4cf934643040b749d0f67c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.