#!/usr/bin/env python3 """Print progressive PPL stats as CSV from progressive_metadata.json. Expected (current) metadata shape: - data["eval"]["pre_ppl"] - data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key) - data["cycles"][i]["comm_post_ppl"] (optional; current key) - data["cycles"][i]["distill_post_ppl"] - data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle) - data["cycles"][i]["post_ppl"] """ import argparse import csv import json import os import shlex import sys from typing import Any, List, Optional def _cell(value: Any) -> str: if value is None: return "" if isinstance(value, dict): if not value: return "" return ";".join(str(value[key]) for key in sorted(value)) if isinstance(value, (list, tuple)): return ";".join(str(item) for item in value) return str(value) def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]: meta_dir = os.path.dirname(os.path.abspath(metadata_path)) run_args_path = os.path.join(meta_dir, "run_args.txt") if not os.path.exists(run_args_path): return None try: with open(run_args_path, "r", encoding="utf-8") as handle: lines = handle.read().splitlines() except OSError: return None cmd_line = None for idx, line in enumerate(lines): if line.strip() == "command:": if idx + 1 < len(lines): cmd_line = lines[idx + 1].strip() break if not cmd_line: return None try: return shlex.split(cmd_line) except ValueError: return None def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]: start = None for idx, tok in enumerate(tokens): if tok in ("--exclude_pairs", "--exclude_layers"): start = idx + 1 break if start is None: return None raw: List[int] = [] for tok in tokens[start:]: if tok.startswith("--"): break # Legacy bug: run_args.txt used to print "python" before every token. if tok == "python": continue try: raw.append(int(tok)) except ValueError: continue return raw def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]: exclude: List[int] = [] for idx in raw: if idx < 0: idx = num_pairs + idx if 0 <= idx < num_pairs: exclude.append(idx) return sorted(set(exclude)) def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]: path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json") try: with open(path, "r", encoding="utf-8") as handle: cycle_meta = json.load(handle) except (FileNotFoundError, json.JSONDecodeError, OSError): return None dwce_meta = cycle_meta.get("dwce_meta") or {} excluded = dwce_meta.get("excluded_pairs") if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded): return excluded return None def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]: num_progressive = data.get("num_progressive") final_num_layers = data.get("final_num_layers") if isinstance(num_progressive, int) and isinstance(final_num_layers, int): initial_layers = final_num_layers + num_progressive return max(initial_layers - cycle_idx, 0) cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json") try: with open(cycle_meta_path, "r", encoding="utf-8") as handle: cycle_meta = json.load(handle) except (FileNotFoundError, json.JSONDecodeError, OSError): return None num_layers_before = cycle_meta.get("num_layers_before") if isinstance(num_layers_before, int): return max(num_layers_before - 1, 0) return None def main() -> None: parser = argparse.ArgumentParser( description="Print progressive PPL values as CSV from progressive_metadata.json" ) parser.add_argument("path", help="Path to progressive_metadata.json") args = parser.parse_args() try: with open(args.path, "r", encoding="utf-8") as handle: data = json.load(handle) except FileNotFoundError as exc: raise SystemExit(f"File not found: {args.path}") from exc except json.JSONDecodeError as exc: raise SystemExit(f"Invalid JSON: {args.path}") from exc meta_dir = os.path.dirname(os.path.abspath(args.path)) run_tokens = _read_run_command_tokens(args.path) raw_exclude = ( _parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None ) writer = csv.writer(sys.stdout) writer.writerow( [ "cycle", "layer_merged", "layer_pair", "excluded_pairs", "redistrib_post_ppl", "distill_post_ppl", "lora_post_ppl", "post_ppl", ] ) pre_ppl = data.get("eval", {}).get("pre_ppl") if pre_ppl is not None: writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)]) cycles = data.get("cycles") or data.get("cycle_summaries") or [] for cycle in cycles: cycle_idx = cycle.get("cycle", "") layer_merged = cycle.get("layer_merged") layer_pair = "" if isinstance(layer_merged, int): layer_pair = f"{layer_merged}-{layer_merged + 1}" excluded_pairs = _read_excluded_pairs_from_cycle_meta( meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1 ) if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int): num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx) if isinstance(num_pairs, int): excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs) redistrib_post_ppl = cycle.get("redistrib_post_ppl") if redistrib_post_ppl is None: redistrib_post_ppl = cycle.get("comm_post_ppl") writer.writerow( [ cycle_idx, layer_merged if layer_merged is not None else "", layer_pair, _cell(excluded_pairs), _cell(redistrib_post_ppl), _cell(cycle.get("distill_post_ppl")), _cell(cycle.get("lora_post_ppl")), _cell(cycle.get("post_ppl")), ] ) if __name__ == "__main__": main()