Compatibility with v5

#4
by RaushanTurganbay HF Staff - opened
Files changed (1) hide show
  1. modular_isaac.py +9 -155
modular_isaac.py CHANGED
@@ -102,8 +102,8 @@ from transformers import (
102
  Qwen3ForCausalLM,
103
  Qwen3PreTrainedModel,
104
  )
105
- from transformers.cache_utils import SlidingWindowCache, StaticCache
106
  from transformers.generation.utils import GenerationMixin
 
107
  from transformers.image_processing_utils_fast import (
108
  BaseImageProcessorFast,
109
  DefaultFastImageProcessorKwargs,
@@ -1897,10 +1897,14 @@ class IsaacModel(Qwen3PreTrainedModel):
1897
  sin = sin.to(inputs_embeds.dtype)
1898
 
1899
  # Prepare attention mask
1900
- if attention_mask is not None:
1901
- attention_mask = self._update_causal_mask(
1902
- attention_mask, inputs_embeds, cache_position, past_key_values, False
1903
- )
 
 
 
 
1904
 
1905
  # Initialize hidden states
1906
  hidden_states = inputs_embeds
@@ -1927,156 +1931,6 @@ class IsaacModel(Qwen3PreTrainedModel):
1927
  past_key_values=past_key_values,
1928
  )
1929
 
1930
- def _update_causal_mask(
1931
- self,
1932
- attention_mask: torch.Tensor,
1933
- input_tensor: torch.Tensor,
1934
- cache_position: torch.Tensor,
1935
- past_key_values: Cache,
1936
- output_attentions: bool = False,
1937
- ):
1938
- if self.config._attn_implementation == "flash_attention_2":
1939
- if attention_mask is not None and past_key_values is not None:
1940
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
1941
- if is_padding_right:
1942
- raise ValueError(
1943
- "You are attempting to perform batched generation with padding_side='right'"
1944
- " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
1945
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1946
- )
1947
- if attention_mask is not None and 0.0 in attention_mask:
1948
- return attention_mask
1949
- return None
1950
-
1951
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1952
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1953
- # to infer the attention mask.
1954
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1955
- using_static_cache = isinstance(past_key_values, StaticCache)
1956
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1957
-
1958
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1959
- if (
1960
- self.config._attn_implementation == "sdpa"
1961
- and not (using_static_cache or using_sliding_window_cache)
1962
- and not output_attentions
1963
- ):
1964
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1965
- attention_mask,
1966
- inputs_embeds=input_tensor,
1967
- past_key_values_length=past_seen_tokens,
1968
- sliding_window=self.config.sliding_window,
1969
- is_training=self.training,
1970
- ):
1971
- return None
1972
-
1973
- dtype, device = input_tensor.dtype, input_tensor.device
1974
- min_dtype = torch.finfo(dtype).min
1975
- sequence_length = input_tensor.shape[1]
1976
- # SlidingWindowCache or StaticCache
1977
- if using_sliding_window_cache or using_static_cache:
1978
- target_length = past_key_values.get_max_cache_shape()
1979
- # DynamicCache or no cache
1980
- else:
1981
- target_length = (
1982
- attention_mask.shape[-1]
1983
- if isinstance(attention_mask, torch.Tensor)
1984
- else past_seen_tokens + sequence_length + 1
1985
- )
1986
-
1987
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1988
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1989
- attention_mask,
1990
- sequence_length=sequence_length,
1991
- target_length=target_length,
1992
- dtype=dtype,
1993
- device=device,
1994
- cache_position=cache_position,
1995
- batch_size=input_tensor.shape[0],
1996
- config=self.config,
1997
- past_key_values=past_key_values,
1998
- )
1999
-
2000
- if (
2001
- self.config._attn_implementation == "sdpa"
2002
- and attention_mask is not None
2003
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
2004
- and not output_attentions
2005
- ):
2006
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2007
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2008
- # Details: https://github.com/pytorch/pytorch/issues/110213
2009
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2010
-
2011
- return causal_mask
2012
-
2013
- @staticmethod
2014
- def _prepare_4d_causal_attention_mask_with_cache_position(
2015
- attention_mask: torch.Tensor,
2016
- sequence_length: int,
2017
- target_length: int,
2018
- dtype: torch.dtype,
2019
- device: torch.device,
2020
- cache_position: torch.Tensor,
2021
- batch_size: int,
2022
- config: Qwen3Config,
2023
- past_key_values: Cache,
2024
- ):
2025
- """
2026
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
2027
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
2028
-
2029
- Args:
2030
- attention_mask (`torch.Tensor`):
2031
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
2032
- sequence_length (`int`):
2033
- The sequence length being processed.
2034
- target_length (`int`):
2035
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
2036
- dtype (`torch.dtype`):
2037
- The dtype to use for the 4D attention mask.
2038
- device (`torch.device`):
2039
- The device to place the 4D attention mask on.
2040
- cache_position (`torch.Tensor`):
2041
- Indices depicting the position of the input sequence tokens in the sequence.
2042
- batch_size (`torch.Tensor`):
2043
- Batch size.
2044
- config (`Qwen3Config`):
2045
- The model's configuration class
2046
- past_key_values (`Cache`):
2047
- The cache class that is being used currently to generate
2048
- """
2049
- if attention_mask is not None and attention_mask.dim() == 4:
2050
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
2051
- causal_mask = attention_mask
2052
- else:
2053
- min_dtype = torch.finfo(dtype).min
2054
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
2055
- diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
2056
- if config.sliding_window is not None:
2057
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
2058
- # the check is needed to verify is current checkpoint was trained with sliding window or not
2059
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
2060
- sliding_attend_mask = torch.arange(target_length, device=device) <= (
2061
- cache_position.reshape(-1, 1) - config.sliding_window
2062
- )
2063
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
2064
- causal_mask *= diagonal_attend_mask
2065
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
2066
- if attention_mask is not None:
2067
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
2068
- if attention_mask.shape[-1] > target_length:
2069
- attention_mask = attention_mask[:, :target_length]
2070
- mask_length = attention_mask.shape[-1]
2071
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
2072
- causal_mask.device
2073
- )
2074
- padding_mask = padding_mask == 0
2075
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
2076
- padding_mask, min_dtype
2077
- )
2078
- return causal_mask
2079
-
2080
 
2081
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
2082
  """Isaac multimodal model for conditional generation."""
 
102
  Qwen3ForCausalLM,
103
  Qwen3PreTrainedModel,
104
  )
 
105
  from transformers.generation.utils import GenerationMixin
106
+ from transformers.masking_utils import create_causal_mask
107
  from transformers.image_processing_utils_fast import (
108
  BaseImageProcessorFast,
109
  DefaultFastImageProcessorKwargs,
 
1897
  sin = sin.to(inputs_embeds.dtype)
1898
 
1899
  # Prepare attention mask
1900
+ attention_mask = create_causal_mask(
1901
+ config=self.config,
1902
+ input_embeds=inputs_embeds,
1903
+ attention_mask=attention_mask,
1904
+ past_key_values=past_key_values,
1905
+ position_ids=position_ids,
1906
+ cache_position=cache_position,
1907
+ )
1908
 
1909
  # Initialize hidden states
1910
  hidden_states = inputs_embeds
 
1931
  past_key_values=past_key_values,
1932
  )
1933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1934
 
1935
  class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
1936
  """Isaac multimodal model for conditional generation."""