from typing import Callable import torch from transformers import Qwen3Model from transformers.cache_utils import Cache from transformers.masking_utils import create_causal_mask from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs from .configuration import PPLXQwen3Config # From modeling_t5gemma.py def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable: """ This creates bidirectional attention mask. """ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: if attention_mask is None: return torch.ones((), dtype=torch.bool) return attention_mask[batch_idx, kv_idx].to(torch.bool) return inner_mask class PPLXQwen3Model(Qwen3Model): _supports_flash_attn = True _supports_sdpa = True config_class = PPLXQwen3Config def __init__(self, config): super().__init__(config) self.post_init() def post_init(self): super().post_init() # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa" for layer in self.layers: layer.self_attn.is_causal = False def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) input_ids = None # We construct a dummy tensor imitating initial positions dummy_cache_position = torch.arange( inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long ) attention_mask = { "full_attention": create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=dummy_cache_position, past_key_values=None, position_ids=position_ids, or_mask_function=bidirectional_mask_function(attention_mask), ) } outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, **kwargs, ) return outputs