| import numpy as np
|
| from .layer_norm import layer_norm
|
| from .gelu import gelu
|
| from .multihead_attention import multihead_attention
|
|
|
| def transformer_block(x, weights, num_heads, mask=None, driver=None, chip_id=0, sm_id=0, scheduler=None):
|
| """
|
| x: (batch, seq_len, hidden_dim)
|
| weights: dict with keys for all block weights
|
| driver: VirtualGPUDriver instance
|
| chip_id, sm_id: hardware location
|
| scheduler: function to select (chip_id, sm_id) for each op (best practice: round-robin or load-balance)
|
| """
|
|
|
| if scheduler is None:
|
|
|
| def scheduler(op_name, op_idx=[0], chips=driver.hardware_config.get('num_chips', 1), sms=driver.hardware_config.get('num_sms_per_chip', 1)):
|
| idx = op_idx[0]
|
| chip = idx % chips
|
| sm = (idx // chips) % sms
|
| op_idx[0] += 1
|
| return chip, sm
|
|
|
|
|
| chip_id, sm_id = scheduler('layernorm1')
|
| x_norm1 = driver.layernorm(x, weights['ln1.weight'], weights['ln1.bias'], chip_id=chip_id, sm_id=sm_id)
|
|
|
| chip_id, sm_id = scheduler('multihead_attention')
|
| attn_out, _ = multihead_attention(
|
| x_norm1,
|
| weights['attn.q_proj.weight'],
|
| weights['attn.k_proj.weight'],
|
| weights['attn.v_proj.weight'],
|
| weights['attn.out_proj.weight'],
|
| num_heads,
|
| mask,
|
| driver=driver,
|
| chip_id=chip_id,
|
| sm_id=sm_id,
|
| scheduler=scheduler
|
| )
|
|
|
| x2 = x + attn_out
|
|
|
| chip_id, sm_id = scheduler('layernorm2')
|
| x_norm2 = driver.layernorm(x2, weights['ln2.weight'], weights['ln2.bias'], chip_id=chip_id, sm_id=sm_id)
|
|
|
| chip_id, sm_id = scheduler('ff1')
|
| ff1 = driver.matmul(x_norm2, weights['ff1.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff1.bias']
|
| chip_id, sm_id = scheduler('gelu')
|
| ff1 = driver.gelu(ff1, chip_id=chip_id, sm_id=sm_id)
|
| chip_id, sm_id = scheduler('ff2')
|
| ff2 = driver.matmul(ff1, weights['ff2.weight'], chip_id=chip_id, sm_id=sm_id) + weights['ff2.bias']
|
|
|
| output = x2 + ff2
|
| return output
|
|
|