smallm_140_rope / model.py
Azrail's picture
Training in progress, step 31000
889071d verified
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from .config import SmalLmConfig
from typing import Optional
import logging
from einops import rearrange, repeat
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from einops._torch_specific import allow_ops_in_compiled_graph
allow_ops_in_compiled_graph()
from transformers.utils import is_flash_attn_2_available
USE_FLASH = False
if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.bert_padding import unpad_input, pad_input
USE_FLASH = True
logger = logging.getLogger(__name__)
logger.info(f"USE FLASH: {USE_FLASH=}")
class SwiGLU(nn.Module):
def __init__(
self, input_size: int, hidden_size: int, bias: bool = False, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.input_size = input_size
self.hidden_size = hidden_size
self.up_proj = nn.Linear(input_size, hidden_size * 2, bias=bias)
self.down_proj = nn.Linear(hidden_size, input_size, bias=bias)
def forward(self, x):
up_gate = self.up_proj(x)
up, gate = rearrange(up_gate, "... (d span) -> span ... d", d=self.hidden_size)
down = F.silu(gate) * up
return self.down_proj(down)
class Router(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.experts_to_select = self.config.token_experts - self.config.shared_experts
self.gate = nn.Linear(config.hidden_size, config.routed_experts, bias=False)
self.gate_noise = (
nn.Linear(config.hidden_size, config.routed_experts, bias=False)
if config.noisy_experts is True
else None
)
self.bias_coef = config.balancing_coef
self.register_buffer(
"bias", torch.zeros(config.routed_experts), persistent=True
)
self.register_buffer(
"expert_counts", torch.zeros(config.routed_experts), persistent=False
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
# calculating with fp32 for stability
# num_tokens n_shared_experts
gate_logits = self.gate(x)
if self.gate_noise is not None:
gate_logits_noise = F.softplus(self.gate_noise(x))
gate_logits_noise = torch.randn_like(gate_logits_noise) * gate_logits_noise
gate_logits = gate_logits + gate_logits_noise
gate_weights = gate_logits.sigmoid()
original_weights = gate_weights
gate_weights = gate_weights + self.bias
_, top_experts_idx = torch.topk(gate_weights, self.experts_to_select, dim=-1)
counts = torch.bincount(
top_experts_idx.flatten(), minlength=self.config.routed_experts
).detach()
if self.training:
self.expert_counts += counts
top_experts_weights = original_weights.gather(1, top_experts_idx)
top_experts_weights = top_experts_weights / top_experts_weights.sum(
dim=-1, keepdim=True
)
return top_experts_idx, top_experts_weights.type_as(x), counts.tolist()
def update_bias(self):
mean = self.expert_counts.float().mean()
delta = self.bias_coef * torch.sign(mean - self.expert_counts)
self.bias += delta
self.expert_counts.zero_()
class MoE(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.shared_experts = SwiGLU(
config.hidden_size,
config.shared_experts * config.expert_size,
config.moe_bias,
)
self.routed_experts = nn.ModuleList(
[
SwiGLU(config.hidden_size, config.expert_size, config.moe_bias)
for _ in range(config.routed_experts)
]
)
self.router = Router(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.view(-1, self.config.hidden_size)
experts_idx, experts_weights, counts = self.router(x)
out = torch.zeros_like(x)
for i, expert in enumerate(self.routed_experts):
if counts[i] == 0:
continue
idx, pos = torch.where(experts_idx == i)
out[idx] += expert(x[idx]) * experts_weights[idx, pos, None]
shared_out = self.shared_experts(x)
return (out + shared_out).view(shape)
class ComboMoe(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.shared_experts = SwiGLU(
config.hidden_size,
config.shared_experts * config.expert_size,
config.moe_bias,
)
self.input_router = Router(config)
self.middle_router = Router(config)
self.out_router = Router(config)
self.routed_experts = nn.ModuleList(
[
nn.Linear(config.hidden_size, config.expert_size, bias=config.moe_bias)
for _ in range(config.routed_experts)
]
)
self.middle_routed_experts = nn.ModuleList(
[
nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
for _ in range(config.routed_experts)
]
)
self.out_routed_experts = nn.ModuleList(
[
nn.Linear(config.expert_size, config.hidden_size, bias=config.moe_bias)
for _ in range(config.routed_experts)
]
)
self.offset = config.routed_experts
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = x.view(-1, self.config.hidden_size)
iexpert_idx, iexpert_weights, icounts = self.input_router(x)
iout = torch.zeros((*x.shape[:-1], self.config.expert_size), device=x.device)
for i, expert in enumerate(self.routed_experts[: self.offset]):
if icounts[i] == 0:
continue
idx, pos = torch.where(iexpert_idx == i)
iout[idx] += expert(x[idx]) * iexpert_weights[idx, pos, None]
mexpert_idx, mexpert_weights, mcounts = self.middle_router(x)
for i, expert in enumerate(self.middle_routed_experts):
if mcounts[i] == 0:
continue
idx, pos = torch.where(mexpert_idx == i)
iout[idx] *= F.silu(expert(x[idx]) * mexpert_weights[idx, pos, None])
out = torch.zeros_like(x)
oexpert_idx, oexpert_weights, ocounts = self.out_router(iout)
for i, expert in enumerate(self.out_routed_experts):
if ocounts[i] == 0:
continue
idx, pos = torch.where(oexpert_idx == i)
out[idx] += expert(iout[idx]) * oexpert_weights[idx, pos, None]
shared_out = self.shared_experts(x)
return (out + shared_out).view(shape)
def build_alibi_bias(config: SmalLmConfig) -> torch.Tensor:
"""Build ALiBi for specified number of heads:
Returns:
Tensor with ALiBi biases, shape: [num heads]
"""
bias = (
2**-8
/ config.num_attention_heads
* torch.arange(1, config.num_attention_heads + 1).float()
)
return bias
def calc_rotation(num_rotaitions, dim, base, seq_len):
return (
dim
* torch.log(torch.tensor(seq_len).float() / (num_rotaitions * 2 * torch.pi))
/ torch.log(torch.tensor(base))
)
def get_ramp_interpolation(min_idx, max_idx, thetas_dim, eps=1e-6):
if min_idx == max_idx:
max_idx += eps
mult = (torch.arange(thetas_dim) - min_idx) / (max_idx - min_idx)
mult = torch.clamp(mult, 0, 1)
return 1 - mult
def build_rope_bias(config: SmalLmConfig) -> torch.Tensor:
dim = config.head_size
theta = 1.0 / (config.rope_base ** (torch.arange(0, dim, 2).float() / dim))
# neural tangent kernel by part korrection
if config.max_seq_len > config.original_seq_len:
scale = config.max_seq_len / config.original_seq_len
# from idea that lambda = 2pi / theta_i and lmbad = seq_len / num_rotations, lambda - wavelen
low_interpolation_idx = max(
0,
torch.ceil(
calc_rotation(
config.high_rotations,
dim,
config.rope_base,
config.original_seq_len,
)
).item(),
)
high_interpolation_idx = min(
dim - 1,
torch.floor(
calc_rotation(
config.low_rotations, dim, config.rope_base, config.original_seq_len
)
).item(),
)
interpolation_mult = get_ramp_interpolation(
low_interpolation_idx, high_interpolation_idx, dim // 2
)
theta = (1 - interpolation_mult) * theta / scale + interpolation_mult * theta
seq_idx = torch.arange(config.max_seq_len)
seq_theta = torch.outer(seq_idx, theta)
bias = torch.polar(torch.ones_like(seq_theta), seq_theta)
return bias
def apply_rope_bias(x: torch.Tensor, precompute_bias: torch.Tensor) -> torch.Tensor:
ini_dtype = x.dtype
# for stbility to fp32, also need for torch
x = rearrange(x.float(), "b n s (d i) -> b n s d i", i=2).contiguous()
x = torch.view_as_complex(x)
x = x * precompute_bias
x = torch.view_as_real(x)
x = rearrange(x, "b n s d i -> b n s (d i)")
return x.to(ini_dtype)
class CausalSelfAttention(nn.Module):
def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
if config.num_attention_heads % config.num_kv_heads != 0:
raise ValueError("Num attention heads should divided by num kv heads")
self.config = config
self.layer_idx = layer_idx
self.head_per_group = config.num_attention_heads // config.num_kv_heads
self.q_proj = nn.Linear(
config.hidden_size,
config.head_size * config.num_attention_heads,
bias=config.attention_bias,
)
self.kv_proj = nn.Linear(
config.hidden_size,
config.head_size * config.num_kv_heads * 2,
bias=config.attention_bias,
)
self.out_proj = nn.Linear(
config.head_size * config.num_attention_heads,
config.hidden_size,
bias=config.attention_bias,
)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: Optional[Cache | torch.FloatTensor],
cache_position: Optional[torch.LongTensor],
bias: torch.Tensor,
):
q = self.q_proj(x)
kv = self.kv_proj(x)
q = rearrange(q, "b s (n d) -> b n s d", n=self.config.num_attention_heads)
k, v = rearrange(kv, "b s (n d q) -> q b n s d", q=2, d=self.config.head_size)
if self.config.positional_bias_type == "rope":
k = apply_rope_bias(k, bias)
q = apply_rope_bias(q, bias)
if past_key_values is not None:
# for static cache
cach_kwargs = {"cache_position": cache_position}
k, v = past_key_values.update(
key_states=k,
value_states=v,
layer_idx=self.layer_idx,
cache_kwargs=cach_kwargs,
)
is_causal = attention_mask is None and q.size(-2) > 1
# for enabling flash attention kernel
if USE_FLASH and x.is_cuda:
q = rearrange(q, "b n s d -> b s n d")
k = rearrange(k, "b n s d -> b s n d")
v = rearrange(v, "b n s d -> b s n d")
q, idx_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, attention_mask)
k, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, attention_mask)
v, _, _, _, _ = unpad_input(v, attention_mask)
k = k.contiguous()
v = v.contiguous()
q = q.contiguous()
if USE_FLASH and x.is_cuda:
attention_probs = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=self.config.attention_dropout if self.training else 0.0,
causal=True,
alibi_slopes=bias if self.config.attention_bias == "alibi" else None,
)
attention_probs = pad_input(attention_probs, idx_q, x.size(0), x.size(1))
out = rearrange(attention_probs, "b s n d -> b s (n d)")
else:
attention_probs = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
enable_gqa=True,
is_causal=is_causal,
dropout_p=self.config.attention_dropout if self.training else 0.0,
)
out = rearrange(attention_probs, "b n s d -> b s (n d)")
out = self.out_proj(out)
return out, attention_probs
class WeightedResidual(nn.Module):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight = nn.Parameter(
torch.ones(config.hidden_size), requires_grad=config.static_residual
)
def forward(self, short, long):
return self.weight * short + long
class Block(nn.Module):
def __init__(self, config: SmalLmConfig, layer_idx: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.ffn_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.dropout1 = nn.Dropout(config.layer_dropout)
self.dropout2 = nn.Dropout(config.layer_dropout)
self.attention = CausalSelfAttention(config, layer_idx)
moe_class = MoE if config.moe_type == "default" else ComboMoe
self.mlp = (
moe_class(config)
if (
config.use_moe
and layer_idx % config.moe_period == 0
and layer_idx > config.no_moe_layers
)
else SwiGLU(config.hidden_size, config.intermediate_size, config.mlp_bias)
)
self.attention_residual = WeightedResidual(config)
self.ffn_residual = WeightedResidual(config)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
past_key_values: Optional[Cache | torch.FloatTensor],
output_attentions: bool,
cache_position: Optional[torch.LongTensor],
bias: torch.Tensor,
) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
identity = inputs_embeds
# attention block
out = self.attn_norm(inputs_embeds)
out, attention_probs = self.attention(
out, attention_mask, past_key_values, cache_position, bias
)
out = self.dropout1(out)
identity = self.attention_residual(identity, out)
# swiglu / MoE block
out = self.dropout2(self.mlp(self.ffn_norm(identity)))
out = self.ffn_residual(identity, out)
if output_attentions:
return out, attention_probs
return (out,)
class SmalLmPreTrainedModel(PreTrainedModel):
config_class = SmalLmConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Block"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
module.weight.data[self.pad_idx].zero_()
class SmalLmModel(SmalLmPreTrainedModel):
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.pad_idx = config.pad_token_id
self.pad_token_id = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
precompute_bias = (
build_alibi_bias(config)
if config.positional_bias_type == "alibi"
else build_rope_bias(config)
)
self.register_buffer("precompute_bias", precompute_bias, persistent=False)
# не забыть про sharing weights на output голове self.embedding.weight = self.output.weight
self.embedding = nn.Embedding(
self.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.embedding_dropout = nn.Dropout(config.embedding_dropout)
self.layers = nn.ModuleList(
[Block(config, idx) for idx in range(1, config.num_hidden_layers + 1)]
)
self.out_norm = nn.RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
elementwise_affine=config.rms_affine,
)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embedding
def set_input_embeddings(self, value):
self.embedding = value
def forward(
self,
# input options
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# output options
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# cache options
use_cache: Optional[bool] = None,
past_key_values: Optional[Cache | torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple | BaseModelOutputWithPast:
# check additional parameters
output_attentions = False
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache
if use_cache is not None
else (False if self.training else self.config.use_cache)
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You must specify only input_ids or inputs_embeds, not both"
)
if self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# calculating position for StaticCache
if cache_position is None:
last_position = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
last_position,
last_position + inputs_embeds.size(1),
device=inputs_embeds.device,
)
causal_mask = self._get_causal_masks(
attention_mask, inputs_embeds, past_key_values, cache_position
)
if self.config.positional_bias_type == "rope":
end_pos = (
inputs_embeds.size(1)
if past_key_values is None
else cache_position[-1] + 1
)
start_pos = 0 if past_key_values is None else cache_position[0]
bias = self.precompute_bias[start_pos:end_pos]
elif self.config.positional_bias_type == "alibi":
if USE_FLASH and input_ids.is_cuda:
bias = self.precompute_bias
else:
i = torch.arange(
(
inputs_embeds.size(1)
if past_key_values is None
else cache_position[-1] + 1
),
device=inputs_embeds.device,
)
bias = i[:, None] - i[None, :]
bias = torch.tril(bias).expand(
inputs_embeds.size(0), self.config.num_attention_heads, -1, -1
) * rearrange(self.precompute_bias, "n -> 1 n 1 1")
if causal_mask is not None:
causal_mask = causal_mask + bias
else:
causal_mask = bias
hidden_state = inputs_embeds
hidden_states = [hidden_state] if output_hidden_states else None
attentions = [] if output_attentions else None
for idx, layer in enumerate(self.layers, 1):
if self.gradient_checkpointing:
# for details see:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3107
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3149
layer_out = self._gradient_checkpointing_func(
layer.__call__,
hidden_state,
causal_mask,
past_key_values,
output_attentions,
cache_position,
bias,
)
else:
layer_out = layer(
hidden_state,
causal_mask,
past_key_values,
output_attentions,
cache_position,
bias,
)
hidden_state = layer_out[0]
if output_hidden_states:
hidden_states.append(hidden_state)
if output_attentions:
attentions.append(layer_out[1])
hidden_state = self.out_norm(hidden_state)
out = BaseModelOutputWithPast(
last_hidden_state=hidden_state,
past_key_values=past_key_values if use_cache else None,
hidden_states=tuple(hidden_states) if hidden_states is not None else None,
attentions=tuple(attentions) if attentions is not None else None,
)
return out if return_dict else out.to_tuple()
def _get_causal_masks(
self,
attention_mask: Optional[torch.Tensor],
inputs_embeds: torch.Tensor,
past_key_values: Optional[torch.Tensor],
cache_position: Optional[torch.Tensor],
):
if USE_FLASH and inputs_embeds.is_cuda:
return attention_mask
dtype, device = inputs_embeds.dtype, inputs_embeds.device
past_token = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
if attention_mask is not None and torch.all(attention_mask == 0.0):
return None
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values_length=past_token,
is_training=self.training,
):
return None
sequence_length = inputs_embeds.size(1)
target_length = (
attention_mask.size(-1)
if isinstance(attention_mask, torch.Tensor)
else past_token + sequence_length + 1
)
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask=attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=inputs_embeds.size(0),
)
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: Optional[torch.Tensor],
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: Optional[torch.Tensor],
batch_size: int,
):
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length]
+ attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
class SmalLmForCausalLM(SmalLmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: SmalLmConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.model = SmalLmModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
# input options
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# output options
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# cache options
use_cache: Optional[bool] = None,
past_key_values: Optional[Cache | torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
# generation options
labels: Optional[torch.Tensor] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs,
) -> tuple | CausalLMOutputWithPast:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
model_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = model_outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shifted_logits = logits[:, :-1].contiguous()
shifted_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shifted_logits.view(-1, self.config.vocab_size), shifted_labels.view(-1)
)
if not return_dict:
output = (logits, model_outputs[1:])
return (loss, output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_outputs.past_key_values,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
__all__ = ["SmalLmForCausalLM", "SmalLmModel", "SmalLmPreTrainedModel"]