File size: 6,614 Bytes
2c44909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | #!/usr/bin/env python3
"""Print progressive PPL stats as CSV from progressive_metadata.json.
Expected (current) metadata shape:
- data["eval"]["pre_ppl"]
- data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key)
- data["cycles"][i]["comm_post_ppl"] (optional; current key)
- data["cycles"][i]["distill_post_ppl"]
- data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle)
- data["cycles"][i]["post_ppl"]
"""
import argparse
import csv
import json
import os
import shlex
import sys
from typing import Any, List, Optional
def _cell(value: Any) -> str:
if value is None:
return ""
if isinstance(value, dict):
if not value:
return ""
return ";".join(str(value[key]) for key in sorted(value))
if isinstance(value, (list, tuple)):
return ";".join(str(item) for item in value)
return str(value)
def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]:
meta_dir = os.path.dirname(os.path.abspath(metadata_path))
run_args_path = os.path.join(meta_dir, "run_args.txt")
if not os.path.exists(run_args_path):
return None
try:
with open(run_args_path, "r", encoding="utf-8") as handle:
lines = handle.read().splitlines()
except OSError:
return None
cmd_line = None
for idx, line in enumerate(lines):
if line.strip() == "command:":
if idx + 1 < len(lines):
cmd_line = lines[idx + 1].strip()
break
if not cmd_line:
return None
try:
return shlex.split(cmd_line)
except ValueError:
return None
def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]:
start = None
for idx, tok in enumerate(tokens):
if tok in ("--exclude_pairs", "--exclude_layers"):
start = idx + 1
break
if start is None:
return None
raw: List[int] = []
for tok in tokens[start:]:
if tok.startswith("--"):
break
# Legacy bug: run_args.txt used to print "python" before every token.
if tok == "python":
continue
try:
raw.append(int(tok))
except ValueError:
continue
return raw
def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]:
exclude: List[int] = []
for idx in raw:
if idx < 0:
idx = num_pairs + idx
if 0 <= idx < num_pairs:
exclude.append(idx)
return sorted(set(exclude))
def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]:
path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
try:
with open(path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None
dwce_meta = cycle_meta.get("dwce_meta") or {}
excluded = dwce_meta.get("excluded_pairs")
if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded):
return excluded
return None
def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]:
num_progressive = data.get("num_progressive")
final_num_layers = data.get("final_num_layers")
if isinstance(num_progressive, int) and isinstance(final_num_layers, int):
initial_layers = final_num_layers + num_progressive
return max(initial_layers - cycle_idx, 0)
cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
try:
with open(cycle_meta_path, "r", encoding="utf-8") as handle:
cycle_meta = json.load(handle)
except (FileNotFoundError, json.JSONDecodeError, OSError):
return None
num_layers_before = cycle_meta.get("num_layers_before")
if isinstance(num_layers_before, int):
return max(num_layers_before - 1, 0)
return None
def main() -> None:
parser = argparse.ArgumentParser(
description="Print progressive PPL values as CSV from progressive_metadata.json"
)
parser.add_argument("path", help="Path to progressive_metadata.json")
args = parser.parse_args()
try:
with open(args.path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except FileNotFoundError as exc:
raise SystemExit(f"File not found: {args.path}") from exc
except json.JSONDecodeError as exc:
raise SystemExit(f"Invalid JSON: {args.path}") from exc
meta_dir = os.path.dirname(os.path.abspath(args.path))
run_tokens = _read_run_command_tokens(args.path)
raw_exclude = (
_parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None
)
writer = csv.writer(sys.stdout)
writer.writerow(
[
"cycle",
"layer_merged",
"layer_pair",
"excluded_pairs",
"redistrib_post_ppl",
"distill_post_ppl",
"lora_post_ppl",
"post_ppl",
]
)
pre_ppl = data.get("eval", {}).get("pre_ppl")
if pre_ppl is not None:
writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)])
cycles = data.get("cycles") or data.get("cycle_summaries") or []
for cycle in cycles:
cycle_idx = cycle.get("cycle", "")
layer_merged = cycle.get("layer_merged")
layer_pair = ""
if isinstance(layer_merged, int):
layer_pair = f"{layer_merged}-{layer_merged + 1}"
excluded_pairs = _read_excluded_pairs_from_cycle_meta(
meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1
)
if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int):
num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx)
if isinstance(num_pairs, int):
excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs)
redistrib_post_ppl = cycle.get("redistrib_post_ppl")
if redistrib_post_ppl is None:
redistrib_post_ppl = cycle.get("comm_post_ppl")
writer.writerow(
[
cycle_idx,
layer_merged if layer_merged is not None else "",
layer_pair,
_cell(excluded_pairs),
_cell(redistrib_post_ppl),
_cell(cycle.get("distill_post_ppl")),
_cell(cycle.get("lora_post_ppl")),
_cell(cycle.get("post_ppl")),
]
)
if __name__ == "__main__":
main()
|