|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SmolConfig: |
|
|
|
|
|
vocab_size: int = 49152 |
|
|
hidden_size: int = 576 |
|
|
intermediate_size: int = 1536 |
|
|
num_hidden_layers: int = 30 |
|
|
num_attention_heads: int = 9 |
|
|
num_key_value_heads: int = 3 |
|
|
max_position_embeddings: int = 8192 |
|
|
|
|
|
|
|
|
rope_theta: float = 100000.0 |
|
|
|
|
|
|
|
|
rms_norm_eps: float = 1e-5 |
|
|
|
|
|
|
|
|
attention_bias: bool = False |
|
|
mlp_bias: bool = False |
|
|
|
|
|
|
|
|
dtype: torch.dtype = torch.bfloat16 |
|
|
|
|
|
@property |
|
|
def head_dim(self) -> int: |
|
|
|
|
|
return self.hidden_size // self.num_attention_heads |
|
|
|
|
|
@classmethod |
|
|
def from_hf(cls, hf_config) -> "SmolConfig": |
|
|
""" |
|
|
Helper to build this config from a transformers LlamaConfig (Which is the config for the HuggingFace SmolLM2-135M model). |
|
|
Example: |
|
|
from transformers import AutoConfig |
|
|
hf = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
cfg = SmolConfig.from_hf(hf) |
|
|
And then pass this config to this function call to set the config for the model. |
|
|
""" |
|
|
return cls( |
|
|
vocab_size=hf_config.vocab_size, |
|
|
hidden_size=hf_config.hidden_size, |
|
|
intermediate_size=hf_config.intermediate_size, |
|
|
num_hidden_layers=hf_config.num_hidden_layers, |
|
|
num_attention_heads=hf_config.num_attention_heads, |
|
|
num_key_value_heads=getattr(hf_config, "num_key_value_heads", |
|
|
hf_config.num_attention_heads), |
|
|
max_position_embeddings=hf_config.max_position_embeddings, |
|
|
rope_theta=getattr(hf_config, "rope_theta", 10000.0), |
|
|
rms_norm_eps=hf_config.rms_norm_eps, |
|
|
attention_bias=getattr(hf_config, "attention_bias", False), |
|
|
mlp_bias=getattr(hf_config, "mlp_bias", False), |
|
|
dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
Root Mean Square Layer Normalization (RMSNorm) |
|
|
Used in LLaMA / SmolLM2 instead of LayerNorm. |
|
|
""" |
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
norm = x.pow(2).mean(dim=-1, keepdim=True) |
|
|
x = x * torch.rsqrt(norm + self.eps) |
|
|
return self.weight * x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rope_freqs(head_dim: int, base: float, device, dtype): |
|
|
""" |
|
|
Compute inverse frequencies for RoPE. |
|
|
""" |
|
|
half_dim = head_dim // 2 |
|
|
|
|
|
freq_seq = torch.arange(half_dim, device=device, dtype=dtype) |
|
|
inv_freq = 1.0 / (base ** (freq_seq / half_dim)) |
|
|
return inv_freq |
|
|
|
|
|
def build_rope_cache( |
|
|
seq_len: int, |
|
|
head_dim: int, |
|
|
base: float, |
|
|
device, |
|
|
dtype, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Build cosine and sine caches for RoPE. |
|
|
Returns: |
|
|
cos: (1, 1, seq_len, head_dim/2) |
|
|
sin: (1, 1, seq_len, head_dim/2) |
|
|
""" |
|
|
inv_freq = rope_freqs(head_dim, base, device, dtype) |
|
|
|
|
|
t = torch.arange(seq_len, device=device, dtype=dtype) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
cos = freqs.cos()[None, None, :, :] |
|
|
sin = freqs.sin()[None, None, :, :] |
|
|
return cos, sin |
|
|
|
|
|
def apply_rope( |
|
|
x: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Apply RoPE to last dimension of x. |
|
|
cos, sin are broadcast to match (..., head_dim/2). |
|
|
""" |
|
|
b, h, t, d = x.shape |
|
|
half = d // 2 |
|
|
|
|
|
x1 = x[..., :half] |
|
|
x2 = x[..., half:] |
|
|
|
|
|
|
|
|
cos_t = cos[..., :t, :] |
|
|
sin_t = sin[..., :t, :] |
|
|
|
|
|
x1_rot = x1 * cos_t - x2 * sin_t |
|
|
x2_rot = x1 * sin_t + x2 * cos_t |
|
|
|
|
|
return torch.cat([x1_rot, x2_rot], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
""" |
|
|
LLaMA / SmolLM2-style attention with: |
|
|
- Q heads = num_attention_heads |
|
|
- K/V heads = num_key_value_heads (GQA/MQA) |
|
|
- RoPE on Q and K |
|
|
- Causal masking |
|
|
""" |
|
|
def __init__(self, config: SmolConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.n_heads = config.num_attention_heads |
|
|
self.n_kv_heads = config.num_key_value_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
assert self.hidden_size == self.n_heads * self.head_dim |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.n_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.k_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.n_kv_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.v_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.n_kv_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
|
|
|
self.o_proj = nn.Linear( |
|
|
self.n_heads * self.head_dim, |
|
|
self.hidden_size, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
B, T, C = x.shape |
|
|
|
|
|
|
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
q = apply_rope(q, cos, sin) |
|
|
k = apply_rope(k, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
if past_key_value is not None: |
|
|
past_k, past_v = past_key_value |
|
|
|
|
|
k = torch.cat([past_k, k], dim=2) |
|
|
v = torch.cat([past_v, v], dim=2) |
|
|
seq_len = k.shape[2] |
|
|
else: |
|
|
seq_len = T |
|
|
|
|
|
|
|
|
k_cache = k |
|
|
v_cache = v |
|
|
|
|
|
|
|
|
if self.n_kv_heads != self.n_heads: |
|
|
repeat_factor = self.n_heads // self.n_kv_heads |
|
|
k = k.repeat_interleave(repeat_factor, dim=1) |
|
|
v = v.repeat_interleave(repeat_factor, dim=1) |
|
|
|
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
|
|
|
|
|
|
|
|
|
if past_key_value is None: |
|
|
|
|
|
causal_mask = torch.full( |
|
|
(T, T), float("-inf"), device=x.device, dtype=x.dtype |
|
|
).triu(1) |
|
|
scores = scores + causal_mask.unsqueeze(0).unsqueeze(0) |
|
|
else: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
if attention_mask.dim() == 2: |
|
|
|
|
|
attention_mask = attention_mask[:, None, None, :] |
|
|
|
|
|
if attention_mask.shape[-1] != seq_len: |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
past_len = past_k.shape[2] |
|
|
extended_mask = torch.zeros(B, 1, 1, seq_len, device=attention_mask.device, dtype=attention_mask.dtype) |
|
|
extended_mask[..., past_len:] = attention_mask[..., -T:] |
|
|
attention_mask = extended_mask |
|
|
scores = scores + attention_mask |
|
|
|
|
|
|
|
|
probs = F.softmax(scores, dim=-1) |
|
|
|
|
|
|
|
|
out = torch.matmul(probs, v) |
|
|
|
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
out = self.o_proj(out) |
|
|
|
|
|
|
|
|
present_key_value = None |
|
|
if use_cache: |
|
|
|
|
|
present_key_value = (k_cache, v_cache) |
|
|
|
|
|
return out, present_key_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SmolMLP(nn.Module): |
|
|
""" |
|
|
SwiGLU MLP: |
|
|
z = W1(x) -> split -> (x1, x2) |
|
|
out = W2( SiLU(x1) * x2 ) |
|
|
""" |
|
|
def __init__(self, config: SmolConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.fc1 = nn.Linear( |
|
|
config.hidden_size, |
|
|
2 * config.intermediate_size, |
|
|
bias=config.mlp_bias, |
|
|
) |
|
|
|
|
|
self.fc2 = nn.Linear( |
|
|
config.intermediate_size, |
|
|
config.hidden_size, |
|
|
bias=config.mlp_bias, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.fc1(x) |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return self.fc2(F.silu(x1) * x2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SmolBlock(nn.Module): |
|
|
def __init__(self, config: SmolConfig): |
|
|
super().__init__() |
|
|
self.attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.attn = MultiHeadSelfAttention(config) |
|
|
self.mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.mlp = SmolMLP(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
|
|
attn_out, present_key_value = self.attn( |
|
|
self.attn_norm(x), cos, sin, attention_mask, past_key_value, use_cache |
|
|
) |
|
|
x = x + attn_out |
|
|
|
|
|
x = x + self.mlp(self.mlp_norm(x)) |
|
|
return x, present_key_value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SmolLM2(nn.Module): |
|
|
""" |
|
|
SmolLM2-135M-style LLaMA decoder-only language model. |
|
|
|
|
|
Usage: |
|
|
cfg = SmolConfig() |
|
|
model = SmolLM2(cfg) |
|
|
|
|
|
input_ids: LongTensor (B, T) |
|
|
logits = model(input_ids) |
|
|
""" |
|
|
def __init__(self, config: SmolConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.embed_tokens = nn.Embedding( |
|
|
config.vocab_size, |
|
|
config.hidden_size, |
|
|
) |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[SmolBlock(config) for _ in range(config.num_hidden_layers)] |
|
|
) |
|
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.lm_head = nn.Linear( |
|
|
config.hidden_size, |
|
|
config.vocab_size, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
|
|
|
self.lm_head.weight = self.embed_tokens.weight |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
|
|
B, T = input_ids.shape |
|
|
|
|
|
|
|
|
if past_key_values is None: |
|
|
assert T <= self.config.max_position_embeddings, ( |
|
|
f"Sequence length {T} exceeds max_position_embeddings " |
|
|
f"{self.config.max_position_embeddings}" |
|
|
) |
|
|
seq_len = T |
|
|
else: |
|
|
|
|
|
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 |
|
|
seq_len = past_len + T |
|
|
assert seq_len <= self.config.max_position_embeddings, ( |
|
|
f"Total sequence length {seq_len} exceeds max_position_embeddings " |
|
|
f"{self.config.max_position_embeddings}" |
|
|
) |
|
|
|
|
|
|
|
|
x = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
cos, sin = build_rope_cache( |
|
|
seq_len=seq_len, |
|
|
head_dim=self.config.head_dim, |
|
|
base=self.config.rope_theta, |
|
|
device=x.device, |
|
|
dtype=x.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
past_len = past_key_values[0][0].shape[2] if past_key_values[0] is not None else 0 |
|
|
|
|
|
cos = cos[..., past_len:, :] |
|
|
sin = sin[..., past_len:, :] |
|
|
|
|
|
|
|
|
present_key_values = [] if use_cache else None |
|
|
for i, layer in enumerate(self.layers): |
|
|
past_kv = past_key_values[i] if past_key_values is not None else None |
|
|
x, present_kv = layer(x, cos, sin, attention_mask, past_kv, use_cache) |
|
|
if use_cache: |
|
|
present_key_values.append(present_kv) |
|
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
return logits, present_key_values |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_new_tokens: int = 100, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
eos_token_id: Optional[int] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate text using KV cache for efficient inference. |
|
|
|
|
|
Args: |
|
|
input_ids: (B, T) input token ids |
|
|
max_new_tokens: maximum number of new tokens to generate |
|
|
temperature: sampling temperature |
|
|
top_k: top-k sampling (keep top k tokens) |
|
|
top_p: nucleus sampling (keep tokens with cumulative probability <= top_p) |
|
|
eos_token_id: end-of-sequence token id (stop generation when encountered) |
|
|
|
|
|
Returns: |
|
|
generated_ids: (B, T + max_new_tokens) generated token ids |
|
|
""" |
|
|
self.eval() |
|
|
device = input_ids.device |
|
|
B, T = input_ids.shape |
|
|
|
|
|
|
|
|
generated_ids = input_ids.clone() |
|
|
past_key_values = None |
|
|
|
|
|
for step in range(max_new_tokens): |
|
|
|
|
|
|
|
|
if past_key_values is None: |
|
|
|
|
|
current_input = generated_ids |
|
|
else: |
|
|
|
|
|
current_input = generated_ids[:, -1:] |
|
|
|
|
|
logits, past_key_values = self.forward( |
|
|
input_ids=current_input, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
|
|
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
if top_p is not None: |
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
generated_ids = torch.cat([generated_ids, next_token], dim=1) |
|
|
|
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
|
break |
|
|
|
|
|
return generated_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
cfg = SmolConfig() |
|
|
model = SmolLM2(cfg) |
|
|
|
|
|
B, T = 2, 16 |
|
|
x = torch.randint(0, cfg.vocab_size, (B, T)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits, _ = model(x) |
|
|
|
|
|
print("Input shape :", x.shape) |
|
|
print("Logits shape:", logits.shape) |
|
|
|