add sink cache options
Browse files- README.md +5 -0
- custom_generate/generate.py +2 -2
- custom_generate/requirements.txt +1 -0
README.md
CHANGED
|
@@ -17,12 +17,17 @@ This implementation should match the `SinkCache` class present in `transformers<
|
|
| 17 |
|
| 18 |
|
| 19 |
## Model compatibility
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
## Additional Arguments
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
## Output Type changes
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
## Example usage
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
## Model compatibility
|
| 20 |
+
- Decoder-only models
|
| 21 |
|
| 22 |
|
| 23 |
## Additional Arguments
|
| 24 |
+
- `window_length` (`int`, defaults to `256`): The length of the context window.
|
| 25 |
+
- `num_sink_tokens` (`int`, defaults to `4`): The number of sink tokens. See the original paper for more information.
|
| 26 |
|
| 27 |
|
| 28 |
## Output Type changes
|
| 29 |
+
- When `return_dict_in_generate=True`, `output.past_key_values` will be a `SinkCache` instance. `SinkCache` is defined
|
| 30 |
+
in `generate.py`, in this repository.
|
| 31 |
|
| 32 |
|
| 33 |
## Example usage
|
custom_generate/generate.py
CHANGED
|
@@ -193,7 +193,7 @@ class SinkCache(Cache):
|
|
| 193 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 194 |
|
| 195 |
|
| 196 |
-
def generate(model, **kwargs):
|
| 197 |
-
past_key_values = SinkCache(window_length=
|
| 198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 199 |
return generation_outputs
|
|
|
|
| 193 |
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 194 |
|
| 195 |
|
| 196 |
+
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
| 197 |
+
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
| 198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 199 |
return generation_outputs
|
custom_generate/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
transformers>=4.53.0 # 4.52 results in an infinite loop
|