|
|
|
|
|
"""HybriKo Model - Hugging Face Compatible
|
|
|
|
|
|
A hybrid RNN-Attention language model optimized for Korean.
|
|
|
Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks.
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
from typing import Optional, Dict, Any, Tuple, Union
|
|
|
|
|
|
from transformers import PreTrainedModel
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
|
try:
|
|
|
from .configuration_hybridko import HybriKoConfig
|
|
|
except ImportError:
|
|
|
from configuration_hybridko import HybriKoConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
"""Root Mean Square Layer Normalization."""
|
|
|
|
|
|
def __init__(self, d_model: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(d_model))
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
|
|
return x / rms * self.weight
|
|
|
|
|
|
|
|
|
class GeGLU(nn.Module):
|
|
|
"""Gated GELU Feed-Forward Network."""
|
|
|
|
|
|
def __init__(self, d_model: int, d_ff: int):
|
|
|
super().__init__()
|
|
|
self.w1 = nn.Linear(d_model, d_ff, bias=False)
|
|
|
self.w2 = nn.Linear(d_model, d_ff, bias=False)
|
|
|
self.w3 = nn.Linear(d_ff, d_model, bias=False)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
return self.w3(F.gelu(self.w1(x)) * self.w2(x))
|
|
|
|
|
|
|
|
|
class RGLRU(nn.Module):
|
|
|
"""Real-Gated Linear Recurrent Unit (Griffin/LFM2 style)."""
|
|
|
|
|
|
def __init__(self, d_model: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.d_model = d_model
|
|
|
self.eps = eps
|
|
|
|
|
|
self.input_proj = nn.Linear(d_model, d_model * 2)
|
|
|
self.gate_proj = nn.Linear(d_model, d_model * 2)
|
|
|
self.a_param = nn.Parameter(torch.zeros(d_model))
|
|
|
self.out_proj = nn.Linear(d_model, d_model)
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
def _init_weights(self):
|
|
|
nn.init.xavier_uniform_(self.input_proj.weight)
|
|
|
nn.init.xavier_uniform_(self.gate_proj.weight)
|
|
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
|
nn.init.uniform_(self.a_param, -0.5, 0.5)
|
|
|
|
|
|
def forward(
|
|
|
self, x: torch.Tensor, h_prev: Optional[torch.Tensor] = None
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
batch, seq_len, _ = x.shape
|
|
|
|
|
|
|
|
|
input_gate = self.input_proj(x)
|
|
|
x_in, x_gate = input_gate.chunk(2, dim=-1)
|
|
|
x_in = x_in * torch.sigmoid(x_gate)
|
|
|
|
|
|
|
|
|
gates = self.gate_proj(x)
|
|
|
r, i = gates.chunk(2, dim=-1)
|
|
|
r = torch.sigmoid(r)
|
|
|
i = torch.sigmoid(i)
|
|
|
|
|
|
|
|
|
a_base = torch.sigmoid(F.softplus(self.a_param))
|
|
|
a = a_base.unsqueeze(0).unsqueeze(0) * r
|
|
|
sqrt_1_minus_a2 = torch.sqrt(torch.clamp(1 - a ** 2, min=self.eps))
|
|
|
|
|
|
|
|
|
h = h_prev if h_prev is not None else torch.zeros(
|
|
|
batch, self.d_model, device=x.device, dtype=x.dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
outputs = []
|
|
|
for t in range(seq_len):
|
|
|
h = a[:, t] * h + sqrt_1_minus_a2[:, t] * (i[:, t] * x_in[:, t])
|
|
|
outputs.append(h)
|
|
|
|
|
|
h_seq = torch.stack(outputs, dim=1)
|
|
|
return self.out_proj(h_seq), h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
|
"""Rotary Positional Embedding (RoPE)."""
|
|
|
|
|
|
def __init__(self, d_head: int, max_seq_len: int = 2048):
|
|
|
super().__init__()
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_head, 2).float() / d_head))
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
self._cache = None
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
seq_len = x.shape[2]
|
|
|
if self._cache is None or self._cache[0].shape[2] < seq_len:
|
|
|
t = torch.arange(seq_len, device=x.device, dtype=x.dtype)
|
|
|
freqs = torch.outer(t, self.inv_freq.to(x.device))
|
|
|
emb = torch.cat([freqs, freqs], dim=-1)
|
|
|
self._cache = (
|
|
|
emb.cos().unsqueeze(0).unsqueeze(0),
|
|
|
emb.sin().unsqueeze(0).unsqueeze(0),
|
|
|
)
|
|
|
return self._cache[0][:, :, :seq_len], self._cache[1][:, :, :seq_len]
|
|
|
|
|
|
|
|
|
def apply_rope(
|
|
|
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
|
) -> torch.Tensor:
|
|
|
"""Apply Rotary Positional Embedding to input tensor."""
|
|
|
d_half = x.shape[-1] // 2
|
|
|
x1, x2 = x[..., :d_half], x[..., d_half:]
|
|
|
cos = cos[..., :d_half]
|
|
|
sin = sin[..., :d_half]
|
|
|
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
|
|
|
|
|
|
|
|
|
class GQAttention(nn.Module):
|
|
|
"""Grouped Query Attention with RoPE."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
d_model: int,
|
|
|
n_heads: int = 8,
|
|
|
n_kv_heads: int = 2,
|
|
|
dropout: float = 0.0,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.n_heads = n_heads
|
|
|
self.n_kv_heads = n_kv_heads
|
|
|
self.d_head = d_model // n_heads
|
|
|
self.scale = 1.0 / math.sqrt(self.d_head)
|
|
|
self.dropout = dropout
|
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
|
|
self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False)
|
|
|
self.o_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
self.rope = RotaryEmbedding(self.d_head)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
B, L, _ = x.shape
|
|
|
|
|
|
|
|
|
q = self.q_proj(x).view(B, L, self.n_heads, self.d_head)
|
|
|
k = self.k_proj(x).view(B, L, self.n_kv_heads, self.d_head)
|
|
|
v = self.v_proj(x).view(B, L, self.n_kv_heads, self.d_head)
|
|
|
|
|
|
|
|
|
q = q.transpose(1, 2)
|
|
|
k = k.transpose(1, 2)
|
|
|
v = v.transpose(1, 2)
|
|
|
|
|
|
|
|
|
cos, sin = self.rope(q)
|
|
|
q = apply_rope(q, cos, sin)
|
|
|
k = apply_rope(k, cos, sin)
|
|
|
|
|
|
|
|
|
n_rep = self.n_heads // self.n_kv_heads
|
|
|
k = k.repeat_interleave(n_rep, dim=1)
|
|
|
v = v.repeat_interleave(n_rep, dim=1)
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
mask = torch.triu(torch.ones(L, L, device=q.device), diagonal=1).bool()
|
|
|
attn = attn.masked_fill(mask, float("-inf"))
|
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
|
|
|
if self.training and self.dropout > 0:
|
|
|
attn = F.dropout(attn, p=self.dropout)
|
|
|
|
|
|
out = (attn @ v).transpose(1, 2).contiguous()
|
|
|
return self.o_proj(out.view(B, L, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GriffinBlock(nn.Module):
|
|
|
"""RNN-based block using RGLRU."""
|
|
|
|
|
|
def __init__(self, d_model: int, ff_mult: int = 3):
|
|
|
super().__init__()
|
|
|
self.norm1 = RMSNorm(d_model)
|
|
|
self.rglru = RGLRU(d_model)
|
|
|
self.norm2 = RMSNorm(d_model)
|
|
|
self.ffn = GeGLU(d_model, d_model * ff_mult)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
rnn_out, _ = self.rglru(self.norm1(x))
|
|
|
x = x + rnn_out
|
|
|
x = x + self.ffn(self.norm2(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
|
"""Attention-based block using GQA."""
|
|
|
|
|
|
def __init__(
|
|
|
self, d_model: int, n_heads: int = 8, n_kv_heads: int = 2, ff_mult: int = 3
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.norm1 = RMSNorm(d_model)
|
|
|
self.attn = GQAttention(d_model, n_heads, n_kv_heads)
|
|
|
self.norm2 = RMSNorm(d_model)
|
|
|
self.ffn = GeGLU(d_model, d_model * ff_mult)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x = x + self.attn(self.norm1(x))
|
|
|
x = x + self.ffn(self.norm2(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybriKoPreTrainedModel(PreTrainedModel):
|
|
|
"""Base class for HybriKo models."""
|
|
|
|
|
|
config_class = HybriKoConfig
|
|
|
base_model_prefix = "hybridko"
|
|
|
supports_gradient_checkpointing = True
|
|
|
|
|
|
def _init_weights(self, module):
|
|
|
if isinstance(module, nn.Linear):
|
|
|
nn.init.normal_(module.weight, std=0.02)
|
|
|
if module.bias is not None:
|
|
|
nn.init.zeros_(module.bias)
|
|
|
elif isinstance(module, nn.Embedding):
|
|
|
nn.init.normal_(module.weight, std=0.02)
|
|
|
|
|
|
|
|
|
class HybriKoModel(HybriKoPreTrainedModel):
|
|
|
"""HybriKo: Hybrid RNN-Attention Language Model for Korean.
|
|
|
|
|
|
Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks.
|
|
|
- Layers 1, 2: GriffinBlock (RNN)
|
|
|
- Layer 3: AttentionBlock
|
|
|
- Pattern repeats...
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: HybriKoConfig):
|
|
|
super().__init__(config)
|
|
|
self.config = config
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
|
|
|
|
self.embed = nn.Embedding(config.vocab_size, config.d_model)
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
for i in range(config.n_layers):
|
|
|
if (i + 1) % 3 == 0:
|
|
|
self.layers.append(
|
|
|
AttentionBlock(
|
|
|
config.d_model, config.n_heads, config.n_kv_heads, config.ff_mult
|
|
|
)
|
|
|
)
|
|
|
else:
|
|
|
self.layers.append(GriffinBlock(config.d_model, config.ff_mult))
|
|
|
|
|
|
|
|
|
self.norm = RMSNorm(config.d_model)
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
self.lm_head.weight = self.embed.weight
|
|
|
|
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
def _forward_layer(self, layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward pass through a single layer (for checkpointing)."""
|
|
|
return layer(x)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
labels: Optional[torch.Tensor] = None,
|
|
|
return_dict: bool = True,
|
|
|
**kwargs
|
|
|
) -> Union[Dict[str, Any], CausalLMOutputWithPast]:
|
|
|
"""Forward pass.
|
|
|
|
|
|
Args:
|
|
|
input_ids: Token IDs [batch, seq_len]
|
|
|
attention_mask: Attention mask (unused for causal LM, for HF compatibility)
|
|
|
labels: Target token IDs for loss computation
|
|
|
return_dict: Whether to return a dict or CausalLMOutputWithPast
|
|
|
|
|
|
Returns:
|
|
|
CausalLMOutputWithPast or dict with 'logits' and optionally 'loss'
|
|
|
"""
|
|
|
x = self.embed(input_ids)
|
|
|
|
|
|
for layer in self.layers:
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
x = checkpoint(
|
|
|
self._forward_layer,
|
|
|
layer,
|
|
|
x,
|
|
|
use_reentrant=False,
|
|
|
)
|
|
|
else:
|
|
|
x = layer(x)
|
|
|
|
|
|
x = self.norm(x)
|
|
|
logits = self.lm_head(x)
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss = F.cross_entropy(
|
|
|
logits[:, :-1].contiguous().view(-1, self.config.vocab_size),
|
|
|
labels[:, 1:].contiguous().view(-1),
|
|
|
ignore_index=-100,
|
|
|
)
|
|
|
|
|
|
if return_dict:
|
|
|
return CausalLMOutputWithPast(
|
|
|
loss=loss,
|
|
|
logits=logits,
|
|
|
)
|
|
|
return {"logits": logits, "loss": loss}
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def generate(
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
max_new_tokens: int = 50,
|
|
|
temperature: float = 0.8,
|
|
|
top_k: Optional[int] = None,
|
|
|
top_p: Optional[float] = None,
|
|
|
**kwargs
|
|
|
) -> torch.Tensor:
|
|
|
"""Generate text tokens.
|
|
|
|
|
|
Args:
|
|
|
input_ids: Prompt token IDs [batch, seq_len]
|
|
|
max_new_tokens: Number of tokens to generate
|
|
|
temperature: Sampling temperature
|
|
|
top_k: If set, only sample from top k tokens
|
|
|
top_p: If set, use nucleus sampling with this probability
|
|
|
|
|
|
Returns:
|
|
|
Generated token IDs including prompt
|
|
|
"""
|
|
|
self.eval()
|
|
|
for _ in range(max_new_tokens):
|
|
|
idx = input_ids[:, -self.config.max_seq_len:]
|
|
|
outputs = self(idx)
|
|
|
logits = outputs.logits[:, -1] / temperature
|
|
|
|
|
|
|
|
|
if top_k is not None:
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
logits[logits < v[:, [-1]]] = float("-inf")
|
|
|
|
|
|
|
|
|
if top_p is not None:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
|
|
sorted_indices_to_remove[:, 0] = 0
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
1, sorted_indices, sorted_indices_to_remove
|
|
|
)
|
|
|
logits[indices_to_remove] = float("-inf")
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, 1)
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1)
|
|
|
return input_ids
|
|
|
|
|
|
def get_num_params(self, non_embedding: bool = True) -> int:
|
|
|
"""Return the number of parameters in the model."""
|
|
|
n_params = sum(p.numel() for p in self.parameters())
|
|
|
if non_embedding:
|
|
|
n_params -= self.embed.weight.numel()
|
|
|
return n_params
|
|
|
|
|
|
|
|
|
|
|
|
HybriKoConfig.register_for_auto_class()
|
|
|
HybriKoModel.register_for_auto_class("AutoModel")
|
|
|
HybriKoModel.register_for_auto_class("AutoModelForCausalLM")
|
|
|
|