Florian valade commited on
Commit
687049b
·
1 Parent(s): 33efa44

Fix transformers compatibility: pin versions and rename past_key_value to past_key_values

Browse files
Files changed (3) hide show
  1. requirements.txt +4 -4
  2. src/inference.py +1 -1
  3. src/model_adapters.py +3 -3
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- torch>=2.0.0
2
- transformers>=4.37.0
3
- gradio>=4.0.0
4
  bitsandbytes>=0.41.0
5
- accelerate>=0.25.0
6
  huggingface_hub>=0.19.0
 
1
+ torch>=2.0.0,<3.0.0
2
+ transformers>=4.51.0,<4.55.0
3
+ gradio>=4.0.0,<5.0.0
4
  bitsandbytes>=0.41.0
5
+ accelerate>=0.25.0,<1.0.0
6
  huggingface_hub>=0.19.0
src/inference.py CHANGED
@@ -717,7 +717,7 @@ class DSSDecoder:
717
  hidden_states=hidden_states,
718
  position_ids=position_ids,
719
  attention_mask=None,
720
- past_key_value=None,
721
  position_embeddings=position_embeddings,
722
  use_cache=False,
723
  cache_position=cache_position,
 
717
  hidden_states=hidden_states,
718
  position_ids=position_ids,
719
  attention_mask=None,
720
+ past_key_values=None,
721
  position_embeddings=position_embeddings,
722
  use_cache=False,
723
  cache_position=cache_position,
src/model_adapters.py CHANGED
@@ -33,7 +33,7 @@ class ModelAdapter(ABC):
33
  hidden_states: Tensor,
34
  position_ids: Tensor,
35
  attention_mask: Optional[Tensor],
36
- past_key_value: Optional[Tuple],
37
  position_embeddings: Optional[Tuple],
38
  use_cache: bool = True,
39
  cache_position: Optional[Tensor] = None,
@@ -97,7 +97,7 @@ class LlamaStyleAdapter(ModelAdapter):
97
  hidden_states: Tensor,
98
  position_ids: Tensor,
99
  attention_mask: Optional[Tensor],
100
- past_key_value: Optional[Tuple],
101
  position_embeddings: Optional[Tuple],
102
  use_cache: bool = True,
103
  cache_position: Optional[Tensor] = None,
@@ -107,7 +107,7 @@ class LlamaStyleAdapter(ModelAdapter):
107
  hidden_states,
108
  attention_mask=attention_mask,
109
  position_ids=position_ids,
110
- past_key_value=past_key_value,
111
  use_cache=use_cache,
112
  position_embeddings=position_embeddings,
113
  cache_position=cache_position,
 
33
  hidden_states: Tensor,
34
  position_ids: Tensor,
35
  attention_mask: Optional[Tensor],
36
+ past_key_values: Optional[Tuple],
37
  position_embeddings: Optional[Tuple],
38
  use_cache: bool = True,
39
  cache_position: Optional[Tensor] = None,
 
97
  hidden_states: Tensor,
98
  position_ids: Tensor,
99
  attention_mask: Optional[Tensor],
100
+ past_key_values: Optional[Tuple],
101
  position_embeddings: Optional[Tuple],
102
  use_cache: bool = True,
103
  cache_position: Optional[Tensor] = None,
 
107
  hidden_states,
108
  attention_mask=attention_mask,
109
  position_ids=position_ids,
110
+ past_key_values=past_key_values,
111
  use_cache=use_cache,
112
  position_embeddings=position_embeddings,
113
  cache_position=cache_position,