| |
| """Print progressive PPL stats as CSV from progressive_metadata.json. |
| |
| Expected (current) metadata shape: |
| - data["eval"]["pre_ppl"] |
| - data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key) |
| - data["cycles"][i]["comm_post_ppl"] (optional; current key) |
| - data["cycles"][i]["distill_post_ppl"] |
| - data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle) |
| - data["cycles"][i]["post_ppl"] |
| """ |
|
|
| import argparse |
| import csv |
| import json |
| import os |
| import shlex |
| import sys |
| from typing import Any, List, Optional |
|
|
|
|
| def _cell(value: Any) -> str: |
| if value is None: |
| return "" |
| if isinstance(value, dict): |
| if not value: |
| return "" |
| return ";".join(str(value[key]) for key in sorted(value)) |
| if isinstance(value, (list, tuple)): |
| return ";".join(str(item) for item in value) |
| return str(value) |
|
|
|
|
| def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]: |
| meta_dir = os.path.dirname(os.path.abspath(metadata_path)) |
| run_args_path = os.path.join(meta_dir, "run_args.txt") |
| if not os.path.exists(run_args_path): |
| return None |
|
|
| try: |
| with open(run_args_path, "r", encoding="utf-8") as handle: |
| lines = handle.read().splitlines() |
| except OSError: |
| return None |
|
|
| cmd_line = None |
| for idx, line in enumerate(lines): |
| if line.strip() == "command:": |
| if idx + 1 < len(lines): |
| cmd_line = lines[idx + 1].strip() |
| break |
|
|
| if not cmd_line: |
| return None |
|
|
| try: |
| return shlex.split(cmd_line) |
| except ValueError: |
| return None |
|
|
|
|
| def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]: |
| start = None |
| for idx, tok in enumerate(tokens): |
| if tok in ("--exclude_pairs", "--exclude_layers"): |
| start = idx + 1 |
| break |
| if start is None: |
| return None |
|
|
| raw: List[int] = [] |
| for tok in tokens[start:]: |
| if tok.startswith("--"): |
| break |
| |
| if tok == "python": |
| continue |
| try: |
| raw.append(int(tok)) |
| except ValueError: |
| continue |
| return raw |
|
|
|
|
| def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]: |
| exclude: List[int] = [] |
| for idx in raw: |
| if idx < 0: |
| idx = num_pairs + idx |
| if 0 <= idx < num_pairs: |
| exclude.append(idx) |
| return sorted(set(exclude)) |
|
|
|
|
| def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]: |
| path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json") |
| try: |
| with open(path, "r", encoding="utf-8") as handle: |
| cycle_meta = json.load(handle) |
| except (FileNotFoundError, json.JSONDecodeError, OSError): |
| return None |
|
|
| dwce_meta = cycle_meta.get("dwce_meta") or {} |
| excluded = dwce_meta.get("excluded_pairs") |
| if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded): |
| return excluded |
| return None |
|
|
|
|
| def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]: |
| num_progressive = data.get("num_progressive") |
| final_num_layers = data.get("final_num_layers") |
| if isinstance(num_progressive, int) and isinstance(final_num_layers, int): |
| initial_layers = final_num_layers + num_progressive |
| return max(initial_layers - cycle_idx, 0) |
|
|
| cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json") |
| try: |
| with open(cycle_meta_path, "r", encoding="utf-8") as handle: |
| cycle_meta = json.load(handle) |
| except (FileNotFoundError, json.JSONDecodeError, OSError): |
| return None |
|
|
| num_layers_before = cycle_meta.get("num_layers_before") |
| if isinstance(num_layers_before, int): |
| return max(num_layers_before - 1, 0) |
| return None |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Print progressive PPL values as CSV from progressive_metadata.json" |
| ) |
| parser.add_argument("path", help="Path to progressive_metadata.json") |
| args = parser.parse_args() |
|
|
| try: |
| with open(args.path, "r", encoding="utf-8") as handle: |
| data = json.load(handle) |
| except FileNotFoundError as exc: |
| raise SystemExit(f"File not found: {args.path}") from exc |
| except json.JSONDecodeError as exc: |
| raise SystemExit(f"Invalid JSON: {args.path}") from exc |
|
|
| meta_dir = os.path.dirname(os.path.abspath(args.path)) |
| run_tokens = _read_run_command_tokens(args.path) |
| raw_exclude = ( |
| _parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None |
| ) |
|
|
| writer = csv.writer(sys.stdout) |
| writer.writerow( |
| [ |
| "cycle", |
| "layer_merged", |
| "layer_pair", |
| "excluded_pairs", |
| "redistrib_post_ppl", |
| "distill_post_ppl", |
| "lora_post_ppl", |
| "post_ppl", |
| ] |
| ) |
|
|
| pre_ppl = data.get("eval", {}).get("pre_ppl") |
| if pre_ppl is not None: |
| writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)]) |
|
|
| cycles = data.get("cycles") or data.get("cycle_summaries") or [] |
| for cycle in cycles: |
| cycle_idx = cycle.get("cycle", "") |
| layer_merged = cycle.get("layer_merged") |
| layer_pair = "" |
| if isinstance(layer_merged, int): |
| layer_pair = f"{layer_merged}-{layer_merged + 1}" |
|
|
| excluded_pairs = _read_excluded_pairs_from_cycle_meta( |
| meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1 |
| ) |
| if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int): |
| num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx) |
| if isinstance(num_pairs, int): |
| excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs) |
| redistrib_post_ppl = cycle.get("redistrib_post_ppl") |
| if redistrib_post_ppl is None: |
| redistrib_post_ppl = cycle.get("comm_post_ppl") |
|
|
| writer.writerow( |
| [ |
| cycle_idx, |
| layer_merged if layer_merged is not None else "", |
| layer_pair, |
| _cell(excluded_pairs), |
| _cell(redistrib_post_ppl), |
| _cell(cycle.get("distill_post_ppl")), |
| _cell(cycle.get("lora_post_ppl")), |
| _cell(cycle.get("post_ppl")), |
| ] |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|