Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
a5e85e1
·
1 Parent(s): 4bb42a5

feat: add RMSNorm benchmark scripts and K8s job

Browse files

- run_rms_bench.py: custom RMS benchmark for dims 512/1024/4096/16384
- run_bench.sh / run_and_wait.sh: K8s job apply + log streaming helpers
- benchmark_rms_optim.yaml: TrainJob manifest for B200 benchmarks
- bench_framework.py: minor fixes for benchmark runner

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

benchmarks/benchmark_rms_optim.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apiVersion: trainer.kubeflow.org/v1alpha1
2
+ kind: TrainJob
3
+ metadata:
4
+ name: jeesoo-rms-optim-v10
5
+ namespace: kbm-g-np-motif
6
+ spec:
7
+ managedBy: trainer.kubeflow.org/trainjob-controller
8
+ podTemplateOverrides:
9
+ - spec:
10
+ containers:
11
+ - name: node
12
+ volumeMounts:
13
+ - mountPath: /dev/shm
14
+ name: shm
15
+ - mountPath: /mair
16
+ name: mair
17
+ volumes:
18
+ - emptyDir:
19
+ medium: Memory
20
+ sizeLimit: 64Gi
21
+ name: shm
22
+ - name: mair
23
+ persistentVolumeClaim:
24
+ claimName: mair
25
+ targetJobs:
26
+ - name: node
27
+ runtimeRef:
28
+ apiGroup: trainer.kubeflow.org
29
+ kind: ClusterTrainingRuntime
30
+ name: torch-distributed
31
+ suspend: false
32
+ trainer:
33
+ args:
34
+ - /bin/bash
35
+ - '-c'
36
+ - >
37
+ ACTIVATIONPATH=/mair/team-sys/jeesoo/activation
38
+
39
+ pip install triton matplotlib pandas
40
+
41
+ echo "=== Building ==="
42
+
43
+ cd $ACTIVATIONPATH
44
+
45
+ pip uninstall -y activation 2>/dev/null; true
46
+ rm -rf $ACTIVATIONPATH/build/temp.linux-x86_64-cpython-312 $ACTIVATIONPATH/_activation*.so $ACTIVATIONPATH/*.egg-info
47
+
48
+ pip install --no-build-isolation --no-cache-dir -e . -v 2>&1 | tail -100
49
+
50
+ python -c "import _activation; print('Build OK:', _activation)" || { echo "BUILD FAILED"; exit 0; }
51
+
52
+ echo "=== Running RMS tests ==="
53
+
54
+ cd $ACTIVATIONPATH
55
+
56
+ python -m pytest tests/test_rms_norm.py -v 2>&1 | tail -40
57
+
58
+ echo "=== Warmup ==="
59
+
60
+ python -c "import torch; x=torch.randn(8192,1280,device='cuda',dtype=torch.bfloat16); [torch.mm(x.T,x) for _ in range(100)]; torch.cuda.synchronize(); print('warmup done')"
61
+
62
+ echo "=== RMS Benchmark ==="
63
+
64
+ cd $ACTIVATIONPATH/benchmarks
65
+
66
+ python run_rms_bench.py 2>&1 | tee results/rms_optim_log.txt
67
+
68
+ echo "=== Done ==="
69
+
70
+ exit 0;
71
+ env:
72
+ - name: PYTHONUNBUFFERED
73
+ value: '1'
74
+ - name: PYTORCH_ALLOC_CONF
75
+ value: expandable_segments:True
76
+ - name: CUDA_LAUNCH_BLOCKING
77
+ value: '0'
78
+ - name: OMP_NUM_THREADS
79
+ value: '1'
80
+ - name: HF_HOME
81
+ value: /mair/llm-dataset/hf_cache
82
+ image: ghcr.io/motiftechnologies/llm-training:v0.1.3
83
+ numNodes: 1
84
+ numProcPerNode: 1
85
+ resourcesPerNode:
86
+ limits:
87
+ cpu: '16'
88
+ memory: 128Gi
89
+ nvidia.com/gpu: '1'
90
+ requests:
91
+ cpu: '16'
92
+ memory: 128Gi
93
+ nvidia.com/gpu: '1'
benchmarks/common/bench_framework.py CHANGED
@@ -4,8 +4,8 @@ import re
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
7
- import triton
8
  from torch.profiler import ProfilerActivity, profile
 
9
 
10
  from .diff_engine import DiffCase
11
 
@@ -42,8 +42,8 @@ def _compute_bytes(inputs, forward_fn, obj):
42
  if isinstance(output, torch.Tensor):
43
  output_bytes = output.nbytes
44
  elif isinstance(output, (tuple, list)):
45
- output_bytes = sum(o.nbytes for o in output
46
- if isinstance(o, torch.Tensor))
47
  else:
48
  output_bytes = 0
49
  return input_bytes + output_bytes
@@ -158,9 +158,7 @@ def make_fwd_benchmark_for_case(
158
  key = make_fwd_key(dim, batch_size, seq_len)
159
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
160
  if provider == "speedup":
161
- return round(
162
- timings_ms["naive"][key] /
163
- _get_best_cuda_timing(timings_ms, key), 2)
164
  if provider.endswith("_bw"):
165
  base = provider[:-3]
166
  ms = timings_ms[base][key]
@@ -229,8 +227,7 @@ def make_fwd_benchmark_plot_for_case(
229
  ms = profile_bench(run, total_bytes=nbytes)
230
  timings_ms[provider][config] = ms
231
  if provider == "cuda":
232
- ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
233
- timings_ms, config)
234
  spdup_ratio.append(ratio)
235
  return round(ratio, 2)
236
  else:
@@ -270,9 +267,7 @@ def make_bwd_benchmark_for_case(
270
  key = make_bwd_key(dim, batch_size, seq_len)
271
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
272
  if provider == "speedup":
273
- return round(
274
- timings_ms["naive"][key] /
275
- _get_best_cuda_timing(timings_ms, key), 2)
276
  if provider.endswith("_bw"):
277
  base = provider[:-3]
278
  ms = timings_ms[base][key]
@@ -365,8 +360,7 @@ def make_bwd_benchmark_plot_for_case(
365
  ms = profile_bench(run, total_bytes=nbytes)
366
  timings_ms[provider][config] = ms
367
  if provider == "cuda":
368
- ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
369
- timings_ms, config)
370
  spdup_ratio.append(ratio)
371
  return round(ratio, 2)
372
  else:
 
4
  from typing import Any, Dict, Sequence
5
 
6
  import torch
 
7
  from torch.profiler import ProfilerActivity, profile
8
+ import triton
9
 
10
  from .diff_engine import DiffCase
11
 
 
42
  if isinstance(output, torch.Tensor):
43
  output_bytes = output.nbytes
44
  elif isinstance(output, (tuple, list)):
45
+ output_bytes = sum(
46
+ o.nbytes for o in output if isinstance(o, torch.Tensor))
47
  else:
48
  output_bytes = 0
49
  return input_bytes + output_bytes
 
158
  key = make_fwd_key(dim, batch_size, seq_len)
159
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
160
  if provider == "speedup":
161
+ return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
 
 
162
  if provider.endswith("_bw"):
163
  base = provider[:-3]
164
  ms = timings_ms[base][key]
 
227
  ms = profile_bench(run, total_bytes=nbytes)
228
  timings_ms[provider][config] = ms
229
  if provider == "cuda":
230
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
 
231
  spdup_ratio.append(ratio)
232
  return round(ratio, 2)
233
  else:
 
267
  key = make_bwd_key(dim, batch_size, seq_len)
268
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
269
  if provider == "speedup":
270
+ return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
 
 
271
  if provider.endswith("_bw"):
272
  base = provider[:-3]
273
  ms = timings_ms[base][key]
 
360
  ms = profile_bench(run, total_bytes=nbytes)
361
  timings_ms[provider][config] = ms
362
  if provider == "cuda":
363
+ ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
 
364
  spdup_ratio.append(ratio)
365
  return round(ratio, 2)
366
  else:
benchmarks/run_and_wait.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Usage: ./run_and_wait.sh <yaml-file> [poll-interval-seconds]
3
+ # Deletes existing job, applies yaml, waits for build+run, prints results.
4
+
5
+ set -euo pipefail
6
+
7
+ YAML="${1:?Usage: $0 <yaml-file> [poll-interval]}"
8
+ POLL="${2:-10}"
9
+
10
+ JOB_NAME=$(grep -m1 '^\s*name:' "$YAML" | awk '{print $2}')
11
+ NAMESPACE=$(grep -m1 '^\s*namespace:' "$YAML" | awk '{print $2}')
12
+ LABEL="batch.kubernetes.io/job-name=${JOB_NAME}-node-0"
13
+
14
+ echo "=== $JOB_NAME | $NAMESPACE ==="
15
+
16
+ # Remember old pods to ignore them
17
+ OLD_PODS=$(kubectl get pods -n "$NAMESPACE" -l "$LABEL" -o jsonpath='{.items[*].metadata.name}' 2>/dev/null || true)
18
+
19
+ # Delete if exists
20
+ kubectl delete -f "$YAML" 2>/dev/null && {
21
+ echo "Deleted old job. Waiting for cleanup..."
22
+ while kubectl get trainjob -n "$NAMESPACE" "$JOB_NAME" &>/dev/null; do sleep 2; done
23
+ sleep 3
24
+ }
25
+
26
+ # Apply
27
+ kubectl apply -f "$YAML"
28
+
29
+ # Wait for NEW pod (not in OLD_PODS)
30
+ echo -n "Waiting for new pod"
31
+ POD=""
32
+ while true; do
33
+ ALL_PODS=$(kubectl get pods -n "$NAMESPACE" -l "$LABEL" -o jsonpath='{.items[*].metadata.name}' 2>/dev/null || true)
34
+ for p in $ALL_PODS; do
35
+ if [[ ! " $OLD_PODS " =~ " $p " ]]; then
36
+ PHASE=$(kubectl get pod -n "$NAMESPACE" "$p" -o jsonpath='{.status.phase}' 2>/dev/null || true)
37
+ if [[ -n "$PHASE" ]]; then
38
+ POD="$p"
39
+ echo " $POD ($PHASE)"
40
+ break 2
41
+ fi
42
+ fi
43
+ done
44
+ echo -n "."
45
+ sleep "$POLL"
46
+ done
47
+
48
+ # Wait for pod to complete
49
+ echo -n "Running"
50
+ while true; do
51
+ PHASE=$(kubectl get pod -n "$NAMESPACE" "$POD" -o jsonpath='{.status.phase}' 2>/dev/null || echo "Gone")
52
+ if [[ "$PHASE" != "Running" && "$PHASE" != "Pending" ]]; then
53
+ echo " ($PHASE)"
54
+ break
55
+ fi
56
+ echo -n "."
57
+ sleep "$POLL"
58
+ done
59
+
60
+ # Print logs
61
+ echo ""
62
+ echo "=== LOGS ==="
63
+ kubectl logs -n "$NAMESPACE" "$POD" 2>/dev/null || echo "(no logs available)"
64
+
65
+ echo ""
66
+ echo "=== STATUS: $PHASE ==="
benchmarks/run_bench.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Usage: ./run_bench.sh <yaml-file> [poll-interval-seconds]
3
+ # Example: ./run_bench.sh benchmark_rms_optim.yaml 10
4
+
5
+ set -euo pipefail
6
+
7
+ YAML="${1:?Usage: $0 <yaml-file> [poll-interval]}"
8
+ POLL="${2:-10}"
9
+
10
+ if [[ ! -f "$YAML" ]]; then
11
+ echo "Error: $YAML not found"
12
+ exit 1
13
+ fi
14
+
15
+ # Extract job name and namespace from yaml
16
+ JOB_NAME=$(grep -m1 '^\s*name:' "$YAML" | awk '{print $2}')
17
+ NAMESPACE=$(grep -m1 '^\s*namespace:' "$YAML" | awk '{print $2}')
18
+
19
+ echo "=== Job: $JOB_NAME | Namespace: $NAMESPACE ==="
20
+
21
+ # Delete if exists
22
+ kubectl delete trainjob -n "$NAMESPACE" "$JOB_NAME" 2>/dev/null && {
23
+ echo "Deleted existing job, waiting 5s..."
24
+ sleep 5
25
+ }
26
+
27
+ # Apply
28
+ kubectl apply -f "$YAML"
29
+ echo "Applied. Polling every ${POLL}s..."
30
+
31
+ # Wait for pod
32
+ echo -n "Waiting for pod..."
33
+ while true; do
34
+ POD=$(kubectl get pods -n "$NAMESPACE" -l "batch.kubernetes.io/job-name=${JOB_NAME}-node-0" -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || true)
35
+ if [[ -n "$POD" && "$POD" != "" ]]; then
36
+ STATUS=$(kubectl get pod -n "$NAMESPACE" "$POD" -o jsonpath='{.status.phase}' 2>/dev/null || true)
37
+ if [[ "$STATUS" == "Running" || "$STATUS" == "Succeeded" || "$STATUS" == "Failed" ]]; then
38
+ echo " $POD ($STATUS)"
39
+ break
40
+ fi
41
+ fi
42
+ echo -n "."
43
+ sleep "$POLL"
44
+ done
45
+
46
+ # Stream logs until completion
47
+ echo "=== Streaming logs ==="
48
+ kubectl logs -n "$NAMESPACE" "$POD" -f 2>/dev/null || true
49
+
50
+ # Final status
51
+ echo ""
52
+ echo "=== Final Status ==="
53
+ kubectl get trainjob -n "$NAMESPACE" "$JOB_NAME" -o jsonpath='{.status.conditions[0].reason}' 2>/dev/null
54
+ echo ""
benchmarks/run_rms_bench.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick RMS benchmark with custom configs."""
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+
7
+ from common.bench_framework import (make_bwd_benchmark_for_case,
8
+ make_fwd_benchmark_for_case)
9
+ from common.diff_engine import calculate_diff
10
+
11
+ sys.path.insert(0, os.path.dirname(__file__))
12
+ from cases.rms import CASE
13
+
14
+ torch.set_default_device("cuda")
15
+
16
+ configs = [
17
+ (512, 8, 4096),
18
+ (1024, 8, 4096),
19
+ (4096, 8, 4096),
20
+ (16384, 8, 4096),
21
+ ]
22
+
23
+ # Correctness check
24
+ for dim, bs, sl in configs:
25
+ print(f"Correctness: bs={bs}, sl={sl}, D={dim}...", end=" ")
26
+ calculate_diff(CASE, batch_size=bs, seq_len=sl, hidden_size=dim)
27
+ print("ok")
28
+
29
+ print()
30
+
31
+ line_vals = ("naive", "naive_bw", "cuda", "cuda_bw", "speedup")
32
+ line_names = {
33
+ "naive": "Naive (us)",
34
+ "naive_bw": "Naive (GB/s)",
35
+ "cuda": "CUDA (us)",
36
+ "cuda_bw": "CUDA (GB/s)",
37
+ "speedup": "SpeedUp (ratio)",
38
+ }
39
+
40
+ save_dir = "./results/rms_custom"
41
+ os.makedirs(save_dir, exist_ok=True)
42
+
43
+ bench = make_fwd_benchmark_for_case(
44
+ case=CASE,
45
+ configs=configs,
46
+ plot_name="rms-bf16-fwd",
47
+ dtype=torch.bfloat16,
48
+ line_vals=line_vals,
49
+ line_names=line_names,
50
+ )
51
+ bench.run(print_data=True, save_path=save_dir)
52
+
53
+ bench = make_bwd_benchmark_for_case(
54
+ case=CASE,
55
+ configs=configs,
56
+ plot_name="rms-bf16-bwd",
57
+ dtype=torch.bfloat16,
58
+ line_vals=line_vals,
59
+ line_names=line_names,
60
+ )
61
+ bench.run(print_data=True, save_path=save_dir)