feat: add RMSNorm benchmark scripts and K8s job
Browse files- run_rms_bench.py: custom RMS benchmark for dims 512/1024/4096/16384
- run_bench.sh / run_and_wait.sh: K8s job apply + log streaming helpers
- benchmark_rms_optim.yaml: TrainJob manifest for B200 benchmarks
- bench_framework.py: minor fixes for benchmark runner
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- benchmarks/benchmark_rms_optim.yaml +93 -0
- benchmarks/common/bench_framework.py +7 -13
- benchmarks/run_and_wait.sh +66 -0
- benchmarks/run_bench.sh +54 -0
- benchmarks/run_rms_bench.py +61 -0
benchmarks/benchmark_rms_optim.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apiVersion: trainer.kubeflow.org/v1alpha1
|
| 2 |
+
kind: TrainJob
|
| 3 |
+
metadata:
|
| 4 |
+
name: jeesoo-rms-optim-v10
|
| 5 |
+
namespace: kbm-g-np-motif
|
| 6 |
+
spec:
|
| 7 |
+
managedBy: trainer.kubeflow.org/trainjob-controller
|
| 8 |
+
podTemplateOverrides:
|
| 9 |
+
- spec:
|
| 10 |
+
containers:
|
| 11 |
+
- name: node
|
| 12 |
+
volumeMounts:
|
| 13 |
+
- mountPath: /dev/shm
|
| 14 |
+
name: shm
|
| 15 |
+
- mountPath: /mair
|
| 16 |
+
name: mair
|
| 17 |
+
volumes:
|
| 18 |
+
- emptyDir:
|
| 19 |
+
medium: Memory
|
| 20 |
+
sizeLimit: 64Gi
|
| 21 |
+
name: shm
|
| 22 |
+
- name: mair
|
| 23 |
+
persistentVolumeClaim:
|
| 24 |
+
claimName: mair
|
| 25 |
+
targetJobs:
|
| 26 |
+
- name: node
|
| 27 |
+
runtimeRef:
|
| 28 |
+
apiGroup: trainer.kubeflow.org
|
| 29 |
+
kind: ClusterTrainingRuntime
|
| 30 |
+
name: torch-distributed
|
| 31 |
+
suspend: false
|
| 32 |
+
trainer:
|
| 33 |
+
args:
|
| 34 |
+
- /bin/bash
|
| 35 |
+
- '-c'
|
| 36 |
+
- >
|
| 37 |
+
ACTIVATIONPATH=/mair/team-sys/jeesoo/activation
|
| 38 |
+
|
| 39 |
+
pip install triton matplotlib pandas
|
| 40 |
+
|
| 41 |
+
echo "=== Building ==="
|
| 42 |
+
|
| 43 |
+
cd $ACTIVATIONPATH
|
| 44 |
+
|
| 45 |
+
pip uninstall -y activation 2>/dev/null; true
|
| 46 |
+
rm -rf $ACTIVATIONPATH/build/temp.linux-x86_64-cpython-312 $ACTIVATIONPATH/_activation*.so $ACTIVATIONPATH/*.egg-info
|
| 47 |
+
|
| 48 |
+
pip install --no-build-isolation --no-cache-dir -e . -v 2>&1 | tail -100
|
| 49 |
+
|
| 50 |
+
python -c "import _activation; print('Build OK:', _activation)" || { echo "BUILD FAILED"; exit 0; }
|
| 51 |
+
|
| 52 |
+
echo "=== Running RMS tests ==="
|
| 53 |
+
|
| 54 |
+
cd $ACTIVATIONPATH
|
| 55 |
+
|
| 56 |
+
python -m pytest tests/test_rms_norm.py -v 2>&1 | tail -40
|
| 57 |
+
|
| 58 |
+
echo "=== Warmup ==="
|
| 59 |
+
|
| 60 |
+
python -c "import torch; x=torch.randn(8192,1280,device='cuda',dtype=torch.bfloat16); [torch.mm(x.T,x) for _ in range(100)]; torch.cuda.synchronize(); print('warmup done')"
|
| 61 |
+
|
| 62 |
+
echo "=== RMS Benchmark ==="
|
| 63 |
+
|
| 64 |
+
cd $ACTIVATIONPATH/benchmarks
|
| 65 |
+
|
| 66 |
+
python run_rms_bench.py 2>&1 | tee results/rms_optim_log.txt
|
| 67 |
+
|
| 68 |
+
echo "=== Done ==="
|
| 69 |
+
|
| 70 |
+
exit 0;
|
| 71 |
+
env:
|
| 72 |
+
- name: PYTHONUNBUFFERED
|
| 73 |
+
value: '1'
|
| 74 |
+
- name: PYTORCH_ALLOC_CONF
|
| 75 |
+
value: expandable_segments:True
|
| 76 |
+
- name: CUDA_LAUNCH_BLOCKING
|
| 77 |
+
value: '0'
|
| 78 |
+
- name: OMP_NUM_THREADS
|
| 79 |
+
value: '1'
|
| 80 |
+
- name: HF_HOME
|
| 81 |
+
value: /mair/llm-dataset/hf_cache
|
| 82 |
+
image: ghcr.io/motiftechnologies/llm-training:v0.1.3
|
| 83 |
+
numNodes: 1
|
| 84 |
+
numProcPerNode: 1
|
| 85 |
+
resourcesPerNode:
|
| 86 |
+
limits:
|
| 87 |
+
cpu: '16'
|
| 88 |
+
memory: 128Gi
|
| 89 |
+
nvidia.com/gpu: '1'
|
| 90 |
+
requests:
|
| 91 |
+
cpu: '16'
|
| 92 |
+
memory: 128Gi
|
| 93 |
+
nvidia.com/gpu: '1'
|
benchmarks/common/bench_framework.py
CHANGED
|
@@ -4,8 +4,8 @@ import re
|
|
| 4 |
from typing import Any, Dict, Sequence
|
| 5 |
|
| 6 |
import torch
|
| 7 |
-
import triton
|
| 8 |
from torch.profiler import ProfilerActivity, profile
|
|
|
|
| 9 |
|
| 10 |
from .diff_engine import DiffCase
|
| 11 |
|
|
@@ -42,8 +42,8 @@ def _compute_bytes(inputs, forward_fn, obj):
|
|
| 42 |
if isinstance(output, torch.Tensor):
|
| 43 |
output_bytes = output.nbytes
|
| 44 |
elif isinstance(output, (tuple, list)):
|
| 45 |
-
output_bytes = sum(
|
| 46 |
-
|
| 47 |
else:
|
| 48 |
output_bytes = 0
|
| 49 |
return input_bytes + output_bytes
|
|
@@ -158,9 +158,7 @@ def make_fwd_benchmark_for_case(
|
|
| 158 |
key = make_fwd_key(dim, batch_size, seq_len)
|
| 159 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 160 |
if provider == "speedup":
|
| 161 |
-
return round(
|
| 162 |
-
timings_ms["naive"][key] /
|
| 163 |
-
_get_best_cuda_timing(timings_ms, key), 2)
|
| 164 |
if provider.endswith("_bw"):
|
| 165 |
base = provider[:-3]
|
| 166 |
ms = timings_ms[base][key]
|
|
@@ -229,8 +227,7 @@ def make_fwd_benchmark_plot_for_case(
|
|
| 229 |
ms = profile_bench(run, total_bytes=nbytes)
|
| 230 |
timings_ms[provider][config] = ms
|
| 231 |
if provider == "cuda":
|
| 232 |
-
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
|
| 233 |
-
timings_ms, config)
|
| 234 |
spdup_ratio.append(ratio)
|
| 235 |
return round(ratio, 2)
|
| 236 |
else:
|
|
@@ -270,9 +267,7 @@ def make_bwd_benchmark_for_case(
|
|
| 270 |
key = make_bwd_key(dim, batch_size, seq_len)
|
| 271 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 272 |
if provider == "speedup":
|
| 273 |
-
return round(
|
| 274 |
-
timings_ms["naive"][key] /
|
| 275 |
-
_get_best_cuda_timing(timings_ms, key), 2)
|
| 276 |
if provider.endswith("_bw"):
|
| 277 |
base = provider[:-3]
|
| 278 |
ms = timings_ms[base][key]
|
|
@@ -365,8 +360,7 @@ def make_bwd_benchmark_plot_for_case(
|
|
| 365 |
ms = profile_bench(run, total_bytes=nbytes)
|
| 366 |
timings_ms[provider][config] = ms
|
| 367 |
if provider == "cuda":
|
| 368 |
-
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(
|
| 369 |
-
timings_ms, config)
|
| 370 |
spdup_ratio.append(ratio)
|
| 371 |
return round(ratio, 2)
|
| 372 |
else:
|
|
|
|
| 4 |
from typing import Any, Dict, Sequence
|
| 5 |
|
| 6 |
import torch
|
|
|
|
| 7 |
from torch.profiler import ProfilerActivity, profile
|
| 8 |
+
import triton
|
| 9 |
|
| 10 |
from .diff_engine import DiffCase
|
| 11 |
|
|
|
|
| 42 |
if isinstance(output, torch.Tensor):
|
| 43 |
output_bytes = output.nbytes
|
| 44 |
elif isinstance(output, (tuple, list)):
|
| 45 |
+
output_bytes = sum(
|
| 46 |
+
o.nbytes for o in output if isinstance(o, torch.Tensor))
|
| 47 |
else:
|
| 48 |
output_bytes = 0
|
| 49 |
return input_bytes + output_bytes
|
|
|
|
| 158 |
key = make_fwd_key(dim, batch_size, seq_len)
|
| 159 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 160 |
if provider == "speedup":
|
| 161 |
+
return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
|
|
|
|
|
|
|
| 162 |
if provider.endswith("_bw"):
|
| 163 |
base = provider[:-3]
|
| 164 |
ms = timings_ms[base][key]
|
|
|
|
| 227 |
ms = profile_bench(run, total_bytes=nbytes)
|
| 228 |
timings_ms[provider][config] = ms
|
| 229 |
if provider == "cuda":
|
| 230 |
+
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
|
|
|
|
| 231 |
spdup_ratio.append(ratio)
|
| 232 |
return round(ratio, 2)
|
| 233 |
else:
|
|
|
|
| 267 |
key = make_bwd_key(dim, batch_size, seq_len)
|
| 268 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 269 |
if provider == "speedup":
|
| 270 |
+
return round(timings_ms["naive"][key] / _get_best_cuda_timing(timings_ms, key), 2)
|
|
|
|
|
|
|
| 271 |
if provider.endswith("_bw"):
|
| 272 |
base = provider[:-3]
|
| 273 |
ms = timings_ms[base][key]
|
|
|
|
| 360 |
ms = profile_bench(run, total_bytes=nbytes)
|
| 361 |
timings_ms[provider][config] = ms
|
| 362 |
if provider == "cuda":
|
| 363 |
+
ratio = timings_ms["naive"][config] / _get_best_cuda_timing(timings_ms, config)
|
|
|
|
| 364 |
spdup_ratio.append(ratio)
|
| 365 |
return round(ratio, 2)
|
| 366 |
else:
|
benchmarks/run_and_wait.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Usage: ./run_and_wait.sh <yaml-file> [poll-interval-seconds]
|
| 3 |
+
# Deletes existing job, applies yaml, waits for build+run, prints results.
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
YAML="${1:?Usage: $0 <yaml-file> [poll-interval]}"
|
| 8 |
+
POLL="${2:-10}"
|
| 9 |
+
|
| 10 |
+
JOB_NAME=$(grep -m1 '^\s*name:' "$YAML" | awk '{print $2}')
|
| 11 |
+
NAMESPACE=$(grep -m1 '^\s*namespace:' "$YAML" | awk '{print $2}')
|
| 12 |
+
LABEL="batch.kubernetes.io/job-name=${JOB_NAME}-node-0"
|
| 13 |
+
|
| 14 |
+
echo "=== $JOB_NAME | $NAMESPACE ==="
|
| 15 |
+
|
| 16 |
+
# Remember old pods to ignore them
|
| 17 |
+
OLD_PODS=$(kubectl get pods -n "$NAMESPACE" -l "$LABEL" -o jsonpath='{.items[*].metadata.name}' 2>/dev/null || true)
|
| 18 |
+
|
| 19 |
+
# Delete if exists
|
| 20 |
+
kubectl delete -f "$YAML" 2>/dev/null && {
|
| 21 |
+
echo "Deleted old job. Waiting for cleanup..."
|
| 22 |
+
while kubectl get trainjob -n "$NAMESPACE" "$JOB_NAME" &>/dev/null; do sleep 2; done
|
| 23 |
+
sleep 3
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# Apply
|
| 27 |
+
kubectl apply -f "$YAML"
|
| 28 |
+
|
| 29 |
+
# Wait for NEW pod (not in OLD_PODS)
|
| 30 |
+
echo -n "Waiting for new pod"
|
| 31 |
+
POD=""
|
| 32 |
+
while true; do
|
| 33 |
+
ALL_PODS=$(kubectl get pods -n "$NAMESPACE" -l "$LABEL" -o jsonpath='{.items[*].metadata.name}' 2>/dev/null || true)
|
| 34 |
+
for p in $ALL_PODS; do
|
| 35 |
+
if [[ ! " $OLD_PODS " =~ " $p " ]]; then
|
| 36 |
+
PHASE=$(kubectl get pod -n "$NAMESPACE" "$p" -o jsonpath='{.status.phase}' 2>/dev/null || true)
|
| 37 |
+
if [[ -n "$PHASE" ]]; then
|
| 38 |
+
POD="$p"
|
| 39 |
+
echo " $POD ($PHASE)"
|
| 40 |
+
break 2
|
| 41 |
+
fi
|
| 42 |
+
fi
|
| 43 |
+
done
|
| 44 |
+
echo -n "."
|
| 45 |
+
sleep "$POLL"
|
| 46 |
+
done
|
| 47 |
+
|
| 48 |
+
# Wait for pod to complete
|
| 49 |
+
echo -n "Running"
|
| 50 |
+
while true; do
|
| 51 |
+
PHASE=$(kubectl get pod -n "$NAMESPACE" "$POD" -o jsonpath='{.status.phase}' 2>/dev/null || echo "Gone")
|
| 52 |
+
if [[ "$PHASE" != "Running" && "$PHASE" != "Pending" ]]; then
|
| 53 |
+
echo " ($PHASE)"
|
| 54 |
+
break
|
| 55 |
+
fi
|
| 56 |
+
echo -n "."
|
| 57 |
+
sleep "$POLL"
|
| 58 |
+
done
|
| 59 |
+
|
| 60 |
+
# Print logs
|
| 61 |
+
echo ""
|
| 62 |
+
echo "=== LOGS ==="
|
| 63 |
+
kubectl logs -n "$NAMESPACE" "$POD" 2>/dev/null || echo "(no logs available)"
|
| 64 |
+
|
| 65 |
+
echo ""
|
| 66 |
+
echo "=== STATUS: $PHASE ==="
|
benchmarks/run_bench.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Usage: ./run_bench.sh <yaml-file> [poll-interval-seconds]
|
| 3 |
+
# Example: ./run_bench.sh benchmark_rms_optim.yaml 10
|
| 4 |
+
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
YAML="${1:?Usage: $0 <yaml-file> [poll-interval]}"
|
| 8 |
+
POLL="${2:-10}"
|
| 9 |
+
|
| 10 |
+
if [[ ! -f "$YAML" ]]; then
|
| 11 |
+
echo "Error: $YAML not found"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# Extract job name and namespace from yaml
|
| 16 |
+
JOB_NAME=$(grep -m1 '^\s*name:' "$YAML" | awk '{print $2}')
|
| 17 |
+
NAMESPACE=$(grep -m1 '^\s*namespace:' "$YAML" | awk '{print $2}')
|
| 18 |
+
|
| 19 |
+
echo "=== Job: $JOB_NAME | Namespace: $NAMESPACE ==="
|
| 20 |
+
|
| 21 |
+
# Delete if exists
|
| 22 |
+
kubectl delete trainjob -n "$NAMESPACE" "$JOB_NAME" 2>/dev/null && {
|
| 23 |
+
echo "Deleted existing job, waiting 5s..."
|
| 24 |
+
sleep 5
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# Apply
|
| 28 |
+
kubectl apply -f "$YAML"
|
| 29 |
+
echo "Applied. Polling every ${POLL}s..."
|
| 30 |
+
|
| 31 |
+
# Wait for pod
|
| 32 |
+
echo -n "Waiting for pod..."
|
| 33 |
+
while true; do
|
| 34 |
+
POD=$(kubectl get pods -n "$NAMESPACE" -l "batch.kubernetes.io/job-name=${JOB_NAME}-node-0" -o jsonpath='{.items[0].metadata.name}' 2>/dev/null || true)
|
| 35 |
+
if [[ -n "$POD" && "$POD" != "" ]]; then
|
| 36 |
+
STATUS=$(kubectl get pod -n "$NAMESPACE" "$POD" -o jsonpath='{.status.phase}' 2>/dev/null || true)
|
| 37 |
+
if [[ "$STATUS" == "Running" || "$STATUS" == "Succeeded" || "$STATUS" == "Failed" ]]; then
|
| 38 |
+
echo " $POD ($STATUS)"
|
| 39 |
+
break
|
| 40 |
+
fi
|
| 41 |
+
fi
|
| 42 |
+
echo -n "."
|
| 43 |
+
sleep "$POLL"
|
| 44 |
+
done
|
| 45 |
+
|
| 46 |
+
# Stream logs until completion
|
| 47 |
+
echo "=== Streaming logs ==="
|
| 48 |
+
kubectl logs -n "$NAMESPACE" "$POD" -f 2>/dev/null || true
|
| 49 |
+
|
| 50 |
+
# Final status
|
| 51 |
+
echo ""
|
| 52 |
+
echo "=== Final Status ==="
|
| 53 |
+
kubectl get trainjob -n "$NAMESPACE" "$JOB_NAME" -o jsonpath='{.status.conditions[0].reason}' 2>/dev/null
|
| 54 |
+
echo ""
|
benchmarks/run_rms_bench.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quick RMS benchmark with custom configs."""
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from common.bench_framework import (make_bwd_benchmark_for_case,
|
| 8 |
+
make_fwd_benchmark_for_case)
|
| 9 |
+
from common.diff_engine import calculate_diff
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 12 |
+
from cases.rms import CASE
|
| 13 |
+
|
| 14 |
+
torch.set_default_device("cuda")
|
| 15 |
+
|
| 16 |
+
configs = [
|
| 17 |
+
(512, 8, 4096),
|
| 18 |
+
(1024, 8, 4096),
|
| 19 |
+
(4096, 8, 4096),
|
| 20 |
+
(16384, 8, 4096),
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
# Correctness check
|
| 24 |
+
for dim, bs, sl in configs:
|
| 25 |
+
print(f"Correctness: bs={bs}, sl={sl}, D={dim}...", end=" ")
|
| 26 |
+
calculate_diff(CASE, batch_size=bs, seq_len=sl, hidden_size=dim)
|
| 27 |
+
print("ok")
|
| 28 |
+
|
| 29 |
+
print()
|
| 30 |
+
|
| 31 |
+
line_vals = ("naive", "naive_bw", "cuda", "cuda_bw", "speedup")
|
| 32 |
+
line_names = {
|
| 33 |
+
"naive": "Naive (us)",
|
| 34 |
+
"naive_bw": "Naive (GB/s)",
|
| 35 |
+
"cuda": "CUDA (us)",
|
| 36 |
+
"cuda_bw": "CUDA (GB/s)",
|
| 37 |
+
"speedup": "SpeedUp (ratio)",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
save_dir = "./results/rms_custom"
|
| 41 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
bench = make_fwd_benchmark_for_case(
|
| 44 |
+
case=CASE,
|
| 45 |
+
configs=configs,
|
| 46 |
+
plot_name="rms-bf16-fwd",
|
| 47 |
+
dtype=torch.bfloat16,
|
| 48 |
+
line_vals=line_vals,
|
| 49 |
+
line_names=line_names,
|
| 50 |
+
)
|
| 51 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 52 |
+
|
| 53 |
+
bench = make_bwd_benchmark_for_case(
|
| 54 |
+
case=CASE,
|
| 55 |
+
configs=configs,
|
| 56 |
+
plot_name="rms-bf16-bwd",
|
| 57 |
+
dtype=torch.bfloat16,
|
| 58 |
+
line_vals=line_vals,
|
| 59 |
+
line_names=line_names,
|
| 60 |
+
)
|
| 61 |
+
bench.run(print_data=True, save_path=save_dir)
|