|
|
import tiktoken |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from contextlib import nullcontext |
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Tuple, Optional, Literal |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
|
|
|
from kernel import act_quant, weight_dequant, fp8_gemm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArgs: |
|
|
max_batch_size: int = 8 |
|
|
max_seq_len: int = 2048 |
|
|
dtype: Literal["bf16", "fp8"] = "bf16" |
|
|
scale_fmt: Optional[str] = None |
|
|
|
|
|
vocab_size: int = 102400 |
|
|
dim: int = 1024 |
|
|
inter_dim: int = 4096 |
|
|
moe_inter_dim: int = 1024 |
|
|
n_layers: int = 20 |
|
|
n_dense_layers: int = 3 |
|
|
n_heads: int = 12 |
|
|
|
|
|
|
|
|
n_routed_experts: int = 6 |
|
|
n_shared_experts: int = 1 |
|
|
n_activated_experts: int = 2 |
|
|
route_scale: float = 1. |
|
|
use_routing_bias: bool = True |
|
|
|
|
|
|
|
|
q_lora_rank: int = 0 |
|
|
kv_lora_rank: int = 512 |
|
|
qk_nope_head_dim: int = 128 |
|
|
qk_rope_head_dim: int = 64 |
|
|
v_head_dim: int = 128 |
|
|
|
|
|
|
|
|
original_seq_len: int = 4096 |
|
|
rope_theta: float = 10000.0 |
|
|
rope_factor: float = 40 |
|
|
beta_fast: int = 32 |
|
|
beta_slow: int = 1 |
|
|
mscale: float = 1. |
|
|
|
|
|
tokenizer_name: str = "gpt2" |
|
|
|
|
|
|
|
|
world_size = 1 |
|
|
rank = 0 |
|
|
block_size = 128 |
|
|
gemm_impl: Literal["bf16", "fp8"] = "bf16" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: |
|
|
dim = args.qk_rope_head_dim |
|
|
seqlen = args.max_seq_len |
|
|
beta_fast = args.beta_fast |
|
|
beta_slow = args.beta_slow |
|
|
base = args.rope_theta |
|
|
factor = args.rope_factor |
|
|
|
|
|
def find_correction_dim(num_rotations, dim, base, max_seq_len): |
|
|
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) |
|
|
|
|
|
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): |
|
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) |
|
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) |
|
|
return max(low, 0), min(high, dim-1) |
|
|
|
|
|
def linear_ramp_factor(min, max, dim): |
|
|
if min == max: |
|
|
max += 0.001 |
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) |
|
|
ramp_func = torch.clamp(linear_func, 0, 1) |
|
|
return ramp_func |
|
|
|
|
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
|
|
if seqlen > args.original_seq_len: |
|
|
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) |
|
|
smooth = 1 - linear_ramp_factor(low, high, dim // 2) |
|
|
freqs = freqs / factor * (1 - smooth) + freqs * smooth |
|
|
|
|
|
t = torch.arange(seqlen) |
|
|
freqs = torch.outer(t, freqs) |
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
|
return freqs_cis |
|
|
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
|
|
dtype = x.dtype |
|
|
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) |
|
|
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) |
|
|
y = torch.view_as_real(x * freqs_cis).flatten(3) |
|
|
return y.to(dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor: |
|
|
|
|
|
if weight.element_size() > 1: |
|
|
return F.linear(x, weight, bias) |
|
|
elif gemm_impl == "bf16": |
|
|
weight = weight_dequant(weight, weight.scale) |
|
|
return F.linear(x, weight, bias) |
|
|
else: |
|
|
x, scale = act_quant(x, block_size, scale_fmt) |
|
|
y = fp8_gemm(x, scale, weight, weight.scale) |
|
|
if bias is not None: |
|
|
y += bias |
|
|
return y |
|
|
|
|
|
|
|
|
class Linear(nn.Module): |
|
|
dtype = torch.float32 |
|
|
scale_fmt: Optional[str] = None |
|
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
|
|
super().__init__() |
|
|
self.in_features = in_features |
|
|
self.out_features = out_features |
|
|
|
|
|
|
|
|
param_dtype = dtype or Linear.dtype |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=param_dtype)) |
|
|
|
|
|
nn.init.normal_(self.weight, mean=0.0, std=0.02 / math.sqrt(in_features)) |
|
|
|
|
|
if self.weight.element_size() == 1: |
|
|
scale_out_features = (out_features + block_size - 1) // block_size |
|
|
scale_in_features = (in_features + block_size - 1) // block_size |
|
|
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) |
|
|
|
|
|
nn.init.ones_(self.scale) |
|
|
else: |
|
|
self.register_parameter("scale", None) |
|
|
|
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(out_features, dtype=param_dtype)) |
|
|
nn.init.zeros_(self.bias) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return linear(x, self.weight, self.bias, self.scale_fmt) |
|
|
|
|
|
|
|
|
class ColumnParallelLinear(Linear): |
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
|
|
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" |
|
|
self.part_out_features = out_features // world_size |
|
|
super().__init__(in_features, self.part_out_features, bias, dtype) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
y = linear(x, self.weight, self.bias) |
|
|
return y |
|
|
|
|
|
|
|
|
class RowParallelLinear(Linear): |
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): |
|
|
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" |
|
|
self.part_in_features = in_features // world_size |
|
|
super().__init__(self.part_in_features, out_features, bias, dtype) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
y = linear(x, self.weight) |
|
|
if world_size > 1: |
|
|
dist.all_reduce(y) |
|
|
if self.bias is not None: |
|
|
y += self.bias |
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.eps = eps |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
output = F.rms_norm(x.float(), (self.dim,), self.weight, self.eps) |
|
|
return output.to(x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadLatentAttention(nn.Module): |
|
|
def __init__(self, args: ModelArgs): |
|
|
super().__init__() |
|
|
self.dim = args.dim |
|
|
self.n_heads = args.n_heads |
|
|
self.n_local_heads = args.n_heads // world_size |
|
|
self.q_lora_rank = args.q_lora_rank |
|
|
self.kv_lora_rank = args.kv_lora_rank |
|
|
self.qk_nope_head_dim = args.qk_nope_head_dim |
|
|
self.qk_rope_head_dim = args.qk_rope_head_dim |
|
|
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim |
|
|
self.v_head_dim = args.v_head_dim |
|
|
|
|
|
if self.q_lora_rank == 0: |
|
|
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) |
|
|
else: |
|
|
self.wq_a = Linear(self.dim, self.q_lora_rank) |
|
|
self.q_norm = RMSNorm(self.q_lora_rank) |
|
|
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) |
|
|
|
|
|
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) |
|
|
self.kv_norm = RMSNorm(self.kv_lora_rank) |
|
|
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) |
|
|
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) |
|
|
self.softmax_scale = self.qk_head_dim ** -0.5 |
|
|
|
|
|
if args.max_seq_len > args.original_seq_len: |
|
|
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 |
|
|
self.softmax_scale = self.softmax_scale * mscale * mscale |
|
|
|
|
|
|
|
|
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank, dtype=Linear.dtype), persistent=False) |
|
|
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim, dtype=Linear.dtype), persistent=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): |
|
|
|
|
|
bsz, seqlen, _ = x.size() |
|
|
end_pos = start_pos + seqlen |
|
|
if self.q_lora_rank == 0: |
|
|
q = self.wq(x) |
|
|
else: |
|
|
q = self.wq_b(self.q_norm(self.wq_a(x))) |
|
|
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) |
|
|
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
q_pe = apply_rotary_emb(q_pe, freqs_cis) |
|
|
kv = self.wkv_a(x) |
|
|
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
|
|
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) |
|
|
|
|
|
|
|
|
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) |
|
|
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) |
|
|
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) |
|
|
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv).detach() |
|
|
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2).detach() |
|
|
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + |
|
|
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale |
|
|
|
|
|
if mask is not None: |
|
|
scores += mask.unsqueeze(1) |
|
|
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) |
|
|
|
|
|
|
|
|
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) |
|
|
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) |
|
|
x = self.wo(x.flatten(2)) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Gate(nn.Module): |
|
|
|
|
|
def __init__(self, args: ModelArgs): |
|
|
super().__init__() |
|
|
self.dim = args.dim |
|
|
self.n_routed_experts = args.n_routed_experts |
|
|
self.n_activated_experts = args.n_activated_experts |
|
|
self.route_scale = args.route_scale |
|
|
|
|
|
|
|
|
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim, dtype=Linear.dtype)) |
|
|
nn.init.normal_(self.weight, mean=0.0, std=0.02 / math.sqrt(args.dim)) |
|
|
|
|
|
|
|
|
if args.use_routing_bias: |
|
|
self.bias = nn.Parameter(torch.zeros(args.n_routed_experts, dtype=torch.float32)) |
|
|
else: |
|
|
self.register_parameter("bias", None) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
scores = linear(x, self.weight) |
|
|
|
|
|
|
|
|
scores = scores.sigmoid() |
|
|
|
|
|
original_scores = scores |
|
|
|
|
|
|
|
|
if self.bias is not None: |
|
|
scores = scores + self.bias |
|
|
|
|
|
|
|
|
indices = torch.topk(scores, self.n_activated_experts, dim=-1)[1] |
|
|
weights = original_scores.gather(1, indices) |
|
|
|
|
|
|
|
|
weights = weights / weights.sum(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
weights = weights * self.route_scale |
|
|
|
|
|
return weights.type_as(x), indices |
|
|
|
|
|
|
|
|
class Expert(nn.Module): |
|
|
|
|
|
def __init__(self, dim: int, inter_dim: int): |
|
|
super().__init__() |
|
|
self.w1 = Linear(dim, inter_dim, bias=False) |
|
|
self.w2 = Linear(inter_dim, dim, bias=False) |
|
|
self.w3 = Linear(dim, inter_dim, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
|
|
|
class MoE(nn.Module): |
|
|
def __init__(self, args: ModelArgs): |
|
|
super().__init__() |
|
|
self.dim = args.dim |
|
|
self.n_routed_experts = args.n_routed_experts |
|
|
self.n_activated_experts = args.n_activated_experts |
|
|
self.active_expert_idx = None |
|
|
|
|
|
self.gate = Gate(args) |
|
|
self.experts = nn.ModuleList([ |
|
|
Expert(args.dim, args.moe_inter_dim) |
|
|
for _ in range(args.n_routed_experts) |
|
|
]) |
|
|
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) |
|
|
self.ffn_norm = RMSNorm(args.dim) |
|
|
|
|
|
|
|
|
self.lb_loss_coef = 0.01 |
|
|
|
|
|
def set_active_expert(self, expert_idx: Optional[int]): |
|
|
"""Freeze all but the active expert to save optimizer memory""" |
|
|
self.active_expert_idx = expert_idx |
|
|
|
|
|
for i, expert in enumerate(self.experts): |
|
|
requires_grad = (expert_idx is None) or (i == expert_idx) |
|
|
for param in expert.parameters(): |
|
|
param.requires_grad = requires_grad |
|
|
|
|
|
def compute_load_balance_loss(self, router_probs, expert_indices): |
|
|
"""Encourage uniform expert utilization""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens_per_expert = torch.zeros(self.n_routed_experts, device=router_probs.device) |
|
|
indices_flat = expert_indices.view(-1) |
|
|
ones = torch.ones_like(indices_flat, dtype=torch.float32) |
|
|
tokens_per_expert.scatter_add_(0, indices_flat, ones) |
|
|
tokens_per_expert = tokens_per_expert / (indices_flat.numel() + 1e-8) |
|
|
|
|
|
|
|
|
router_prob_per_expert = router_probs.mean(dim=0) |
|
|
|
|
|
|
|
|
loss = torch.mean(tokens_per_expert * router_prob_per_expert) * self.n_routed_experts |
|
|
return loss |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
original_shape = x.size() |
|
|
x = x.view(-1, self.dim) |
|
|
|
|
|
router_logits = linear(x, self.gate.weight, self.gate.bias) |
|
|
router_probs = router_logits.sigmoid() |
|
|
weights, indices = torch.topk(router_probs, self.n_activated_experts, dim=-1) |
|
|
|
|
|
|
|
|
weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
weights = weights * self.gate.route_scale |
|
|
|
|
|
|
|
|
if self.training and self.active_expert_idx is not None: |
|
|
|
|
|
y = torch.zeros_like(x) |
|
|
i = self.active_expert_idx |
|
|
|
|
|
|
|
|
mask = (indices == i) |
|
|
idx = torch.where(mask.any(dim=1))[0] |
|
|
|
|
|
if idx.numel() > 0: |
|
|
top_positions = torch.argmax(mask[idx].int(), dim=1) |
|
|
expert_weights = weights[idx, top_positions].unsqueeze(-1) |
|
|
expert_out = self.experts[i](x[idx]) |
|
|
y[idx] = expert_out * expert_weights |
|
|
|
|
|
|
|
|
lb_loss = self.compute_load_balance_loss(router_probs, indices) |
|
|
|
|
|
|
|
|
z = self.shared_experts(x) |
|
|
return (y + z).view(original_shape), lb_loss |
|
|
|
|
|
else: |
|
|
|
|
|
y = torch.zeros_like(x) |
|
|
for i in range(self.n_routed_experts): |
|
|
mask = (indices == i) |
|
|
idx = torch.where(mask.any(dim=1))[0] |
|
|
|
|
|
if idx.numel() == 0: |
|
|
continue |
|
|
|
|
|
top_positions = torch.argmax(mask[idx].int(), dim=1) |
|
|
expert_weights = weights[idx, top_positions].unsqueeze(-1) |
|
|
expert_out = self.experts[i](x[idx]) |
|
|
y[idx] += expert_out * expert_weights |
|
|
|
|
|
z = self.shared_experts(x) |
|
|
output = (y + z).view(original_shape) |
|
|
|
|
|
|
|
|
if self.training: |
|
|
lb_loss = self.compute_load_balance_loss(router_probs, indices) |
|
|
return output, lb_loss |
|
|
else: |
|
|
return output, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, dim: int, inter_dim: int): |
|
|
super().__init__() |
|
|
self.fc1 = Linear(dim, inter_dim, bias=False) |
|
|
self.fc2 = Linear(dim, inter_dim, bias=False) |
|
|
self.fc3 = Linear(inter_dim, dim, bias=False) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
return self.fc3(F.silu(self.fc1(x)) * self.fc2(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, layer_id: int, args: ModelArgs): |
|
|
super().__init__() |
|
|
self.attn = MultiHeadLatentAttention(args) |
|
|
|
|
|
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) |
|
|
self.attn_norm = RMSNorm(args.dim) |
|
|
self.ffn_norm = RMSNorm(args.dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) |
|
|
|
|
|
|
|
|
ffn_result = self.ffn(self.ffn_norm(x)) |
|
|
if isinstance(ffn_result, tuple): |
|
|
ffn_out, lb_loss = ffn_result |
|
|
else: |
|
|
ffn_out = ffn_result |
|
|
lb_loss = None |
|
|
|
|
|
x = x + ffn_out |
|
|
return x, lb_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ismail(nn.Module): |
|
|
def __init__(self, args: ModelArgs): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
self.vocab_size = args.vocab_size |
|
|
self.n_layers = args.n_layers |
|
|
|
|
|
|
|
|
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim, dtype=Linear.dtype) |
|
|
nn.init.normal_(self.tok_embeddings.weight, mean=0.0, std=0.02) |
|
|
|
|
|
self.layers = nn.ModuleList([Block(i, args) for i in range(args.n_layers)]) |
|
|
self.norm = RMSNorm(args.dim) |
|
|
self.output = Linear(args.dim, args.vocab_size, bias=False, dtype=Linear.dtype) |
|
|
self.use_checkpointing = False |
|
|
|
|
|
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) |
|
|
|
|
|
def set_active_expert(self, expert_idx: Optional[int]): |
|
|
"""Set active expert for all MoE layers (for sequential training)""" |
|
|
for layer in self.layers: |
|
|
if isinstance(layer.ffn, MoE): |
|
|
layer.ffn.set_active_expert(expert_idx) |
|
|
|
|
|
def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor: |
|
|
bsz, seqlen = tokens.shape |
|
|
h = self.tok_embeddings(tokens).to(Linear.dtype) |
|
|
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] |
|
|
|
|
|
|
|
|
if start_pos == 0: |
|
|
for layer in self.layers: |
|
|
if hasattr(layer.attn, 'kv_cache'): |
|
|
layer.attn.kv_cache.zero_() |
|
|
if hasattr(layer.attn, 'pe_cache'): |
|
|
layer.attn.pe_cache.zero_() |
|
|
|
|
|
mask = None |
|
|
if seqlen > 1: |
|
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device, dtype=h.dtype) |
|
|
mask = torch.triu(mask, diagonal=1) |
|
|
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device, dtype=h.dtype), mask]) |
|
|
|
|
|
total_lb_loss = 0.0 |
|
|
|
|
|
for layer in self.layers: |
|
|
h, lb_loss = layer(h, start_pos, freqs_cis, mask) |
|
|
if lb_loss is not None: |
|
|
total_lb_loss += lb_loss |
|
|
|
|
|
h = self.norm(h) |
|
|
output = self.output(h) |
|
|
|
|
|
|
|
|
if self.training and total_lb_loss > 0: |
|
|
return output, total_lb_loss |
|
|
else: |
|
|
return output |
|
|
|