| | from typing import Callable |
| | import torch |
| | from transformers import Qwen3Model |
| | from transformers.cache_utils import Cache |
| | from transformers.masking_utils import create_causal_mask |
| | from transformers.modeling_outputs import BaseModelOutputWithPooling |
| | from transformers.processing_utils import Unpack |
| | from transformers.utils import TransformersKwargs |
| | from .configuration import PPLXQwen3Config |
| |
|
| |
|
| | |
| | def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable: |
| | """ |
| | This creates bidirectional attention mask. |
| | """ |
| |
|
| | def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| | if attention_mask is None: |
| | return torch.ones((), dtype=torch.bool) |
| | return attention_mask[batch_idx, kv_idx].to(torch.bool) |
| |
|
| | return inner_mask |
| |
|
| |
|
| | class PPLXQwen3Model(Qwen3Model): |
| | _supports_flash_attn = True |
| | _supports_sdpa = True |
| |
|
| | config_class = PPLXQwen3Config |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.post_init() |
| |
|
| | def post_init(self): |
| | super().post_init() |
| | |
| | for layer in self.layers: |
| | layer.self_attn.is_causal = False |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | position_ids: torch.LongTensor | None = None, |
| | past_key_values: Cache | None = None, |
| | inputs_embeds: torch.FloatTensor | None = None, |
| | use_cache: bool | None = None, |
| | cache_position: torch.LongTensor | None = None, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> BaseModelOutputWithPooling: |
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| | input_ids = None |
| |
|
| | |
| | dummy_cache_position = torch.arange( |
| | inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long |
| | ) |
| | attention_mask = { |
| | "full_attention": create_causal_mask( |
| | config=self.config, |
| | input_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | cache_position=dummy_cache_position, |
| | past_key_values=None, |
| | position_ids=position_ids, |
| | or_mask_function=bidirectional_mask_function(attention_mask), |
| | ) |
| | } |
| |
|
| | outputs = super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| | return outputs |