Florian valade commited on
Commit
7848d77
·
1 Parent(s): 45e00e6

fix cache position for newer version of transformers

Browse files
Files changed (2) hide show
  1. src/inference.py +4 -0
  2. src/model_adapters.py +3 -0
src/inference.py CHANGED
@@ -522,6 +522,9 @@ class DSSDecoder:
522
  0
523
  )
524
 
 
 
 
525
  # Get embeddings
526
  hidden_states = self.adapter.get_embed_tokens(input_ids)
527
 
@@ -544,6 +547,7 @@ class DSSDecoder:
544
  past_key_value=None,
545
  position_embeddings=position_embeddings,
546
  use_cache=False,
 
547
  )
548
 
549
  # Check if this is a head checkpoint
 
522
  0
523
  )
524
 
525
+ # Cache position (required by newer transformers for Qwen3)
526
+ cache_position = torch.arange(seq_len, dtype=torch.long, device=device)
527
+
528
  # Get embeddings
529
  hidden_states = self.adapter.get_embed_tokens(input_ids)
530
 
 
547
  past_key_value=None,
548
  position_embeddings=position_embeddings,
549
  use_cache=False,
550
+ cache_position=cache_position,
551
  )
552
 
553
  # Check if this is a head checkpoint
src/model_adapters.py CHANGED
@@ -36,6 +36,7 @@ class ModelAdapter(ABC):
36
  past_key_value: Optional[Tuple],
37
  position_embeddings: Optional[Tuple],
38
  use_cache: bool = True,
 
39
  ) -> Tuple[Tensor, Optional[Tuple]]:
40
  """Forward through a single layer, returning hidden states and optional KV cache."""
41
  ...
@@ -99,6 +100,7 @@ class LlamaStyleAdapter(ModelAdapter):
99
  past_key_value: Optional[Tuple],
100
  position_embeddings: Optional[Tuple],
101
  use_cache: bool = True,
 
102
  ) -> Tuple[Tensor, Optional[Tuple]]:
103
  """Forward through a decoder layer."""
104
  layer_outputs = layer(
@@ -108,6 +110,7 @@ class LlamaStyleAdapter(ModelAdapter):
108
  past_key_value=past_key_value,
109
  use_cache=use_cache,
110
  position_embeddings=position_embeddings,
 
111
  )
112
  hidden_states = layer_outputs[0]
113
  new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None
 
36
  past_key_value: Optional[Tuple],
37
  position_embeddings: Optional[Tuple],
38
  use_cache: bool = True,
39
+ cache_position: Optional[Tensor] = None,
40
  ) -> Tuple[Tensor, Optional[Tuple]]:
41
  """Forward through a single layer, returning hidden states and optional KV cache."""
42
  ...
 
100
  past_key_value: Optional[Tuple],
101
  position_embeddings: Optional[Tuple],
102
  use_cache: bool = True,
103
+ cache_position: Optional[Tensor] = None,
104
  ) -> Tuple[Tensor, Optional[Tuple]]:
105
  """Forward through a decoder layer."""
106
  layer_outputs = layer(
 
110
  past_key_value=past_key_value,
111
  use_cache=use_cache,
112
  position_embeddings=position_embeddings,
113
+ cache_position=cache_position,
114
  )
115
  hidden_states = layer_outputs[0]
116
  new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None