INV / helium /block.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from .layer_norm import layer_norm
from .gelu import gelu
from .multihead_attention import multihead_attention
def transformer_block(x, weights, num_heads, mask=None, driver=None, chip_id=0, sm_id=0, scheduler=None):
"""
x: (batch, seq_len, hidden_dim)
weights: dict with keys for all block weights
driver: VirtualGPUDriver instance
chip_id, sm_id: hardware location
scheduler: function to select (chip_id, sm_id) for each op (best practice: round-robin or load-balance)
"""
# Scheduler setup
if scheduler is None:
# Default: round-robin scheduler over available chips/SMs
def scheduler(op_name, op_idx=[0], chips=driver.hardware_config.get('num_chips', 1), sms=driver.hardware_config.get('num_sms_per_chip', 1)):
idx = op_idx[0]
chip = idx % chips
sm = (idx // chips) % sms
op_idx[0] += 1
return chip, sm
# LayerNorm 1 (GPU, scheduled)
chip_id, sm_id = scheduler('layernorm1')
x_norm1 = driver.layernorm(x, weights['ln1.weight'], weights['ln1.bias'], chip_id=chip_id, sm_id=sm_id)
# Multi-head attention (GPU, scheduled)
chip_id, sm_id = scheduler('multihead_attention')
attn_out, _ = multihead_attention(
x_norm1,
weights['attn.q_proj.weight'],
weights['attn.k_proj.weight'],
weights['attn.v_proj.weight'],
weights['attn.out_proj.weight'],
num_heads,
mask,
driver=driver,
chip_id=chip_id,
sm_id=sm_id,
scheduler=scheduler
)
# Residual 1
x2 = x + attn_out
# LayerNorm 2 (GPU, scheduled)
chip_id, sm_id = scheduler('layernorm2')
x_norm2 = driver.layernorm(x2, weights['ln2.weight'], weights['ln2.bias'], chip_id=chip_id, sm_id=sm_id)
# Feedforward (GPU, scheduled)
chip_id, sm_id = scheduler('ff1')
ff1 = driver.matmul(x_norm2, weights['ff1.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff1.bias']
chip_id, sm_id = scheduler('gelu')
ff1 = driver.gelu(ff1, chip_id=chip_id, sm_id=sm_id)
chip_id, sm_id = scheduler('ff2')
ff2 = driver.matmul(ff1, weights['ff2.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff2.bias']
# Residual 2
output = x2 + ff2
return output