temp_ss / src /print_progressive_ppl_csv.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Print progressive PPL stats as CSV from progressive_metadata.json.
Expected (current) metadata shape:
- data["eval"]["pre_ppl"]
- data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key)
- data["cycles"][i]["comm_post_ppl"] (optional; current key)
- data["cycles"][i]["distill_post_ppl"]
- data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle)
- data["cycles"][i]["post_ppl"]
"""
import argparse
import csv
import json
import os
import shlex
import sys
from typing import Any, List, Optional
def _cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, dict):
if not value:
return ""
return ";".join(str(value[key]) for key in sorted(value))
if isinstance(value, (list, tuple)):
return ";".join(str(item) for item in value)
return str(value)
def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]:
meta_dir = os.path.dirname(os.path.abspath(metadata_path))
run_args_path = os.path.join(meta_dir, "run_args.txt")
if not os.path.exists(run_args_path):
return None
try:
with open(run_args_path, "r", encoding="utf-8") as handle:
lines = handle.read().splitlines()
except OSError:
return None
cmd_line = None
for idx, line in enumerate(lines):
if line.strip() == "command:":
if idx + 1 < len(lines):
cmd_line = lines[idx + 1].strip()
break
if not cmd_line:
return None
try:
return shlex.split(cmd_line)
except ValueError:
return None
def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]:
start = None
for idx, tok in enumerate(tokens):
if tok in ("--exclude_pairs", "--exclude_layers"):
start = idx + 1
break
if start is None:
return None
raw: List[int] = []
for tok in tokens[start:]:
if tok.startswith("--"):
break
# Legacy bug: run_args.txt used to print "python" before every token.
if tok == "python":
continue
try:
raw.append(int(tok))
except ValueError:
continue
return raw
def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]:
exclude: List[int] = []
for idx in raw:
if idx < 0:
idx = num_pairs + idx
if 0 <= idx < num_pairs:
exclude.append(idx)
return sorted(set(exclude))
def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]:
path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
try:
with open(path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None
dwce_meta = cycle_meta.get("dwce_meta") or {}
excluded = dwce_meta.get("excluded_pairs")
if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded):
return excluded
return None
def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]:
num_progressive = data.get("num_progressive")
final_num_layers = data.get("final_num_layers")
if isinstance(num_progressive, int) and isinstance(final_num_layers, int):
initial_layers = final_num_layers + num_progressive
return max(initial_layers - cycle_idx, 0)
cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
try:
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None
num_layers_before = cycle_meta.get("num_layers_before")
if isinstance(num_layers_before, int):
return max(num_layers_before - 1, 0)
return None
def main() -> None:
parser = argparse.ArgumentParser(
description="Print progressive PPL values as CSV from progressive_metadata.json"
)
parser.add_argument("path", help="Path to progressive_metadata.json")
args = parser.parse_args()
try:
with open(args.path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except FileNotFoundError as exc:
raise SystemExit(f"File not found: {args.path}") from exc
except json.JSONDecodeError as exc:
raise SystemExit(f"Invalid JSON: {args.path}") from exc
meta_dir = os.path.dirname(os.path.abspath(args.path))
run_tokens = _read_run_command_tokens(args.path)
raw_exclude = (
_parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None
)
writer = csv.writer(sys.stdout)
writer.writerow(
[
"cycle",
"layer_merged",
"layer_pair",
"excluded_pairs",
"redistrib_post_ppl",
"distill_post_ppl",
"lora_post_ppl",
"post_ppl",
]
)
pre_ppl = data.get("eval", {}).get("pre_ppl")
if pre_ppl is not None:
writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)])
cycles = data.get("cycles") or data.get("cycle_summaries") or []
for cycle in cycles:
cycle_idx = cycle.get("cycle", "")
layer_merged = cycle.get("layer_merged")
layer_pair = ""
if isinstance(layer_merged, int):
layer_pair = f"{layer_merged}-{layer_merged + 1}"
excluded_pairs = _read_excluded_pairs_from_cycle_meta(
meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1
)
if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int):
num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx)
if isinstance(num_pairs, int):
excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs)
redistrib_post_ppl = cycle.get("redistrib_post_ppl")
if redistrib_post_ppl is None:
redistrib_post_ppl = cycle.get("comm_post_ppl")
writer.writerow(
[
cycle_idx,
layer_merged if layer_merged is not None else "",
layer_pair,
_cell(excluded_pairs),
_cell(redistrib_post_ppl),
_cell(cycle.get("distill_post_ppl")),
_cell(cycle.get("lora_post_ppl")),
_cell(cycle.get("post_ppl")),
]
)
if __name__ == "__main__":
main()