File size: 6,614 Bytes
2c44909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#!/usr/bin/env python3
"""Print progressive PPL stats as CSV from progressive_metadata.json.

Expected (current) metadata shape:
  - data["eval"]["pre_ppl"]
  - data["cycles"][i]["redistrib_post_ppl"] (optional; legacy key)
  - data["cycles"][i]["comm_post_ppl"] (optional; current key)
  - data["cycles"][i]["distill_post_ppl"]
  - data["cycles"][i]["lora_post_ppl"] (typically only set on the last cycle)
  - data["cycles"][i]["post_ppl"]
"""

import argparse
import csv
import json
import os
import shlex
import sys
from typing import Any, List, Optional


def _cell(value: Any) -> str:
    if value is None:
        return ""
    if isinstance(value, dict):
        if not value:
            return ""
        return ";".join(str(value[key]) for key in sorted(value))
    if isinstance(value, (list, tuple)):
        return ";".join(str(item) for item in value)
    return str(value)


def _read_run_command_tokens(metadata_path: str) -> Optional[List[str]]:
    meta_dir = os.path.dirname(os.path.abspath(metadata_path))
    run_args_path = os.path.join(meta_dir, "run_args.txt")
    if not os.path.exists(run_args_path):
        return None

    try:
        with open(run_args_path, "r", encoding="utf-8") as handle:
            lines = handle.read().splitlines()
    except OSError:
        return None

    cmd_line = None
    for idx, line in enumerate(lines):
        if line.strip() == "command:":
            if idx + 1 < len(lines):
                cmd_line = lines[idx + 1].strip()
            break

    if not cmd_line:
        return None

    try:
        return shlex.split(cmd_line)
    except ValueError:
        return None


def _parse_exclude_pairs_from_tokens(tokens: List[str]) -> Optional[List[int]]:
    start = None
    for idx, tok in enumerate(tokens):
        if tok in ("--exclude_pairs", "--exclude_layers"):
            start = idx + 1
            break
    if start is None:
        return None

    raw: List[int] = []
    for tok in tokens[start:]:
        if tok.startswith("--"):
            break
        # Legacy bug: run_args.txt used to print "python" before every token.
        if tok == "python":
            continue
        try:
            raw.append(int(tok))
        except ValueError:
            continue
    return raw


def _normalize_excluded_pairs(raw: List[int], num_pairs: int) -> List[int]:
    exclude: List[int] = []
    for idx in raw:
        if idx < 0:
            idx = num_pairs + idx
        if 0 <= idx < num_pairs:
            exclude.append(idx)
    return sorted(set(exclude))


def _read_excluded_pairs_from_cycle_meta(meta_dir: str, cycle_idx: int) -> Optional[List[int]]:
    path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
    try:
        with open(path, "r", encoding="utf-8") as handle:
            cycle_meta = json.load(handle)
    except (FileNotFoundError, json.JSONDecodeError, OSError):
        return None

    dwce_meta = cycle_meta.get("dwce_meta") or {}
    excluded = dwce_meta.get("excluded_pairs")
    if isinstance(excluded, list) and all(isinstance(x, int) for x in excluded):
        return excluded
    return None


def _num_pairs_for_cycle(data: dict, meta_dir: str, cycle_idx: int) -> Optional[int]:
    num_progressive = data.get("num_progressive")
    final_num_layers = data.get("final_num_layers")
    if isinstance(num_progressive, int) and isinstance(final_num_layers, int):
        initial_layers = final_num_layers + num_progressive
        return max(initial_layers - cycle_idx, 0)

    cycle_meta_path = os.path.join(meta_dir, f"cycle_{cycle_idx}", "cycle_metadata.json")
    try:
        with open(cycle_meta_path, "r", encoding="utf-8") as handle:
            cycle_meta = json.load(handle)
    except (FileNotFoundError, json.JSONDecodeError, OSError):
        return None

    num_layers_before = cycle_meta.get("num_layers_before")
    if isinstance(num_layers_before, int):
        return max(num_layers_before - 1, 0)
    return None


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Print progressive PPL values as CSV from progressive_metadata.json"
    )
    parser.add_argument("path", help="Path to progressive_metadata.json")
    args = parser.parse_args()

    try:
        with open(args.path, "r", encoding="utf-8") as handle:
            data = json.load(handle)
    except FileNotFoundError as exc:
        raise SystemExit(f"File not found: {args.path}") from exc
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Invalid JSON: {args.path}") from exc

    meta_dir = os.path.dirname(os.path.abspath(args.path))
    run_tokens = _read_run_command_tokens(args.path)
    raw_exclude = (
        _parse_exclude_pairs_from_tokens(run_tokens) if run_tokens is not None else None
    )

    writer = csv.writer(sys.stdout)
    writer.writerow(
        [
            "cycle",
            "layer_merged",
            "layer_pair",
            "excluded_pairs",
            "redistrib_post_ppl",
            "distill_post_ppl",
            "lora_post_ppl",
            "post_ppl",
        ]
    )

    pre_ppl = data.get("eval", {}).get("pre_ppl")
    if pre_ppl is not None:
        writer.writerow(["pre", "", "", "", "", "", "", _cell(pre_ppl)])

    cycles = data.get("cycles") or data.get("cycle_summaries") or []
    for cycle in cycles:
        cycle_idx = cycle.get("cycle", "")
        layer_merged = cycle.get("layer_merged")
        layer_pair = ""
        if isinstance(layer_merged, int):
            layer_pair = f"{layer_merged}-{layer_merged + 1}"

        excluded_pairs = _read_excluded_pairs_from_cycle_meta(
            meta_dir, cycle_idx if isinstance(cycle_idx, int) else -1
        )
        if excluded_pairs is None and raw_exclude is not None and isinstance(cycle_idx, int):
            num_pairs = _num_pairs_for_cycle(data, meta_dir, cycle_idx)
            if isinstance(num_pairs, int):
                excluded_pairs = _normalize_excluded_pairs(raw_exclude, num_pairs)
        redistrib_post_ppl = cycle.get("redistrib_post_ppl")
        if redistrib_post_ppl is None:
            redistrib_post_ppl = cycle.get("comm_post_ppl")

        writer.writerow(
            [
                cycle_idx,
                layer_merged if layer_merged is not None else "",
                layer_pair,
                _cell(excluded_pairs),
                _cell(redistrib_post_ppl),
                _cell(cycle.get("distill_post_ppl")),
                _cell(cycle.get("lora_post_ppl")),
                _cell(cycle.get("post_ppl")),
            ]
        )


if __name__ == "__main__":
    main()