Way-sft-plamo-3-8b-chat / modeling_plamo.py
WayBob's picture
Upload folder using huggingface_hub
e31fed6 verified
import enum
import os
import warnings
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
from transformers.cache_utils import DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
# Check if Flash Attention should be enabled
USE_FLASH_ATTENTION_FOR_POST_TRAINING = (
os.environ.get("PLAMO3_MODELING_PLAMO_USE_FLASH_ATTENTION_FOR_POST_TRAINING", "0") == "1"
)
if USE_FLASH_ATTENTION_FOR_POST_TRAINING:
try:
from flash_attn import flash_attn_func
except ImportError:
warnings.warn(
"PLAMO3_MODELING_PLAMO_USE_FLASH_ATTENTION_FOR_POST_TRAINING is set but flash_attn is not installed. "
"Falling back to scaled_dot_product_attention. "
"Install it via `pip install flash-attn` to use Flash Attention.",
stacklevel=2,
)
USE_FLASH_ATTENTION_FOR_POST_TRAINING = False
def _swiglu(h: torch.Tensor) -> torch.Tensor:
h0, h1 = h.chunk(2, dim=-1)
return torch.nn.functional.silu(h0) * h1
class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
) -> None:
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (_rotate_half(x) * sin)
return x_embed
class LinearType(str, enum.Enum):
Normal = "normal"
Fp8 = "fp8"
def is_full_attn(sliding_window_pattern: int, layer_idx: int) -> bool:
return not bool((layer_idx + 1) % sliding_window_pattern)
class Plamo3Config(PretrainedConfig): # type: ignore
model_type: str = "plamo3"
def __init__(
self,
hidden_size: int = 4096,
num_hidden_layers: int = 32,
rms_norm_eps: float = 1e-6,
tie_word_embeddings: bool = True,
# Attention
num_attention_heads: int = 32,
num_key_value_heads: int = 4,
head_dim: int = 128,
max_position_embeddings: int = 2048,
window_size: int = 2048,
sliding_window_pattern: int = 8,
rope_theta: int = 1000000,
rope_local_theta: int = 10000,
# MLP
intermediate_size: int = 13312,
# Tokenizer
vocab_size: int = 32000,
tokenizer_class: str = "Plamo3Tokenizer",
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
# Multimodal
image_token_id: Optional[int] = None,
image_feature_size: Optional[int] = None,
image_proj_type: Literal["linear", "mlp"] = "linear",
# FP8
linear_type: LinearType = LinearType.Normal,
# Evaluation
use_cache: bool = True,
**kwargs: Any,
) -> None:
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.rms_norm_eps = rms_norm_eps
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.window_size = window_size
self.sliding_window_pattern = sliding_window_pattern
self.rope_theta = rope_theta
self.rope_local_theta = rope_local_theta
self.intermediate_size = intermediate_size
self.vocab_size = vocab_size
self.image_token_id = image_token_id
self.image_feature_size = image_feature_size
self.image_proj_type = image_proj_type
self.linear_type = linear_type
self.use_cache = use_cache
self.interleaved_sliding_window: list[int | None] = []
for i in range(self.num_hidden_layers):
if is_full_attn(self.sliding_window_pattern, i):
self.interleaved_sliding_window.append(None)
else:
self.interleaved_sliding_window.append(self.window_size)
assert len(self.interleaved_sliding_window) == self.num_hidden_layers
super().__init__(
tokenizer_class=tokenizer_class,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def layer_types(self) -> list[str]:
return [
"full_attention" if sliding_window_size is None else "sliding_attention"
for sliding_window_size in self.interleaved_sliding_window
]
@property
def layers_block_type(self) -> list[str]:
return ["attention" for i in range(self.num_hidden_layers)]
@property
def rope_local_base_freq(self) -> int:
return self.rope_local_theta
class Plamo3Cache(DynamicCache): # type: ignore
def __init__(self, config: Plamo3Config) -> None:
super().__init__()
self.config = config
def finalize(self, layer_idx: int) -> None:
full_attn = self.config.layer_types[layer_idx] == "full_attention"
if full_attn:
return
window_size = self.config.window_size
assert self[layer_idx] is not None
key, value = self[layer_idx]
self.layers[layer_idx].keys = key[:, :, -window_size:, :]
self.layers[layer_idx].values = value[:, :, -window_size:, :]
def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
if layer_idx is not None:
k, _ = self[layer_idx]
return k.shape[2] # type: ignore
sequence_length: int | None = None
for layer_cache in iter(self):
key = layer_cache[0]
sequence_length = max(key.shape[2], sequence_length) if sequence_length is not None else key.shape[2]
if sequence_length is None:
return 0
return sequence_length
class DecoderInput(NamedTuple):
hidden_states: torch.Tensor
attention_mask: Optional[torch.Tensor] = None
past_states: Optional[Plamo3Cache] = None
output_hidden_states: Optional[bool] = False
output_attentions: Optional[bool] = False
gradient_checkpointing: bool = False
input_ids: Optional[torch.Tensor] = None
class DecoderOutput(NamedTuple):
hidden_states: torch.Tensor
all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
all_self_attns: Optional[Tuple[torch.Tensor, ...]]
def _make_causal_mask(
input_ids_shape: Tuple[int, int],
dtype: torch.dtype,
device: torch.device,
seq_len: int,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Make causal mask used for bi-directional self-attention.
Follows the logic in `LlamaModel._prepare_4d_causal_attention_mask_with_cache_position`
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L664
NOTE(murai): seq_len (sequence_length) and tgt_len(target_length) are swapped in the original code.
Our implementation:
- seq_len: the length of the sequences which is being processed as well as which have been processed
- tgt_len: the length of the sequences which is being processed
Original (Llama) implementation:
- sequence_length: "The sequence length being processed"
- target_length: "when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet."
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, seq_len), float("-inf"), device=device)
if tgt_len != 1:
# TODO(murai): is this necessary?
mask = torch.triu(mask, diagonal=1)
mask = torch.where(torch.arange(seq_len, device=device) > cache_position.reshape(-1, 1), mask, 0.0)
mask = mask.to(dtype)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, seq_len)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
def _rms_norm(
hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
hidden_states = hidden_states.to(input_dtype)
if weight is not None:
hidden_states = (offset + weight) * hidden_states
return hidden_states
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
offset: float = 1.0,
device: Optional[Union[torch.device, str]] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
max_len = max(q_len, kv_len)
mask = (
torch.ones(max_len, max_len, dtype=torch.bool, device=device)
.triu(diagonal=-window_size)
.tril(diagonal=window_size)
)
return mask[-q_len:, -kv_len:]
class Attention(torch.nn.Module):
def __init__(self, config: Plamo3Config, layer_idx: int) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
head_dim = config.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
assert self.q_num_heads % self.k_num_heads == 0
self.n_group = self.q_num_heads // self.k_num_heads
self.q_proj_dim = self.q_num_heads * self.qk_dim
self.k_proj_dim = self.k_num_heads * self.qk_dim
self.v_proj_dim = self.v_num_heads * self.v_dim
self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
self.q_norm = RMSNorm(self.qk_dim, eps=self.config.rms_norm_eps, offset=1.0)
self.k_norm = RMSNorm(self.qk_dim, eps=self.config.rms_norm_eps, offset=1.0)
self.full_attn = config.layer_types[layer_idx] == "full_attention"
base = self.config.rope_theta if self.full_attn else self.config.rope_local_theta
self.rotary_emb = RotaryEmbedding(
self.qk_dim, max_position_embeddings=self.config.max_position_embeddings, base=base
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_states: Optional[Plamo3Cache] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Plamo3Cache]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(
qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
)
query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
attn_dtype = query_states.dtype
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if past_states is not None:
key_states, value_states = past_states.update(key_states, value_states, self.layer_idx)
past_states.finalize(self.layer_idx)
kv_seq_len = key_states.shape[-2]
device = hidden_states.device
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
q_position_ids = position_ids[:, -query_states.shape[2] :]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
query_states = query_states.to(attn_dtype)
key_states = key_states.to(attn_dtype)
value_states = value_states.to(attn_dtype)
if attention_mask is not None and attention_mask.dtype != torch.bool:
attention_mask = attention_mask.to(attn_dtype)
if USE_FLASH_ATTENTION_FOR_POST_TRAINING:
# It is assumed that there's no padding on the left side.
# attention_mask is ignored.
if self.full_attn:
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=True, enable_gqa=True
)
else:
# Use Flash Attention for sliding window attention
# Flash attention output is (N, L, H, C), transpose to (N, H, L, C) for consistency
attn_output = flash_attn_func(
query_states.transpose(1, 2),
key_states.transpose(1, 2),
value_states.transpose(1, 2),
window_size=(self.config.window_size, 0),
causal=True,
).transpose(1, 2)
elif attention_mask is None:
assert self.full_attn or key_states.shape[2] <= self.config.window_size + 1
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=True, enable_gqa=True
)
else:
if attention_mask.dtype == torch.bool:
attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
if len(attention_mask.shape) == 2:
attention_mask = attention_mask[None, None]
assert len(attention_mask.shape) == 4
if not self.full_attn:
m_swa = swa_mask(
query_states.shape[2], key_states.shape[2], query_states.device, self.config.window_size
)
# `generate` function creates attention mask that does not consider sliding window
m_swa = m_swa[None, None]
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
bool_mask = torch.logical_not(torch.isneginf(attention_mask))
valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_states
class MLP(nn.Module):
def __init__(self, config: Plamo3Config) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.gate_up_proj(x)
h = _swiglu(h)
return self.down_proj(h) # type: ignore
class Plamo3DecoderLayer(torch.nn.Module):
def __init__(self, config: Plamo3Config, layer_idx: int) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.mixer: torch.nn.Module
self.mixer = Attention(config, layer_idx)
self.mlp = MLP(config)
"""
Notes: The model performance was degraded when setting all offsets to 1.
"""
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_state: Optional[Plamo3Cache] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[Any, ...]:
# from LlamaDecoder
residual = hidden_states
hidden_states = self.pre_mixer_norm(hidden_states)
# Self Attention
hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_states=past_state,
output_attentions=output_attentions,
)
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
hidden_states = residual + hidden_states_sa
residual = hidden_states
hidden_states = self.pre_mlp_norm(hidden_states)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
# Residual
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
hidden_states = residual + hidden_states_mlp
outputs: Any = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs # type: ignore
class Plamo3Decoder(torch.nn.Module):
def __init__(self, config: Plamo3Config) -> None:
super().__init__()
self.layers = torch.nn.ModuleList(
[Plamo3DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(self, x: DecoderInput) -> DecoderOutput:
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
hidden_states = x.hidden_states
for decoder_layer in self.layers:
if x.output_hidden_states:
assert all_hidden_states is not None
all_hidden_states += (hidden_states,)
if self.training and x.gradient_checkpointing:
layer_outputs = self._gradient_checkpointing_func( # type: ignore
decoder_layer.__call__,
hidden_states,
x.attention_mask,
x.past_states,
x.output_attentions,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=x.attention_mask,
past_state=x.past_states,
output_attentions=x.output_attentions,
)
hidden_states = layer_outputs[0]
if x.output_attentions:
assert layer_outputs[1] is not None
assert all_self_attns is not None
all_self_attns += (layer_outputs[1],)
return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
class Plamo3PreTrainedModel(PreTrainedModel): # type: ignore
config_class = Plamo3Config
_no_split_modules: List[str]
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["PlamoDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
def _init_weights(self, module: torch.nn.Module) -> None:
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Plamo3Model(Plamo3PreTrainedModel):
def __init__(self, config: Plamo3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if config.image_feature_size is not None:
if config.image_proj_type == "mlp":
self.image_proj = MLPImageProjector(config) # type: ignore
elif config.image_proj_type == "linear":
self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
else:
raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
self.layers = Plamo3Decoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.embed_tokens
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
self.embed_tokens = value
def _prepare_decoder_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
inputs_embeds: torch.Tensor,
cache_position: torch.LongTensor,
) -> Optional[torch.Tensor]:
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
seq_len=attention_mask.shape[-1],
cache_position=cache_position,
)
input_shape = (input_shape[0], combined_attention_mask.shape[2])
if attention_mask.dim() == 4:
# Custom 4D attention mask
expanded_attn_mask = attention_mask
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Plamo3Cache | DynamicCache] = None,
inputs_embeds: Optional[torch.Tensor] = None,
image_features: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Any,
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length, _ = inputs_embeds.shape
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
# In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
if not isinstance(past_key_values, Plamo3Cache):
past_key_values_prev = past_key_values
past_key_values = Plamo3Cache(self.config)
for layer_idx in range(len(past_key_values_prev)):
layer = past_key_values_prev.layers[layer_idx]
if layer.keys is not None and layer.values is not None:
past_key_values.update(layer.keys, layer.values, layer_idx=layer_idx)
assert isinstance(past_key_values, Plamo3Cache)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length,
past_key_values_length + seq_length,
device=inputs_embeds.device,
) # type: ignore
if image_features is not None:
assert self.config.image_token_id is not None
image_embeds = self.image_proj(image_features)
assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
mask = input_ids == self.config.image_token_id
inputs_embeds[mask] = image_embeds[mask]
# embed positions
require_attn_mask = False
if not self.training or past_key_values is not None:
require_attn_mask = True
if seq_length_with_past > self.config.window_size + 1:
require_attn_mask = True
if require_attn_mask and attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
if attention_mask is not None:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
cache_position, # type: ignore
)
hidden_states = inputs_embeds
if use_cache and past_key_values is None:
past_key_values = Plamo3Cache(self.config)
# decoder layers
out = self.layers(
DecoderInput(
hidden_states,
attention_mask,
past_key_values,
output_hidden_states,
output_attentions,
self.gradient_checkpointing,
)
)
assert isinstance(out, DecoderOutput)
hidden_states = out.hidden_states
all_hidden_states = out.all_hidden_states
all_self_attns = out.all_self_attns
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
assert all_hidden_states is not None
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Plamo3ForCausalLM(Plamo3PreTrainedModel, GenerationMixin): # type: ignore
_tied_weights_keys = ["lm_head.weight"]
# Without this, the model cannot be loaded into a meta device.
# Relevant code:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
# https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
_supports_param_buffer_assignment = False
def __init__(self, config: Plamo3Config) -> None:
super().__init__(config)
self.model = Plamo3Model(config)
self.vocab_size = config.vocab_size
vocab_size = ((self.vocab_size + 15) // 16) * 16
self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.model.embed_tokens
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
self.model.embed_tokens = value
def get_output_embeddings(self) -> torch.nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
self.lm_head = new_embeddings
def set_decoder(self, decoder: Plamo3Model) -> None:
self.model = decoder
def get_decoder(self) -> Plamo3Model:
return self.model
def forward( # type: ignore
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Plamo3Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_features: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Any,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
image_features=image_features,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = logits[:, slice_indices, : self.vocab_size]
loss = None
if labels is not None:
if len(kwargs) > 0 and set(kwargs.keys()) != set(["ignore_index"]):
warnings.warn(
f"The following kwargs may not be supported: {', '.join(kwargs.keys())}. ",
stacklevel=2,
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[Plamo3Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
image_features: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if past_key_values and all(k.keys is not None for k in past_key_values.layers):
input_ids = input_ids[:, -1:]
if image_features is not None:
image_features = image_features[:, -1:, :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"output_attentions": kwargs.get("output_attentions"),
"output_hidden_states": kwargs.get("output_hidden_states"),
"logits_to_keep": kwargs.get("logits_to_keep"),
"attention_mask": attention_mask,
"image_features": image_features,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values: Plamo3Cache, beam_idx: torch.Tensor) -> Plamo3Cache:
past_key_values.reorder_cache(beam_idx)
return past_key_values
class MLPImageProjector(nn.Module):
def __init__(self, config: Plamo3Config) -> None:
super().__init__()
self.config = config
assert config.image_feature_size is not None # for typing
# nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
self.bias0 = Bias(config.image_feature_size)
# PFVM doesn't support Linear with bias, so add bias manually afterwards.
self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
self.bias1 = Bias(config.hidden_size)
self.act1 = nn.GELU()
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.bias2 = Bias(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.norm0(hidden_states)
hidden_states = self.bias0(hidden_states)
hidden_states = self.linear1(hidden_states)
hidden_states = self.bias1(hidden_states)
hidden_states = self.act1(hidden_states)
hidden_states = self.linear2(hidden_states)
hidden_states = self.bias2(hidden_states)
return hidden_states
class Bias(nn.Module):
def __init__(self, num_features: int) -> None:
super().__init__()
self._bias = nn.Parameter(torch.zeros((num_features,)))
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
return x + self._bias