Florian valade
commited on
Commit
·
7848d77
1
Parent(s):
45e00e6
fix cache position for newer version of transformers
Browse files- src/inference.py +4 -0
- src/model_adapters.py +3 -0
src/inference.py
CHANGED
|
@@ -522,6 +522,9 @@ class DSSDecoder:
|
|
| 522 |
0
|
| 523 |
)
|
| 524 |
|
|
|
|
|
|
|
|
|
|
| 525 |
# Get embeddings
|
| 526 |
hidden_states = self.adapter.get_embed_tokens(input_ids)
|
| 527 |
|
|
@@ -544,6 +547,7 @@ class DSSDecoder:
|
|
| 544 |
past_key_value=None,
|
| 545 |
position_embeddings=position_embeddings,
|
| 546 |
use_cache=False,
|
|
|
|
| 547 |
)
|
| 548 |
|
| 549 |
# Check if this is a head checkpoint
|
|
|
|
| 522 |
0
|
| 523 |
)
|
| 524 |
|
| 525 |
+
# Cache position (required by newer transformers for Qwen3)
|
| 526 |
+
cache_position = torch.arange(seq_len, dtype=torch.long, device=device)
|
| 527 |
+
|
| 528 |
# Get embeddings
|
| 529 |
hidden_states = self.adapter.get_embed_tokens(input_ids)
|
| 530 |
|
|
|
|
| 547 |
past_key_value=None,
|
| 548 |
position_embeddings=position_embeddings,
|
| 549 |
use_cache=False,
|
| 550 |
+
cache_position=cache_position,
|
| 551 |
)
|
| 552 |
|
| 553 |
# Check if this is a head checkpoint
|
src/model_adapters.py
CHANGED
|
@@ -36,6 +36,7 @@ class ModelAdapter(ABC):
|
|
| 36 |
past_key_value: Optional[Tuple],
|
| 37 |
position_embeddings: Optional[Tuple],
|
| 38 |
use_cache: bool = True,
|
|
|
|
| 39 |
) -> Tuple[Tensor, Optional[Tuple]]:
|
| 40 |
"""Forward through a single layer, returning hidden states and optional KV cache."""
|
| 41 |
...
|
|
@@ -99,6 +100,7 @@ class LlamaStyleAdapter(ModelAdapter):
|
|
| 99 |
past_key_value: Optional[Tuple],
|
| 100 |
position_embeddings: Optional[Tuple],
|
| 101 |
use_cache: bool = True,
|
|
|
|
| 102 |
) -> Tuple[Tensor, Optional[Tuple]]:
|
| 103 |
"""Forward through a decoder layer."""
|
| 104 |
layer_outputs = layer(
|
|
@@ -108,6 +110,7 @@ class LlamaStyleAdapter(ModelAdapter):
|
|
| 108 |
past_key_value=past_key_value,
|
| 109 |
use_cache=use_cache,
|
| 110 |
position_embeddings=position_embeddings,
|
|
|
|
| 111 |
)
|
| 112 |
hidden_states = layer_outputs[0]
|
| 113 |
new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None
|
|
|
|
| 36 |
past_key_value: Optional[Tuple],
|
| 37 |
position_embeddings: Optional[Tuple],
|
| 38 |
use_cache: bool = True,
|
| 39 |
+
cache_position: Optional[Tensor] = None,
|
| 40 |
) -> Tuple[Tensor, Optional[Tuple]]:
|
| 41 |
"""Forward through a single layer, returning hidden states and optional KV cache."""
|
| 42 |
...
|
|
|
|
| 100 |
past_key_value: Optional[Tuple],
|
| 101 |
position_embeddings: Optional[Tuple],
|
| 102 |
use_cache: bool = True,
|
| 103 |
+
cache_position: Optional[Tensor] = None,
|
| 104 |
) -> Tuple[Tensor, Optional[Tuple]]:
|
| 105 |
"""Forward through a decoder layer."""
|
| 106 |
layer_outputs = layer(
|
|
|
|
| 110 |
past_key_value=past_key_value,
|
| 111 |
use_cache=use_cache,
|
| 112 |
position_embeddings=position_embeddings,
|
| 113 |
+
cache_position=cache_position,
|
| 114 |
)
|
| 115 |
hidden_states = layer_outputs[0]
|
| 116 |
new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None
|