QED-75M_artifacts / src /sllm /model.py
levossadtchi's picture
Add files using upload-large-folder tool
355eea1 verified
from __future__ import annotations
import math
from dataclasses import asdict
import torch
import torch.nn.functional as F
from torch import nn
from sllm.config import ModelConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int, theta: float) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
cos = self.cos_cached[position_ids].unsqueeze(1).to(dtype=x.dtype, device=x.device)
sin = self.sin_cached[position_ids].unsqueeze(1).to(dtype=x.dtype, device=x.device)
return (x * cos) + (rotate_half(x) * sin)
class CausalSelfAttention(nn.Module):
def __init__(self, config: ModelConfig) -> None:
super().__init__()
if config.d_model % config.n_heads != 0:
raise ValueError("d_model must be divisible by n_heads.")
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.k_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.v_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.rotary = RotaryEmbedding(self.head_dim, config.max_seq_len, config.rope_theta)
self.dropout = config.dropout
def _shape(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.shape
return x.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
query = self._shape(self.q_proj(hidden_states))
key = self._shape(self.k_proj(hidden_states))
value = self._shape(self.v_proj(hidden_states))
query = self.rotary(query, position_ids)
key = self.rotary(key, position_ids)
attn_mask = None
is_causal = True
if attention_mask is not None:
key_padding_mask = attention_mask[:, None, None, :].to(dtype=torch.bool, device=query.device)
if not torch.all(key_padding_mask):
seq_len = query.size(-2)
causal_mask = torch.ones(
(1, 1, seq_len, seq_len),
dtype=torch.bool,
device=query.device,
).tril()
attn_mask = causal_mask & key_padding_mask
is_causal = False
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2).contiguous().view(hidden_states.shape)
return self.o_proj(attn_output)
class SwiGLU(nn.Module):
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=config.bias)
self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=config.bias)
self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=config.bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.input_norm = RMSNorm(config.d_model, config.rms_norm_eps)
self.attention = CausalSelfAttention(config)
self.post_attn_norm = RMSNorm(config.d_model, config.rms_norm_eps)
self.mlp = SwiGLU(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attention(
self.input_norm(hidden_states),
position_ids=position_ids,
attention_mask=attention_mask,
)
hidden_states = hidden_states + self.mlp(self.post_attn_norm(hidden_states))
return hidden_states
class SLLMForCausalLM(nn.Module):
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model, config.rms_norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=True)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
batch_size, seq_len = input_ids.shape
if seq_len > self.config.max_seq_len:
raise ValueError(
f"Input length {seq_len} exceeds model context window {self.config.max_seq_len}."
)
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
hidden_states = self.embed_tokens(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids=position_ids, attention_mask=attention_mask)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
output = {"logits": logits}
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
output["loss"] = loss
return output
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = 50,
eos_token_id: int | None = None,
) -> torch.Tensor:
generated = input_ids
for _ in range(max_new_tokens):
context = generated[:, -self.config.max_seq_len :]
outputs = self(context)
next_token_logits = outputs["logits"][:, -1, :]
if temperature <= 0:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
next_token_logits = next_token_logits / temperature
if top_k is not None and top_k > 0:
top_k = min(top_k, next_token_logits.size(-1))
values, _ = torch.topk(next_token_logits, top_k)
cutoff = values[:, [-1]]
next_token_logits = next_token_logits.masked_fill(next_token_logits < cutoff, float("-inf"))
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and torch.all(next_token.squeeze(-1) == eos_token_id):
break
return generated
def export_config(self) -> dict:
return asdict(self.config)