flexqwen / flexqwen.py
mmcarpi's picture
Add flexqwen.py to root for trust_remote_code
9993140 verified
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import ModelOutput
from transformers.modeling_outputs import (
SequenceClassifierOutput,
CausalLMOutputWithPast,
)
from .common import (
FeedForward,
MoEFeedForward,
RMSNorm,
compute_rope_params,
apply_rope,
)
class FlexQwenConfig(PretrainedConfig):
model_type = "flexqwen"
def __init__(
self,
vocab_size: int = 64000,
embedding_dim: int = 1024,
hidden_dim: int = 2048,
num_attention_heads: int = 8,
num_kv_groups: int = 8,
head_dim: int = 128,
qk_norm: bool = True,
moe_num_experts: int = 0,
moe_num_experts_per_token: int = -1,
moe_hidden_dim: int = 512,
num_hidden_layers: int = 32,
max_position_embeddings: int = 1024,
rms_norm_eps: float = 1e-6,
rope_theta: int = 10000,
initializer_range: float = 0.02,
cls_token_id: int = 1,
pad_token_id: int = 3,
tie_word_embeddings: bool = True,
dropout_rate: float = 0.0,
**kwargs,
):
super().__init__(
cls_token_id=cls_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# Vocab & Embeddings
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
# Attention Mechanism
self.num_attention_heads = num_attention_heads
self.num_kv_groups = num_kv_groups
self.head_dim = head_dim
self.qk_norm = qk_norm
# Feed-Forward & MoE
self.moe_num_experts = moe_num_experts
self.moe_num_experts_per_token = moe_num_experts_per_token
self.moe_hidden_dim = moe_hidden_dim
# General Architecture
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
# Initialization
self.initializer_range = initializer_range
# Standard HF Config params
self.tie_word_embeddings = tie_word_embeddings
self.dropout_rate = dropout_rate
# pyrefly: ignore
class FlexQwenPreTrainedModel(PreTrainedModel):
config_class = FlexQwenConfig
base_model_prefix = "model"
_supports_cache_class = True
def _init_weights(self, module):
if isinstance(module, nn.Embedding):
module.weight.data.uniform_(
-self.config.initializer_range, self.config.initializer_range
)
elif isinstance(module, nn.Linear):
module.weight.data.uniform_(
-self.config.initializer_range, self.config.initializer_range
)
if module.bias is not None:
module.bias.data.zero_()
class GroupedQueryAttention(nn.Module):
def __init__(
self,
in_features: int,
num_heads: int,
num_kv_groups: int,
head_dim: int | None = None,
qk_norm: int = False,
rms_norm_eps: float = 1e-6,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
layer_idx: int = 0,
):
assert num_heads % num_kv_groups == 0, (
"num_heads must be divisible by num_kv_groups"
)
factory_kwargs = dict(device=device, dtype=dtype)
super().__init__()
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
if head_dim is None:
assert in_features % num_heads == 0, (
"input_dim must be divisible by num_heads"
)
head_dim = in_features // num_heads
self.head_dim = head_dim
self.out_features = num_heads * head_dim
self.wq = nn.Linear(
in_features, self.out_features, bias=False, **factory_kwargs
)
self.wkv = nn.Linear(
in_features, 2 * num_kv_groups * head_dim, bias=False, **factory_kwargs
)
self.out_proj = nn.Linear(
self.out_features, in_features, bias=False, **factory_kwargs
)
self.qk_norm = qk_norm
if self.qk_norm:
self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs)
self.layer_idx = layer_idx
def forward(
self,
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
attention_mask: Optional[torch.BoolTensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> tuple[torch.FloatTensor, Optional[Cache]]:
batch_size, num_tokens, _ = x.shape
query = self.wq(x)
key, value = self.wkv(x).chunk(2, dim=-1)
query = query.view(
batch_size, num_tokens, self.num_heads, self.head_dim
).transpose(1, 2)
key = key.view(
batch_size, num_tokens, self.num_kv_groups, self.head_dim
).transpose(1, 2)
value = value.view(
batch_size, num_tokens, self.num_kv_groups, self.head_dim
).transpose(1, 2)
if self.qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)
if cache_position is None:
offset = (
past_key_value.get_seq_length(self.layer_idx)
if past_key_value is not None
else 0
)
else:
offset = int(cache_position[0].item())
query = apply_rope(query, cos, sin, offset=offset)
key = apply_rope(key, cos, sin, offset=offset)
if past_key_value is not None:
cache_kwargs = {"cache_position": cache_position}
key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)
attn_output = nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
enable_gqa=True,
)
out = self.out_proj(
attn_output.transpose(1, 2).reshape(
batch_size, num_tokens, self.out_features
)
)
return out, past_key_value
class Transformer(nn.Module):
def __init__(
self,
embedding_dim: int,
hidden_dim: int,
num_heads: int,
head_dim: int,
num_kv_groups: int,
qk_norm: int = False,
moe_num_experts_per_token: int = 8,
moe_num_experts: int = 0,
moe_hidden_dim: int = 128,
rms_norm_eps: float = 1e-6,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
layer_idx: int = 0,
):
factory_kwargs = dict(device=device, dtype=dtype)
super().__init__()
self.attn = GroupedQueryAttention(
in_features=embedding_dim,
num_heads=num_heads,
head_dim=head_dim,
num_kv_groups=num_kv_groups,
qk_norm=qk_norm,
layer_idx=layer_idx,
**factory_kwargs,
)
if moe_num_experts > 0:
self.ff: MoEFeedForward | FeedForward = MoEFeedForward(
embedding_dim=embedding_dim,
hidden_dim=moe_hidden_dim,
num_experts_per_token=moe_num_experts_per_token,
num_experts=moe_num_experts,
device=device,
dtype=dtype,
)
else:
self.ff = FeedForward(
embedding_dim, hidden_dim=hidden_dim, **factory_kwargs
)
self.norm1 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
self.norm2 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs)
def forward(
self,
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
attention_mask: Optional[torch.BoolTensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> tuple[torch.FloatTensor, Optional[Cache]]:
residual = x
x = self.norm1(x)
x, past_key_value = self.attn(
x,
cos,
sin,
attention_mask=attention_mask,
past_key_value=past_key_value,
cache_position=cache_position,
)
x += residual
residual = x
x = self.norm2(x)
x = self.ff(x)
x += residual
return x, past_key_value
@dataclass
class FlexQwenOutputWithPast(ModelOutput):
last_hidden_states: tuple[torch.FloatTensor]
attentions: Optional[tuple[torch.FloatTensor]] = None
past_key_values: Optional[Cache] = None
class FlexQwen(FlexQwenPreTrainedModel):
config_class = FlexQwenConfig
def __init__(
self,
config: FlexQwenConfig,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__(config)
self.embed = nn.Embedding(
config.vocab_size,
config.embedding_dim,
padding_idx=config.pad_token_id,
device=device,
dtype=dtype,
)
self.transformer_blocks = nn.ModuleList(
[
Transformer(
embedding_dim=config.embedding_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_attention_heads,
head_dim=config.head_dim,
num_kv_groups=config.num_kv_groups,
qk_norm=config.qk_norm,
moe_num_experts_per_token=config.moe_num_experts_per_token,
moe_num_experts=config.moe_num_experts,
moe_hidden_dim=config.moe_hidden_dim,
rms_norm_eps=config.rms_norm_eps,
device=device,
dtype=dtype,
layer_idx=i,
)
for i in range(config.num_hidden_layers)
]
)
self.final_norm = RMSNorm(
config.embedding_dim, eps=config.rms_norm_eps, device=device, dtype=dtype
)
cos, sin = compute_rope_params(
head_dim=config.head_dim,
theta_base=config.rope_theta,
max_position_embeddings=config.max_position_embeddings,
dtype=dtype,
device=device,
)
self.register_buffer("cos", cos, persistent=True)
self.register_buffer("sin", sin, persistent=True)
self.config = config
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[int] = None,
is_causal: bool = True,
return_dict: bool = True,
**kwargs,
) -> FlexQwenOutputWithPast | tuple:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("Received both input_ids and input_embeds. Pass only one.")
if input_ids is None and inputs_embeds is None:
raise ValueError("Exactly one of input_ids, input_embds is required.")
if input_ids is not None:
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
x = self.embed(input_ids)
else:
x = inputs_embeds
assert x is not None
q_len = x.shape[1]
kv_len = q_len
# If we have a cache, the total key/value length is past_len + current_len
if past_key_values is not None:
kv_len += past_key_values.get_seq_length()
base_mask = torch.ones((q_len, kv_len), dtype=torch.bool, device=x.device)
if is_causal and q_len > 1:
# Shift the tril to account for past tokens
base_mask = torch.tril(base_mask, diagonal=kv_len - q_len)
if attention_mask is not None:
# Padding mask is usually (Batch, kv_len)
padding_mask = (attention_mask == 1).unsqueeze(1).unsqueeze(2)
attention_mask = base_mask.unsqueeze(0).unsqueeze(1) & padding_mask
else:
attention_mask = base_mask.unsqueeze(0).unsqueeze(1)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
for block in self.transformer_blocks:
x, past_key_values = block(
x,
self.cos,
self.sin,
attention_mask=attention_mask,
past_key_value=past_key_values,
cache_position=cache_position,
)
x = self.final_norm(x)
output = FlexQwenOutputWithPast(
last_hidden_states=(x,),
past_key_values=past_key_values if use_cache else None,
)
if not return_dict:
return output.to_tuple()
return output
class FlexQwenForCausalLM(FlexQwenPreTrainedModel, GenerationMixin):
config_class = FlexQwenConfig
_tied_weights_keys = {"lm_head.weight": "model.embed.weight"}
def __init__(
self,
config: FlexQwenConfig,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs,
):
super().__init__(config)
self.model = FlexQwen(config, device=device, dtype=dtype)
self.lm_head = nn.Linear(
config.embedding_dim,
config.vocab_size,
bias=False,
device=device,
dtype=dtype,
)
self.post_init()
def get_input_embeddings(self):
return self.model.embed
def set_input_embeddings(self, value):
self.model.embed = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def tie_weights(
self, missing_keys: set[str] | None = None, recompute_mapping: bool = True
) -> None:
super().tie_weights(
missing_keys=missing_keys, recompute_mapping=recompute_mapping
)
if getattr(self.config, "tie_word_embeddings", False):
self.lm_head.weight = self.model.embed.weight
print("Weights tied anyway, do not worry, be happy =)")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
is_causal=True,
**kwargs,
) -> CausalLMOutputWithPast | tuple:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs: FlexQwenOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=use_cache,
return_dict=True,
is_causal=is_causal,
**kwargs,
)
logits = self.lm_head(outputs.last_hidden_states[-1])
loss = None
if labels is not None:
if labels.dim() == 1:
labels = labels.unsqueeze(0)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
reduction="mean",
)
output = CausalLMOutputWithPast(
logits=logits,
# pyrefly: ignore
loss=loss,
# TODO: Implement this properly
# pyrefly: ignore
past_key_values=outputs.past_key_values if use_cache else None,
)
if not return_dict:
return output.to_tuple()
return output
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
next_sequence_length: Optional[int] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
is_first_iteration: Optional[bool] = False,
**kwargs,
) -> dict:
if past_key_values is not None:
if not is_first_iteration:
input_ids = input_ids[:, -1:] # pyrefly: ignore
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# pyrefly: ignore
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
"attention_mask": attention_mask,
"cache_position": cache_position,
"is_causal": True,
}
)
return model_inputs
class FlexQwenForSequenceClassification(FlexQwenPreTrainedModel):
config_class = FlexQwenConfig
def __init__(
self,
config: FlexQwenConfig,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__(config)
self.num_labels = config.num_labels
self.model = FlexQwen(config, device=device, dtype=dtype)
self.dropout = nn.Dropout(p=config.dropout_rate)
self.score = nn.Linear(
config.embedding_dim,
self.num_labels,
bias=True,
device=device,
dtype=dtype,
)
self.loss_fct = nn.CrossEntropyLoss() if config.num_labels > 1 else nn.MSELoss()
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
# Fix when attention mask is None
attention_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[int] = None,
is_causal=True,
**kwargs,
) -> SequenceClassifierOutput | tuple:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# pyrefly: ignore
outputs: FlexQwenOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
is_causal=is_causal,
**kwargs,
)
hidden_states = outputs.last_hidden_states[-1]
if is_causal:
if attention_mask is None:
pooled_states = hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
pooled_states = hidden_states[
torch.arange(hidden_states.shape[0], device=hidden_states.device),
sequence_lengths,
]
else:
if attention_mask is None:
pooled_states = hidden_states.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).expand(hidden_states.size())
masked_hidden_states = torch.where(mask.bool(), hidden_states, 0.0)
num_valid_tokens = (
attention_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9)
)
pooled_states = masked_hidden_states.sum(dim=1) / num_valid_tokens
logits = self.score(self.dropout(pooled_states))
loss = None
if labels is not None:
if self.num_labels == 1:
loss = self.loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = self.loss_fct(
logits.view(-1, self.num_labels),
labels.view(-1),
)
if not return_dict:
output = (logits,) + (outputs.last_hidden_states, outputs.attentions)
return (loss,) + output if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.last_hidden_states,
attentions=outputs.attentions,
)
def load_model(
checkpoint_dir: str | Path, device: str | torch.device = "cpu"
) -> FlexQwenForCausalLM:
checkpoint_dir = Path(checkpoint_dir)
from transformers import AutoConfig
from safetensors.torch import load_file
AutoConfig.register("flexqwen", FlexQwenConfig)
config = AutoConfig.from_pretrained(checkpoint_dir)
model = FlexQwenForCausalLM(config) # pyrefly: ignore
safetensors_path = checkpoint_dir / "model.safetensors"
if not safetensors_path.exists():
raise FileNotFoundError(f"Could not find {safetensors_path}.")
disk_dict = load_file(safetensors_path)
model.load_state_dict(disk_dict, strict=False)
model.tie_weights()
return model.to(device)