cali-0.1B / modeling_cali.py
Sandroeth's picture
Update modeling_cali.py
75b4fe7 verified
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging
from .configuration_cali import CALIConfig
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() * self.weight
def build_rope_cache(seq_len, head_dim, theta=10000.0, device=None):
half = head_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, half, device=device).float() / half))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, freqs)
cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1)[None, None, :, :]
sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1)[None, None, :, :]
return cos, sin
def apply_rope(x, cos, sin):
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
return x * cos + torch.cat([-x2, x1], dim=-1) * sin
class GroupedQueryAttention(nn.Module):
def __init__(self, config: CALIConfig):
super().__init__()
self.num_heads = config.num_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.groups = config.num_heads // config.num_kv_heads
self.scale = config.head_dim ** -0.5
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * config.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * config.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * config.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_heads * config.head_dim, config.hidden_dim, bias=False)
def forward(self, x, cos, sin, attention_mask=None, past_key_value=None, use_cache=False):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
present = (k, v) if use_cache else None
k = k.repeat_interleave(self.groups, dim=1)
v = v.repeat_interleave(self.groups, dim=1)
full_T = k.shape[2]
att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
causal = torch.triu(
torch.ones(T, full_T, device=x.device, dtype=torch.bool),
diagonal=full_T - T + 1
)
att = att.masked_fill(causal[None, None], float("-inf"))
if attention_mask is not None:
if attention_mask.dim() == 2:
padding_mask = attention_mask[:, None, None, :full_T].to(dtype=att.dtype)
padding_mask = (1.0 - padding_mask) * torch.finfo(att.dtype).min
att = att + padding_mask
elif attention_mask.dim() == 4:
att = att + attention_mask
att = F.softmax(att.float(), dim=-1).to(q.dtype)
out = torch.matmul(att, v).transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim)
return self.o_proj(out), present
class GatedFFN(nn.Module):
def __init__(self, config: CALIConfig):
super().__init__()
ffn_dim = (int(config.hidden_dim * config.ffn_multiplier) + 255) // 256 * 256
self.gate_proj = nn.Linear(config.hidden_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(config.hidden_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, config.hidden_dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class CALIBlock(nn.Module):
def __init__(self, config: CALIConfig):
super().__init__()
self.norm1 = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.attn = GroupedQueryAttention(config)
self.norm2 = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.ffn = GatedFFN(config)
def forward(self, x, cos, sin, attention_mask=None, past_key_value=None, use_cache=False):
attn_out, present = self.attn(
self.norm1(x), cos, sin,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
x = x + attn_out
x = x + self.ffn(self.norm2(x))
return x, present
class CALIPreTrainedModel(PreTrainedModel):
config_class = CALIConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CALIBlock"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CALIModel):
module.gradient_checkpointing = value
class CALIModel(CALIPreTrainedModel):
def __init__(self, config: CALIConfig):
super().__init__(config)
self.gradient_checkpointing = False
self.embed = nn.Embedding(config.vocab_size, config.hidden_dim)
self.layers = nn.ModuleList([CALIBlock(config) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.post_init()
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_hidden_states=None,
return_dict=None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.embed(input_ids)
B, T, _ = inputs_embeds.shape
device = inputs_embeds.device
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
cos, sin = build_rope_cache(T + past_len, self.config.head_dim, self.config.rope_theta, device)
cos = cos[:, :, past_len:past_len + T, :]
sin = sin[:, :, past_len:past_len + T, :]
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
present_key_values = () if use_cache else None
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_kv = past_key_values[i] if past_key_values else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, use_cache=False)
return custom_forward
hidden_states, _ = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, cos, sin, attention_mask, None,
use_reentrant=False,
)
present = None
else:
hidden_states, present = layer(
hidden_states, cos, sin,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
)
if use_cache:
present_key_values += (present,)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, present_key_values, all_hidden_states] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_values,
hidden_states=all_hidden_states,
)
class CALIForCausalLM(CALIPreTrainedModel, GenerationMixin):
def __init__(self, config: CALIConfig):
super().__init__(config)
self.model = CALIModel(config)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.model.embed.weight
self.post_init()
def get_input_embeddings(self):
return self.model.embed
def set_input_embeddings(self, value):
self.model.embed = value
def get_tied_weights(self):
return {"lm_head.weight": "model.embed.weight"} if self.config.tie_embeddings else {}
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_decoder(self):
return self.model
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)