File size: 687 Bytes
bad4ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# /// script
# dependencies = [
#     "torch",
#     "numpy",
# ]
# ///

"""Configuration for MoE benchmarks."""
import torch

# Model configuration
NUM_EXPERTS = 128
HIDDEN_SIZE = 1152
TOP_K = 4

# Benchmark configuration  
BATCH_SIZE = 8
SEQ_LEN = 512
DTYPE = "bfloat16"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Seeds for reproducibility
WEIGHT_SEED = 999
EXPERT_SEED = 777
INPUT_SEED = 123
GENERAL_SEED = 42

print(f"Configuration:")
print(f"  Experts: {NUM_EXPERTS}")
print(f"  Hidden size: {HIDDEN_SIZE}")
print(f"  Top-k: {TOP_K}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Sequence length: {SEQ_LEN}")
print(f"  Device: {DEVICE}")
print(f"  Dtype: {DTYPE}")