File size: 5,458 Bytes
2d05890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
"""Real-time tok/s + RSS sampler for run_thermal.sh.

Polls llama-server.log incrementally and the RSS log every `interval` seconds,
appending one row per interval to thermal_curve.csv:

    t_sec, tok_s_mean, tok_s_p10, tok_s_n, rss_gb

Exits cleanly after `duration` seconds. Writes a final summary to
thermal_curve.json with cold/sustained/throttle stats.
"""
from __future__ import annotations
import argparse
import json
import re
import statistics
import time
from pathlib import Path

EVAL_RE = re.compile(
    r"eval time\s*=\s*[\d.]+\s*ms\s*/\s*(\d+)\s*(?:tokens|runs)\s*"
    r"\(\s*[\d.]+\s*ms per token,\s*([\d.]+)\s*tokens per second\)",
    re.IGNORECASE,
)


def latest_rss_gb(rss_log: Path) -> float:
    if not rss_log.exists():
        return 0.0
    try:
        with rss_log.open() as f:
            tail = f.readlines()[-3:]
        for line in reversed(tail):
            parts = line.split()
            if len(parts) >= 2 and parts[1].isdigit():
                return int(parts[1]) / 1024 / 1024
    except Exception:
        pass
    return 0.0


def percentile(values, p):
    if not values:
        return 0.0
    s = sorted(values)
    idx = max(0, min(len(s) - 1, int(round((p / 100.0) * (len(s) - 1)))))
    return s[idx]


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--llama-log", type=Path, required=True)
    p.add_argument("--rss-log", type=Path, required=True)
    p.add_argument("--out-csv", type=Path, required=True)
    p.add_argument("--out-json", type=Path, required=True)
    p.add_argument("--interval", type=int, default=30, help="seconds per sample window")
    p.add_argument("--duration", type=int, default=2700, help="seconds to sample (default 45 min)")
    p.add_argument("--min-tokens", type=int, default=8,
                   help="filter eval lines with fewer tokens than this (skip trivial bursts)")
    args = p.parse_args()

    # Wait until llama-server log exists
    deadline = time.time() + 60
    while not args.llama_log.exists() and time.time() < deadline:
        time.sleep(1)
    if not args.llama_log.exists():
        print(f"llama log never appeared at {args.llama_log}", flush=True)
        return

    last_pos = 0
    with args.llama_log.open() as f:
        f.seek(0, 2)  # skip startup lines (model load, etc.)
        last_pos = f.tell()

    args.out_csv.parent.mkdir(parents=True, exist_ok=True)
    csv = args.out_csv.open("w", buffering=1)
    csv.write("t_sec,tok_s_mean,tok_s_median,tok_s_p10,tok_s_n,rss_gb\n")

    rows: list[dict] = []
    start = time.time()
    next_sample = start + args.interval

    while time.time() - start < args.duration:
        sleep_for = max(0.5, next_sample - time.time())
        time.sleep(sleep_for)

        # Read all new content since last poll
        try:
            with args.llama_log.open() as f:
                f.seek(last_pos)
                chunk = f.read()
                last_pos = f.tell()
        except FileNotFoundError:
            chunk = ""

        rates = []
        # Process line-by-line to filter out prompt-eval lines (which would
        # otherwise inflate decode tok/s by ~10x).
        for line in chunk.splitlines():
            if "prompt eval time" in line:
                continue
            m = EVAL_RE.search(line)
            if m:
                n_tok = int(m.group(1))
                tok_s = float(m.group(2))
                if n_tok >= args.min_tokens:
                    rates.append(tok_s)

        rss = latest_rss_gb(args.rss_log)
        t = round(time.time() - start, 1)
        if rates:
            mean = statistics.mean(rates)
            med = statistics.median(rates)
            p10 = percentile(rates, 10)
        else:
            mean = med = p10 = 0.0
        csv.write(f"{t:.1f},{mean:.2f},{med:.2f},{p10:.2f},{len(rates)},{rss:.3f}\n")
        rows.append({"t_sec": t, "tok_s_mean": mean, "tok_s_median": med,
                     "tok_s_p10": p10, "tok_s_n": len(rates), "rss_gb": rss})
        next_sample += args.interval

    csv.close()

    # Summary stats
    early = [r for r in rows if r["t_sec"] <= 60 and r["tok_s_n"] > 0]
    late = [r for r in rows[-min(len(rows), 5):] if r["tok_s_n"] > 0]
    all_rates = [r["tok_s_mean"] for r in rows if r["tok_s_n"] > 0]
    cold = max((r["tok_s_mean"] for r in rows[:3] if r["tok_s_n"] > 0), default=0.0)
    sustained = statistics.median([r["tok_s_mean"] for r in late]) if late else 0.0
    overall = statistics.median(all_rates) if all_rates else 0.0
    throttle_pct = (1 - sustained / cold) * 100 if cold > 0 else 0.0
    peak_rss = max((r["rss_gb"] for r in rows), default=0.0)

    summary = {
        "duration_sec": args.duration,
        "interval_sec": args.interval,
        "n_samples": len(rows),
        "tok_s_cold": round(cold, 2),
        "tok_s_sustained_last5": round(sustained, 2),
        "tok_s_median_overall": round(overall, 2),
        "throttle_pct_cold_to_sustained": round(throttle_pct, 1),
        "peak_rss_gb": round(peak_rss, 3),
        "samples": rows,
    }
    args.out_json.write_text(json.dumps(summary, indent=2))
    print(f"Wrote {args.out_csv} and {args.out_json}")
    print(f"  cold:      {cold:.1f} tok/s")
    print(f"  sustained: {sustained:.1f} tok/s (last 5 samples)")
    print(f"  throttle:  {throttle_pct:+.1f}%  (cold → sustained)")
    print(f"  peak rss:  {peak_rss:.2f} GB")


if __name__ == "__main__":
    main()