| from typing import Callable |
| import torch |
| from transformers import Qwen3Model |
| from transformers.cache_utils import Cache |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs |
| from .configuration import PPLXQwen3Config |
|
|
|
|
| |
| def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable: |
| """ |
| This creates bidirectional attention mask. |
| """ |
|
|
| def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
| if attention_mask is None: |
| return torch.ones((), dtype=torch.bool) |
| return attention_mask[batch_idx, kv_idx].to(torch.bool) |
|
|
| return inner_mask |
|
|
|
|
| class PPLXQwen3Model(Qwen3Model): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
|
|
| config_class = PPLXQwen3Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.post_init() |
|
|
| def post_init(self): |
| super().post_init() |
| |
| for layer in self.layers: |
| layer.self_attn.is_causal = False |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| use_cache: bool | None = None, |
| cache_position: torch.LongTensor | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPooling: |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| input_ids = None |
|
|
| |
| dummy_cache_position = torch.arange( |
| inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long |
| ) |
| attention_mask = { |
| "full_attention": create_causal_mask( |
| config=self.config, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=dummy_cache_position, |
| past_key_values=None, |
| position_ids=position_ids, |
| or_mask_function=bidirectional_mask_function(attention_mask), |
| ) |
| } |
|
|
| outputs = super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| return outputs |
|
|