| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
| import torch |
| import torch.nn as nn |
|
|
| from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.utils import ModelOutput |
| from transformers.modeling_outputs import ( |
| SequenceClassifierOutput, |
| CausalLMOutputWithPast, |
| ) |
|
|
| from .common import ( |
| FeedForward, |
| MoEFeedForward, |
| RMSNorm, |
| compute_rope_params, |
| apply_rope, |
| ) |
|
|
|
|
| class FlexQwenConfig(PretrainedConfig): |
| model_type = "flexqwen" |
|
|
| def __init__( |
| self, |
| vocab_size: int = 64000, |
| embedding_dim: int = 1024, |
| hidden_dim: int = 2048, |
| num_attention_heads: int = 8, |
| num_kv_groups: int = 8, |
| head_dim: int = 128, |
| qk_norm: bool = True, |
| moe_num_experts: int = 0, |
| moe_num_experts_per_token: int = -1, |
| moe_hidden_dim: int = 512, |
| num_hidden_layers: int = 32, |
| max_position_embeddings: int = 1024, |
| rms_norm_eps: float = 1e-6, |
| rope_theta: int = 10000, |
| initializer_range: float = 0.02, |
| cls_token_id: int = 1, |
| pad_token_id: int = 3, |
| tie_word_embeddings: bool = True, |
| dropout_rate: float = 0.0, |
| **kwargs, |
| ): |
| super().__init__( |
| cls_token_id=cls_token_id, |
| pad_token_id=pad_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
| |
| self.vocab_size = vocab_size |
| self.embedding_dim = embedding_dim |
| self.hidden_dim = hidden_dim |
|
|
| |
| self.num_attention_heads = num_attention_heads |
| self.num_kv_groups = num_kv_groups |
| self.head_dim = head_dim |
| self.qk_norm = qk_norm |
|
|
| |
| self.moe_num_experts = moe_num_experts |
| self.moe_num_experts_per_token = moe_num_experts_per_token |
| self.moe_hidden_dim = moe_hidden_dim |
|
|
| |
| self.num_hidden_layers = num_hidden_layers |
| self.max_position_embeddings = max_position_embeddings |
| self.rms_norm_eps = rms_norm_eps |
| self.rope_theta = rope_theta |
|
|
| |
| self.initializer_range = initializer_range |
|
|
| |
| self.tie_word_embeddings = tie_word_embeddings |
|
|
| self.dropout_rate = dropout_rate |
|
|
|
|
| |
| class FlexQwenPreTrainedModel(PreTrainedModel): |
| config_class = FlexQwenConfig |
| base_model_prefix = "model" |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Embedding): |
| module.weight.data.uniform_( |
| -self.config.initializer_range, self.config.initializer_range |
| ) |
| elif isinstance(module, nn.Linear): |
| module.weight.data.uniform_( |
| -self.config.initializer_range, self.config.initializer_range |
| ) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| num_heads: int, |
| num_kv_groups: int, |
| head_dim: int | None = None, |
| qk_norm: int = False, |
| rms_norm_eps: float = 1e-6, |
| device: torch.device | None = None, |
| dtype: torch.dtype | None = None, |
| layer_idx: int = 0, |
| ): |
| assert num_heads % num_kv_groups == 0, ( |
| "num_heads must be divisible by num_kv_groups" |
| ) |
| factory_kwargs = dict(device=device, dtype=dtype) |
| super().__init__() |
|
|
| self.num_heads = num_heads |
| self.num_kv_groups = num_kv_groups |
| self.group_size = num_heads // num_kv_groups |
|
|
| if head_dim is None: |
| assert in_features % num_heads == 0, ( |
| "input_dim must be divisible by num_heads" |
| ) |
| head_dim = in_features // num_heads |
|
|
| self.head_dim = head_dim |
| self.out_features = num_heads * head_dim |
|
|
| self.wq = nn.Linear( |
| in_features, self.out_features, bias=False, **factory_kwargs |
| ) |
| self.wkv = nn.Linear( |
| in_features, 2 * num_kv_groups * head_dim, bias=False, **factory_kwargs |
| ) |
|
|
| self.out_proj = nn.Linear( |
| self.out_features, in_features, bias=False, **factory_kwargs |
| ) |
|
|
| self.qk_norm = qk_norm |
| if self.qk_norm: |
| self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs) |
| self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps, **factory_kwargs) |
|
|
| self.layer_idx = layer_idx |
|
|
| def forward( |
| self, |
| x: torch.FloatTensor, |
| cos: torch.FloatTensor, |
| sin: torch.FloatTensor, |
| attention_mask: Optional[torch.BoolTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> tuple[torch.FloatTensor, Optional[Cache]]: |
| batch_size, num_tokens, _ = x.shape |
|
|
| query = self.wq(x) |
| key, value = self.wkv(x).chunk(2, dim=-1) |
|
|
| query = query.view( |
| batch_size, num_tokens, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
|
|
| key = key.view( |
| batch_size, num_tokens, self.num_kv_groups, self.head_dim |
| ).transpose(1, 2) |
|
|
| value = value.view( |
| batch_size, num_tokens, self.num_kv_groups, self.head_dim |
| ).transpose(1, 2) |
|
|
| if self.qk_norm: |
| query = self.q_norm(query) |
| key = self.k_norm(key) |
|
|
| if cache_position is None: |
| offset = ( |
| past_key_value.get_seq_length(self.layer_idx) |
| if past_key_value is not None |
| else 0 |
| ) |
| else: |
| offset = int(cache_position[0].item()) |
|
|
| query = apply_rope(query, cos, sin, offset=offset) |
| key = apply_rope(key, cos, sin, offset=offset) |
|
|
| if past_key_value is not None: |
| cache_kwargs = {"cache_position": cache_position} |
| key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs) |
|
|
| attn_output = nn.functional.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| enable_gqa=True, |
| ) |
| out = self.out_proj( |
| attn_output.transpose(1, 2).reshape( |
| batch_size, num_tokens, self.out_features |
| ) |
| ) |
| return out, past_key_value |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| embedding_dim: int, |
| hidden_dim: int, |
| num_heads: int, |
| head_dim: int, |
| num_kv_groups: int, |
| qk_norm: int = False, |
| moe_num_experts_per_token: int = 8, |
| moe_num_experts: int = 0, |
| moe_hidden_dim: int = 128, |
| rms_norm_eps: float = 1e-6, |
| device: torch.device | None = None, |
| dtype: torch.dtype | None = None, |
| layer_idx: int = 0, |
| ): |
| factory_kwargs = dict(device=device, dtype=dtype) |
| super().__init__() |
| self.attn = GroupedQueryAttention( |
| in_features=embedding_dim, |
| num_heads=num_heads, |
| head_dim=head_dim, |
| num_kv_groups=num_kv_groups, |
| qk_norm=qk_norm, |
| layer_idx=layer_idx, |
| **factory_kwargs, |
| ) |
|
|
| if moe_num_experts > 0: |
| self.ff: MoEFeedForward | FeedForward = MoEFeedForward( |
| embedding_dim=embedding_dim, |
| hidden_dim=moe_hidden_dim, |
| num_experts_per_token=moe_num_experts_per_token, |
| num_experts=moe_num_experts, |
| device=device, |
| dtype=dtype, |
| ) |
| else: |
| self.ff = FeedForward( |
| embedding_dim, hidden_dim=hidden_dim, **factory_kwargs |
| ) |
| self.norm1 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs) |
| self.norm2 = RMSNorm(embedding_dim, eps=rms_norm_eps, **factory_kwargs) |
|
|
| def forward( |
| self, |
| x: torch.FloatTensor, |
| cos: torch.FloatTensor, |
| sin: torch.FloatTensor, |
| attention_mask: Optional[torch.BoolTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> tuple[torch.FloatTensor, Optional[Cache]]: |
| residual = x |
| x = self.norm1(x) |
| x, past_key_value = self.attn( |
| x, |
| cos, |
| sin, |
| attention_mask=attention_mask, |
| past_key_value=past_key_value, |
| cache_position=cache_position, |
| ) |
| x += residual |
|
|
| residual = x |
| x = self.norm2(x) |
| x = self.ff(x) |
| x += residual |
|
|
| return x, past_key_value |
|
|
|
|
| @dataclass |
| class FlexQwenOutputWithPast(ModelOutput): |
| last_hidden_states: tuple[torch.FloatTensor] |
| attentions: Optional[tuple[torch.FloatTensor]] = None |
| past_key_values: Optional[Cache] = None |
|
|
|
|
| class FlexQwen(FlexQwenPreTrainedModel): |
| config_class = FlexQwenConfig |
|
|
| def __init__( |
| self, |
| config: FlexQwenConfig, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| super().__init__(config) |
|
|
| self.embed = nn.Embedding( |
| config.vocab_size, |
| config.embedding_dim, |
| padding_idx=config.pad_token_id, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| Transformer( |
| embedding_dim=config.embedding_dim, |
| hidden_dim=config.hidden_dim, |
| num_heads=config.num_attention_heads, |
| head_dim=config.head_dim, |
| num_kv_groups=config.num_kv_groups, |
| qk_norm=config.qk_norm, |
| moe_num_experts_per_token=config.moe_num_experts_per_token, |
| moe_num_experts=config.moe_num_experts, |
| moe_hidden_dim=config.moe_hidden_dim, |
| rms_norm_eps=config.rms_norm_eps, |
| device=device, |
| dtype=dtype, |
| layer_idx=i, |
| ) |
| for i in range(config.num_hidden_layers) |
| ] |
| ) |
|
|
| self.final_norm = RMSNorm( |
| config.embedding_dim, eps=config.rms_norm_eps, device=device, dtype=dtype |
| ) |
|
|
| cos, sin = compute_rope_params( |
| head_dim=config.head_dim, |
| theta_base=config.rope_theta, |
| max_position_embeddings=config.max_position_embeddings, |
| dtype=dtype, |
| device=device, |
| ) |
|
|
| self.register_buffer("cos", cos, persistent=True) |
| self.register_buffer("sin", sin, persistent=True) |
| self.config = config |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[int] = None, |
| is_causal: bool = True, |
| return_dict: bool = True, |
| **kwargs, |
| ) -> FlexQwenOutputWithPast | tuple: |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("Received both input_ids and input_embeds. Pass only one.") |
| if input_ids is None and inputs_embeds is None: |
| raise ValueError("Exactly one of input_ids, input_embds is required.") |
|
|
| if input_ids is not None: |
| if input_ids.dim() == 1: |
| input_ids = input_ids.unsqueeze(0) |
| x = self.embed(input_ids) |
| else: |
| x = inputs_embeds |
|
|
| assert x is not None |
|
|
| q_len = x.shape[1] |
| kv_len = q_len |
|
|
| |
| if past_key_values is not None: |
| kv_len += past_key_values.get_seq_length() |
|
|
| base_mask = torch.ones((q_len, kv_len), dtype=torch.bool, device=x.device) |
|
|
| if is_causal and q_len > 1: |
| |
| base_mask = torch.tril(base_mask, diagonal=kv_len - q_len) |
|
|
| if attention_mask is not None: |
| |
| padding_mask = (attention_mask == 1).unsqueeze(1).unsqueeze(2) |
| attention_mask = base_mask.unsqueeze(0).unsqueeze(1) & padding_mask |
| else: |
| attention_mask = base_mask.unsqueeze(0).unsqueeze(1) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| for block in self.transformer_blocks: |
| x, past_key_values = block( |
| x, |
| self.cos, |
| self.sin, |
| attention_mask=attention_mask, |
| past_key_value=past_key_values, |
| cache_position=cache_position, |
| ) |
|
|
| x = self.final_norm(x) |
|
|
| output = FlexQwenOutputWithPast( |
| last_hidden_states=(x,), |
| past_key_values=past_key_values if use_cache else None, |
| ) |
|
|
| if not return_dict: |
| return output.to_tuple() |
|
|
| return output |
|
|
|
|
| class FlexQwenForCausalLM(FlexQwenPreTrainedModel, GenerationMixin): |
| config_class = FlexQwenConfig |
| _tied_weights_keys = {"lm_head.weight": "model.embed.weight"} |
|
|
| def __init__( |
| self, |
| config: FlexQwenConfig, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| **kwargs, |
| ): |
| super().__init__(config) |
| self.model = FlexQwen(config, device=device, dtype=dtype) |
| self.lm_head = nn.Linear( |
| config.embedding_dim, |
| config.vocab_size, |
| bias=False, |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def tie_weights( |
| self, missing_keys: set[str] | None = None, recompute_mapping: bool = True |
| ) -> None: |
| super().tie_weights( |
| missing_keys=missing_keys, recompute_mapping=recompute_mapping |
| ) |
|
|
| if getattr(self.config, "tie_word_embeddings", False): |
| self.lm_head.weight = self.model.embed.weight |
| print("Weights tied anyway, do not worry, be happy =)") |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.BoolTensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| use_cache: Optional[bool] = None, |
| is_causal=True, |
| **kwargs, |
| ) -> CausalLMOutputWithPast | tuple: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| outputs: FlexQwenOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| return_dict=True, |
| is_causal=is_causal, |
| **kwargs, |
| ) |
|
|
| logits = self.lm_head(outputs.last_hidden_states[-1]) |
| loss = None |
| if labels is not None: |
| if labels.dim() == 1: |
| labels = labels.unsqueeze(0) |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels.view(-1), |
| ignore_index=-100, |
| reduction="mean", |
| ) |
|
|
| output = CausalLMOutputWithPast( |
| logits=logits, |
| |
| loss=loss, |
| |
| |
| past_key_values=outputs.past_key_values if use_cache else None, |
| ) |
|
|
| if not return_dict: |
| return output.to_tuple() |
|
|
| return output |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| next_sequence_length: Optional[int] = None, |
| past_key_values: Optional[Cache] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| is_first_iteration: Optional[bool] = False, |
| **kwargs, |
| ) -> dict: |
| if past_key_values is not None: |
| if not is_first_iteration: |
| input_ids = input_ids[:, -1:] |
|
|
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| |
| model_inputs.update( |
| { |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache", True), |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "is_causal": True, |
| } |
| ) |
| return model_inputs |
|
|
|
|
| class FlexQwenForSequenceClassification(FlexQwenPreTrainedModel): |
| config_class = FlexQwenConfig |
|
|
| def __init__( |
| self, |
| config: FlexQwenConfig, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.model = FlexQwen(config, device=device, dtype=dtype) |
| self.dropout = nn.Dropout(p=config.dropout_rate) |
| self.score = nn.Linear( |
| config.embedding_dim, |
| self.num_labels, |
| bias=True, |
| device=device, |
| dtype=dtype, |
| ) |
| self.loss_fct = nn.CrossEntropyLoss() if config.num_labels > 1 else nn.MSELoss() |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| |
| attention_mask: Optional[torch.BoolTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| return_dict: Optional[int] = None, |
| is_causal=True, |
| **kwargs, |
| ) -> SequenceClassifierOutput | tuple: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| outputs: FlexQwenOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| is_causal=is_causal, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_states[-1] |
|
|
| if is_causal: |
| if attention_mask is None: |
| pooled_states = hidden_states[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| pooled_states = hidden_states[ |
| torch.arange(hidden_states.shape[0], device=hidden_states.device), |
| sequence_lengths, |
| ] |
| else: |
| if attention_mask is None: |
| pooled_states = hidden_states.mean(dim=1) |
| else: |
| mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()) |
| masked_hidden_states = torch.where(mask.bool(), hidden_states, 0.0) |
| num_valid_tokens = ( |
| attention_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9) |
| ) |
| pooled_states = masked_hidden_states.sum(dim=1) / num_valid_tokens |
|
|
| logits = self.score(self.dropout(pooled_states)) |
|
|
| loss = None |
| if labels is not None: |
| if self.num_labels == 1: |
| loss = self.loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = self.loss_fct( |
| logits.view(-1, self.num_labels), |
| labels.view(-1), |
| ) |
|
|
| if not return_dict: |
| output = (logits,) + (outputs.last_hidden_states, outputs.attentions) |
| return (loss,) + output if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.last_hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| def load_model( |
| checkpoint_dir: str | Path, device: str | torch.device = "cpu" |
| ) -> FlexQwenForCausalLM: |
| checkpoint_dir = Path(checkpoint_dir) |
|
|
| from transformers import AutoConfig |
| from safetensors.torch import load_file |
|
|
| AutoConfig.register("flexqwen", FlexQwenConfig) |
|
|
| config = AutoConfig.from_pretrained(checkpoint_dir) |
| model = FlexQwenForCausalLM(config) |
|
|
| safetensors_path = checkpoint_dir / "model.safetensors" |
| if not safetensors_path.exists(): |
| raise FileNotFoundError(f"Could not find {safetensors_path}.") |
|
|
| disk_dict = load_file(safetensors_path) |
|
|
| model.load_state_dict(disk_dict, strict=False) |
|
|
| model.tie_weights() |
|
|
| return model.to(device) |
|
|