|
|
from typing import Callable |
|
|
import torch |
|
|
from transformers import Qwen3Model |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.masking_utils import create_causal_mask |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPooling |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import TransformersKwargs |
|
|
from .configuration import PPLXQwen3Config |
|
|
|
|
|
|
|
|
|
|
|
def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable: |
|
|
""" |
|
|
This creates bidirectional attention mask. |
|
|
""" |
|
|
|
|
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
|
|
if attention_mask is None: |
|
|
return torch.ones((), dtype=torch.bool) |
|
|
return attention_mask[batch_idx, kv_idx].to(torch.bool) |
|
|
|
|
|
return inner_mask |
|
|
|
|
|
|
|
|
class PPLXQwen3Model(Qwen3Model): |
|
|
_supports_flash_attn = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
config_class = PPLXQwen3Config |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.post_init() |
|
|
|
|
|
def post_init(self): |
|
|
super().post_init() |
|
|
|
|
|
for layer in self.layers: |
|
|
layer.self_attn.is_causal = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
past_key_values: Cache | None = None, |
|
|
inputs_embeds: torch.FloatTensor | None = None, |
|
|
use_cache: bool | None = None, |
|
|
cache_position: torch.LongTensor | None = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> BaseModelOutputWithPooling: |
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
input_ids = None |
|
|
|
|
|
|
|
|
dummy_cache_position = torch.arange( |
|
|
inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long |
|
|
) |
|
|
attention_mask = { |
|
|
"full_attention": create_causal_mask( |
|
|
config=self.config, |
|
|
input_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
cache_position=dummy_cache_position, |
|
|
past_key_values=None, |
|
|
position_ids=position_ids, |
|
|
or_mask_function=bidirectional_mask_function(attention_mask), |
|
|
) |
|
|
} |
|
|
|
|
|
outputs = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
return outputs |