|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
|
|
|
F_SYLLABLE = 0 |
|
|
F_ONSET = 1 |
|
|
F_NUCLEUS = 2 |
|
|
F_CODA = 3 |
|
|
F_POSITION = 4 |
|
|
F_CAPITALIZED = 5 |
|
|
F_TOKEN_TYPE = 6 |
|
|
F_SPACE_AFTER = 7 |
|
|
F_WORD_END = 8 |
|
|
N_FEATURES = 9 |
|
|
|
|
|
@dataclass |
|
|
class LunaConfig: |
|
|
"""Configuration for Luna.""" |
|
|
|
|
|
|
|
|
syllable_vocab: int = 32768 |
|
|
onset_vocab: int = 2048 |
|
|
nucleus_vocab: int = 512 |
|
|
coda_vocab: int = 2048 |
|
|
|
|
|
|
|
|
position_vocab: int = 4 |
|
|
capitalized_vocab: int = 2 |
|
|
token_type_vocab: int = 4 |
|
|
space_vocab: int = 2 |
|
|
word_end_vocab: int = 2 |
|
|
|
|
|
|
|
|
syllable_dim: int = 256 |
|
|
onset_dim: int = 64 |
|
|
nucleus_dim: int = 64 |
|
|
coda_dim: int = 64 |
|
|
position_dim: int = 32 |
|
|
cap_dim: int = 16 |
|
|
type_dim: int = 16 |
|
|
space_dim: int = 32 |
|
|
word_end_dim: int = 16 |
|
|
|
|
|
|
|
|
n_layer: int = 12 |
|
|
n_head: int = 12 |
|
|
n_embd: int = 768 |
|
|
dropout: float = 0.1 |
|
|
max_seq_len: int = 1024 |
|
|
|
|
|
|
|
|
fuse_output_heads: bool = True |
|
|
use_flash_attention: bool = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
__constants__ = ['eps'] |
|
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
"""RoPE with pre-computed cache.""" |
|
|
def __init__(self, dim: int, max_seq_len: int = 2048): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
self._build_cache(max_seq_len) |
|
|
|
|
|
def _build_cache(self, seq_len: int): |
|
|
|
|
|
device = self.inv_freq.device |
|
|
t = torch.arange(seq_len, device=device) |
|
|
|
|
|
|
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
self.register_buffer("cos_cached", freqs.cos(), persistent=False) |
|
|
self.register_buffer("sin_cached", freqs.sin(), persistent=False) |
|
|
|
|
|
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if seq_len > self.cos_cached.shape[0]: |
|
|
self._build_cache(seq_len) |
|
|
return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
|
@torch.jit.script |
|
|
def apply_rotary_emb_fused(q: torch.Tensor, k: torch.Tensor, |
|
|
cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""JIT-compiled rotary embedding application.""" |
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
q_even, q_odd = q[..., 0::2], q[..., 1::2] |
|
|
k_even, k_odd = k[..., 0::2], k[..., 1::2] |
|
|
|
|
|
q_rot = torch.cat([q_even * cos - q_odd * sin, q_even * sin + q_odd * cos], dim=-1) |
|
|
k_rot = torch.cat([k_even * cos - k_odd * sin, k_even * sin + k_odd * cos], dim=-1) |
|
|
|
|
|
return q_rot, k_rot |
|
|
|
|
|
class Attention(nn.Module): |
|
|
|
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
self.n_head = config.n_head |
|
|
self.head_dim = config.n_embd // config.n_head |
|
|
self.dropout_p = config.dropout |
|
|
|
|
|
|
|
|
self.wqkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) |
|
|
self.wo = nn.Linear(config.n_embd, config.n_embd, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
B, T, C = x.shape |
|
|
|
|
|
|
|
|
qkv = self.wqkv(x) |
|
|
q, k, v = qkv.split(C, dim=-1) |
|
|
|
|
|
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
|
|
|
|
|
q, k = apply_rotary_emb_fused(q, k, cos, sin) |
|
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=None, |
|
|
dropout_p=self.dropout_p if self.training else 0.0, |
|
|
is_causal=True |
|
|
) |
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.wo(out) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""SwiGLU with fused gate computation.""" |
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
hidden = int(4 * config.n_embd) |
|
|
|
|
|
|
|
|
self.w13 = nn.Linear(config.n_embd, 2 * hidden, bias=False) |
|
|
self.w2 = nn.Linear(hidden, config.n_embd, bias=False) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
x13 = self.w13(x) |
|
|
x1, x3 = x13.chunk(2, dim=-1) |
|
|
return self.dropout(self.w2(F.silu(x1) * x3)) |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Pre-norm transformer block.""" |
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(config.n_embd) |
|
|
self.attn = Attention(config) |
|
|
self.norm2 = RMSNorm(config.n_embd) |
|
|
self.ffn = FeedForward(config) |
|
|
|
|
|
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
|
x = x + self.attn(self.norm1(x), cos, sin) |
|
|
x = x + self.ffn(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedDualStreamFusion(nn.Module): |
|
|
|
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.syllable_embed = nn.Embedding(config.syllable_vocab, config.syllable_dim) |
|
|
|
|
|
|
|
|
self.onset_embed = nn.Embedding(config.onset_vocab, config.onset_dim) |
|
|
self.nucleus_embed = nn.Embedding(config.nucleus_vocab, config.nucleus_dim) |
|
|
self.coda_embed = nn.Embedding(config.coda_vocab, config.coda_dim) |
|
|
|
|
|
phonetic_dim = config.onset_dim + config.nucleus_dim + config.coda_dim |
|
|
self.phonetic_proj = nn.Linear(phonetic_dim, config.syllable_dim, bias=False) |
|
|
|
|
|
self.gate = nn.Sequential( |
|
|
nn.Linear(config.syllable_dim * 2, config.syllable_dim // 2, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(config.syllable_dim // 2, 1, bias=False), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.aux_embeddings = nn.ModuleDict({ |
|
|
'position': nn.Embedding(config.position_vocab, config.position_dim), |
|
|
'cap': nn.Embedding(config.capitalized_vocab, config.cap_dim), |
|
|
'tok_type': nn.Embedding(config.token_type_vocab, config.type_dim), |
|
|
'space': nn.Embedding(config.space_vocab, config.space_dim), |
|
|
'word_end': nn.Embedding(config.word_end_vocab, config.word_end_dim), |
|
|
}) |
|
|
|
|
|
self.aux_dim = config.position_dim + config.cap_dim + config.type_dim + config.space_dim + config.word_end_dim |
|
|
|
|
|
|
|
|
total_dim = config.syllable_dim + self.aux_dim |
|
|
self.output_proj = nn.Linear(total_dim, config.n_embd, bias=False) |
|
|
self.output_norm = RMSNorm(config.n_embd) |
|
|
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
features: [B, T, 9] stacked feature tensor |
|
|
Returns: |
|
|
[B, T, n_embd] embedded representation |
|
|
""" |
|
|
|
|
|
syl_ids = features[:, :, F_SYLLABLE] |
|
|
onset_ids = features[:, :, F_ONSET] |
|
|
nucleus_ids = features[:, :, F_NUCLEUS] |
|
|
coda_ids = features[:, :, F_CODA] |
|
|
pos_ids = features[:, :, F_POSITION] |
|
|
cap_ids = features[:, :, F_CAPITALIZED] |
|
|
type_ids = features[:, :, F_TOKEN_TYPE] |
|
|
space_ids = features[:, :, F_SPACE_AFTER] |
|
|
word_end_ids = features[:, :, F_WORD_END] |
|
|
|
|
|
|
|
|
semantic = self.syllable_embed(syl_ids) |
|
|
|
|
|
|
|
|
onset = self.onset_embed(onset_ids) |
|
|
nucleus = self.nucleus_embed(nucleus_ids) |
|
|
coda = self.coda_embed(coda_ids) |
|
|
phonetic = self.phonetic_proj(torch.cat([onset, nucleus, coda], dim=-1)) |
|
|
|
|
|
|
|
|
gate_in = torch.cat([semantic, phonetic], dim=-1) |
|
|
alpha = self.gate(gate_in) |
|
|
fused = alpha * semantic + (1 - alpha) * phonetic |
|
|
|
|
|
|
|
|
aux = torch.cat([ |
|
|
self.aux_embeddings['position'](pos_ids), |
|
|
self.aux_embeddings['cap'](cap_ids), |
|
|
self.aux_embeddings['tok_type'](type_ids), |
|
|
self.aux_embeddings['space'](space_ids), |
|
|
self.aux_embeddings['word_end'](word_end_ids), |
|
|
], dim=-1) |
|
|
|
|
|
|
|
|
combined = torch.cat([fused, aux], dim=-1) |
|
|
return self.output_norm(self.output_proj(combined)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusedOutputHeads(nn.Module): |
|
|
|
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.head_sizes = { |
|
|
'syllable': config.syllable_vocab, |
|
|
'onset': config.onset_vocab, |
|
|
'nucleus': config.nucleus_vocab, |
|
|
'coda': config.coda_vocab, |
|
|
'position': config.position_vocab, |
|
|
'is_capitalized': config.capitalized_vocab, |
|
|
'token_type': config.token_type_vocab, |
|
|
'has_space_after': config.space_vocab, |
|
|
} |
|
|
|
|
|
self.head_names = list(self.head_sizes.keys()) |
|
|
self.total_output = sum(self.head_sizes.values()) |
|
|
|
|
|
|
|
|
self.fused_head = nn.Linear(config.n_embd, self.total_output, bias=False) |
|
|
|
|
|
|
|
|
self.split_sizes = [self.head_sizes[name] for name in self.head_names] |
|
|
|
|
|
|
|
|
self.register_buffer('_split_sizes_tensor', torch.tensor(self.split_sizes)) |
|
|
|
|
|
def forward(self, h: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
h: [B, T, n_embd] |
|
|
Returns: |
|
|
Dict of logits for each head |
|
|
""" |
|
|
|
|
|
all_logits = self.fused_head(h) |
|
|
|
|
|
|
|
|
splits = all_logits.split(self.split_sizes, dim=-1) |
|
|
|
|
|
return {name: logit for name, logit in zip(self.head_names, splits)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizedMultiTaskLoss(nn.Module): |
|
|
"""Vectorized multi-task loss computation. """ |
|
|
|
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.register_buffer('loss_weights', torch.tensor([ |
|
|
1.0, |
|
|
0.2, |
|
|
0.2, |
|
|
0.2, |
|
|
0.3, |
|
|
0.1, |
|
|
0.15, |
|
|
0.4, |
|
|
])) |
|
|
|
|
|
self.weight_sum = self.loss_weights.sum().item() |
|
|
|
|
|
|
|
|
self.register_buffer('position_weights', torch.tensor([0.8, 1.0, 1.5, 1.2])) |
|
|
self.register_buffer('type_weights', torch.tensor([1.0, 1.2, 2.5, 1.0])) |
|
|
|
|
|
|
|
|
self.target_indices = [F_SYLLABLE, F_ONSET, F_NUCLEUS, F_CODA, |
|
|
F_POSITION, F_CAPITALIZED, F_TOKEN_TYPE, F_SPACE_AFTER] |
|
|
|
|
|
def forward(self, logits: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
logits: Dict of [B, T, V] tensors |
|
|
targets: [B, T, 9] target tensor |
|
|
""" |
|
|
head_names = ['syllable', 'onset', 'nucleus', 'coda', |
|
|
'position', 'is_capitalized', 'token_type', 'has_space_after'] |
|
|
|
|
|
total_loss = 0.0 |
|
|
|
|
|
|
|
|
pos_targets = targets[:, :, F_POSITION] |
|
|
type_targets = targets[:, :, F_TOKEN_TYPE] |
|
|
|
|
|
for i, name in enumerate(head_names): |
|
|
logit = logits[name] |
|
|
target = targets[:, :, self.target_indices[i]] |
|
|
weight = self.loss_weights[i] |
|
|
|
|
|
if name == 'syllable': |
|
|
|
|
|
B, T, V = logit.shape |
|
|
per_token = F.cross_entropy( |
|
|
logit.view(-1, V), target.view(-1), reduction='none' |
|
|
).view(B, T) |
|
|
|
|
|
pos_w = self.position_weights[pos_targets] |
|
|
type_w = self.type_weights[type_targets] |
|
|
head_loss = (per_token * pos_w * type_w).mean() |
|
|
else: |
|
|
head_loss = F.cross_entropy( |
|
|
logit.view(-1, logit.size(-1)), target.view(-1) |
|
|
) |
|
|
|
|
|
total_loss = total_loss + weight * head_loss |
|
|
|
|
|
return total_loss / self.weight_sum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Luna(nn.Module): |
|
|
|
|
|
def __init__(self, config: LunaConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embedding = OptimizedDualStreamFusion(config) |
|
|
|
|
|
|
|
|
self.rotary = RotaryEmbedding(config.n_embd // config.n_head, config.max_seq_len) |
|
|
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]) |
|
|
self.norm = RMSNorm(config.n_embd) |
|
|
|
|
|
|
|
|
if config.fuse_output_heads: |
|
|
self.heads = FusedOutputHeads(config) |
|
|
else: |
|
|
self.heads = nn.ModuleDict({ |
|
|
'syllable': nn.Linear(config.n_embd, config.syllable_vocab, bias=False), |
|
|
'onset': nn.Linear(config.n_embd, config.onset_vocab, bias=False), |
|
|
'nucleus': nn.Linear(config.n_embd, config.nucleus_vocab, bias=False), |
|
|
'coda': nn.Linear(config.n_embd, config.coda_vocab, bias=False), |
|
|
'position': nn.Linear(config.n_embd, config.position_vocab, bias=False), |
|
|
'is_capitalized': nn.Linear(config.n_embd, config.capitalized_vocab, bias=False), |
|
|
'token_type': nn.Linear(config.n_embd, config.token_type_vocab, bias=False), |
|
|
'has_space_after': nn.Linear(config.n_embd, config.space_vocab, bias=False), |
|
|
}) |
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self.loss_fn = OptimizedMultiTaskLoss(config) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
self._print_info() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def _print_info(self): |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
embed_params = sum(p.numel() for p in self.embedding.parameters()) |
|
|
|
|
|
if isinstance(self.heads, FusedOutputHeads): |
|
|
head_params = self.heads.fused_head.weight.numel() |
|
|
else: |
|
|
head_params = sum(p.numel() for p in self.heads.parameters()) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("Luna Summary") |
|
|
print(f"{'='*60}") |
|
|
print(f"Total parameters: {n_params:,}") |
|
|
print(f"Embedding parameters: {embed_params:,}") |
|
|
print(f"Output head parameters: {head_params:,}") |
|
|
print(f"Transformer backbone: {n_params - embed_params - head_params:,}") |
|
|
print(f"\nOptimizations enabled:") |
|
|
print(f" - Fused QKV projection") |
|
|
print(f" - Fused FFN gate") |
|
|
print(f" - Fused output heads: {self.config.fuse_output_heads}") |
|
|
print(f" - JIT rotary embeddings") |
|
|
print(f" - RMSNorm everywhere") |
|
|
print(f" - Vectorized loss") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
targets: Optional[torch.Tensor] = None |
|
|
) -> Tuple[Dict[str, torch.Tensor], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Args: |
|
|
features: [B, T, 9] input features |
|
|
targets: [B, T, 9] targets (optional) |
|
|
Returns: |
|
|
logits: Dict of output logits |
|
|
loss: Combined loss (if targets provided) |
|
|
""" |
|
|
B, T, _ = features.shape |
|
|
|
|
|
|
|
|
h = self.embedding(features) |
|
|
h = self.dropout(h) |
|
|
|
|
|
|
|
|
cos, sin = self.rotary(T) |
|
|
for layer in self.layers: |
|
|
h = layer(h, cos, sin) |
|
|
h = self.norm(h) |
|
|
|
|
|
|
|
|
if isinstance(self.heads, FusedOutputHeads): |
|
|
logits = self.heads(h) |
|
|
else: |
|
|
logits = {name: head(h) for name, head in self.heads.items()} |
|
|
|
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = self.loss_fn(logits, targets) |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dict_to_tensor(features_dict: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
"""Convert dict features to stacked tensor.""" |
|
|
return torch.stack([ |
|
|
features_dict['syllable_id'], |
|
|
features_dict['onset_id'], |
|
|
features_dict['nucleus_id'], |
|
|
features_dict['coda_id'], |
|
|
features_dict['position'], |
|
|
features_dict['is_capitalized'], |
|
|
features_dict['token_type'], |
|
|
features_dict['has_space_after'], |
|
|
features_dict['is_word_end'], |
|
|
], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Luna - Speed Test") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
config = LunaConfig( |
|
|
syllable_vocab=32768, |
|
|
onset_vocab=2048, |
|
|
nucleus_vocab=512, |
|
|
coda_vocab=2048, |
|
|
max_seq_len=1024, |
|
|
fuse_output_heads=True, |
|
|
) |
|
|
|
|
|
model = Luna(config).to(device) |
|
|
|
|
|
|
|
|
B, T = 8, 1024 |
|
|
features = torch.stack([ |
|
|
torch.randint(0, 1000, (B, T)), |
|
|
torch.randint(0, 100, (B, T)), |
|
|
torch.randint(0, 50, (B, T)), |
|
|
torch.randint(0, 100, (B, T)), |
|
|
torch.randint(0, 4, (B, T)), |
|
|
torch.randint(0, 2, (B, T)), |
|
|
torch.randint(0, 4, (B, T)), |
|
|
torch.randint(0, 2, (B, T)), |
|
|
torch.randint(0, 2, (B, T)), |
|
|
], dim=-1).to(device) |
|
|
|
|
|
targets = features.clone() |
|
|
|
|
|
|
|
|
for _ in range(3): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
logits, loss = model(features, targets) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
import time |
|
|
n_iters = 50 |
|
|
start = time.time() |
|
|
|
|
|
for _ in range(n_iters): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
logits, loss = model(features, targets) |
|
|
loss.backward() |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
elapsed = time.time() - start |
|
|
|
|
|
tokens_per_iter = B * T |
|
|
tok_per_sec = (n_iters * tokens_per_iter) / elapsed |
|
|
|
|
|
print(f"\nBenchmark Results:") |
|
|
print(f" Batch: {B} x {T} = {B*T:,} tokens") |
|
|
print(f" Iterations: {n_iters}") |
|
|
print(f" Time: {elapsed:.2f}s") |
|
|
print(f" Throughput: {tok_per_sec:,.0f} tok/s") |
|
|
print(f" Loss: {loss.item():.4f}") |
|
|
|
|
|
|
|
|
print("\nTesting torch.compile()...") |
|
|
compiled_model = torch.compile(model, mode="reduce-overhead") |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
logits, loss = compiled_model(features, targets) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
for _ in range(n_iters): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
logits, loss = compiled_model(features, targets) |
|
|
loss.backward() |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
elapsed_compiled = time.time() - start |
|
|
tok_per_sec_compiled = (n_iters * tokens_per_iter) / elapsed_compiled |
|
|
|
|
|
print(f"\nCompiled Results:") |
|
|
print(f" Throughput: {tok_per_sec_compiled:,.0f} tok/s") |
|
|
print(f" Speedup: {tok_per_sec_compiled/tok_per_sec:.2f}x") |
|
|
print(f"\n✓ All tests passed!") |