File size: 5,751 Bytes
ae1180d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9aa4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae1180d
 
 
 
 
 
 
 
a9aa4ae
ae1180d
 
 
 
 
a9aa4ae
 
 
 
ae1180d
 
 
 
 
 
 
 
 
 
 
 
a9aa4ae
 
 
 
 
ae1180d
a9aa4ae
 
 
 
 
 
ae1180d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Shared runtime helpers for the GPU Goblin workload scripts.

Why this exists: ``goblin_runner.sh`` invokes user scripts with
``--max_steps=<N>`` and ``--torch_profile_out=<path>`` so rocprofv3 can
capture only a handful of training steps and ``profile_parser`` can read
real metrics back. Without ``--max_steps`` honored, scripts run for hours
and trip LiveRunner's timeout. Without the profile JSON, profile_parser
zeroes out ``tokens_per_sec`` / ``step_time_seconds`` and the agent has
nothing to reason about beyond config-shape alone.

Each workload script (``train_qwen_lora.py`` and the scenarios under
``scenarios/``) imports this module rather than copy-pasting the
argparse + profile-write boilerplate.

Usage:

    from workloads._runtime import parse_runtime_args, emit_torch_profile

    runtime_args = parse_runtime_args()

    ta_kwargs = dict(...)
    if runtime_args.max_steps > 0:
        ta_kwargs["max_steps"] = runtime_args.max_steps
        ta_kwargs["num_train_epochs"] = 1   # max_steps wins, but be explicit
    training_args = TrainingArguments(**ta_kwargs)

    if __name__ == "__main__":
        import time
        t0 = time.time()
        trainer.train()
        emit_torch_profile(
            runtime_args.torch_profile_out,
            elapsed=time.time() - t0,
            n_steps=int(trainer.state.global_step or runtime_args.max_steps),
            per_device_batch=training_args.per_device_train_batch_size,
            grad_accum=training_args.gradient_accumulation_steps,
            seq_len_cap=512,
        )
"""

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass


@dataclass
class RuntimeArgs:
    max_steps: int
    torch_profile_out: str


def parse_runtime_args() -> RuntimeArgs:
    """Parse ``--max_steps`` and ``--torch_profile_out`` from sys.argv.

    Uses ``parse_known_args`` so unrelated flags from libraries (HF Trainer,
    accelerate, deepspeed) pass through untouched.
    """
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument(
        "--max_steps",
        type=int,
        default=0,
        help=(
            "When >0, override TrainingArguments.max_steps so the script "
            "stops after this many optimization steps. Passed in by "
            "goblin_runner.sh β€” without it, profiling runs train for "
            "hours and time out."
        ),
    )
    parser.add_argument(
        "--torch_profile_out",
        type=str,
        default="",
        help=(
            "Path to write a minimal torch_profile.json (tokens/sec + step "
            "time) so runner/profile_parser populates RunMetrics with real "
            "numbers."
        ),
    )
    args, _ = parser.parse_known_args()
    return RuntimeArgs(
        max_steps=args.max_steps,
        torch_profile_out=args.torch_profile_out,
    )


# MI300X (CDNA3) peak throughput, dense, bf16/fp16 β€” both arrive at the
# same number on this arch since the matrix engine is the same. Source:
# AMD Instinct MI300X datasheet. With sparsity it's ~2.6 PFLOPS, but
# transformers training rarely hits the sparse path so we use dense as
# the realistic peak.
_MI300X_PEAK_FLOPS_DENSE_BF16 = 1.307e15

# FLOPs per token for forward + backward. The standard 6N approximation
# (forward 2N + backward 4N for full fine-tuning) slightly overestimates
# LoRA β€” pure LoRA backward only computes weight gradients for the small
# adapter matrices, not the frozen base β€” so true LoRA flops/token is
# closer to 4N. We use 6N as the conventional choice and accept a ~30%
# pessimistic MFU for LoRA. Still useful as a relative metric run-to-run.
_FLOPS_PER_TOKEN_FACTOR = 6


def emit_torch_profile(
    path: str,
    *,
    elapsed: float,
    n_steps: int,
    per_device_batch: int,
    grad_accum: int = 1,
    seq_len_cap: int = 512,
    model_params: int = 0,
) -> None:
    """Write the smallest torch_profile-shape JSON profile_parser will read.

    profile_parser._read_torch_profile looks for these top-level fields under
    ``metadata``: tokens_per_sec, mfu_pct, step_time_seconds, pytorch_version.

    `model_params` is optional β€” pass `sum(p.numel() for p in
    model.parameters())` from the workload to get a populated `mfu_pct`.
    Without it, mfu_pct stays unset (profile_parser will default to 0).

    No-ops when ``path`` is empty (script run outside goblin_runner.sh) or
    when ``n_steps`` is 0 (training crashed before finishing a step).
    """
    if not path or n_steps <= 0:
        return
    try:
        import torch  # local import β€” workload owns its own torch

        global_batch = max(1, per_device_batch) * max(1, grad_accum)
        approx_tokens = n_steps * global_batch * seq_len_cap
        tokens_per_sec = approx_tokens / elapsed if elapsed > 0 else 0.0
        metadata = {
            "tokens_per_sec": round(tokens_per_sec, 2),
            "step_time_seconds": round(elapsed / n_steps, 4),
            "pytorch_version": torch.__version__,
            "n_steps": n_steps,
        }
        if model_params > 0 and tokens_per_sec > 0:
            flops_per_token = _FLOPS_PER_TOKEN_FACTOR * model_params
            mfu_pct = (flops_per_token * tokens_per_sec) / _MI300X_PEAK_FLOPS_DENSE_BF16 * 100
            metadata["mfu_pct"] = round(mfu_pct, 2)
            metadata["model_params"] = model_params
        payload = {"metadata": metadata}
        with open(path, "w") as f:
            json.dump(payload, f)
    except Exception as exc:  # pragma: no cover β€” diagnostic only
        # Don't tank the run on a profile-emit failure; the agent will
        # just see "fake" metrics for this step instead of "live".
        print(f"[workloads._runtime] failed to write {path}: {exc}")