allow passing a sink cache back in
Browse files
custom_generate/generate.py
CHANGED
|
@@ -194,6 +194,17 @@ class SinkCache(Cache):
|
|
| 194 |
|
| 195 |
|
| 196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 199 |
return generation_outputs
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
| 197 |
+
# compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result in an
|
| 198 |
+
# infinite loop. This is solved in transformers 4.53.
|
| 199 |
+
kwargs.pop("custom_generate", None)
|
| 200 |
+
|
| 201 |
+
# prepare the cache, it is was not passed.
|
| 202 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
| 203 |
+
if past_key_values is None:
|
| 204 |
+
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
| 205 |
+
elif not isinstance(past_key_values, SinkCache):
|
| 206 |
+
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
| 207 |
+
|
| 208 |
+
# generate with the cache
|
| 209 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 210 |
return generation_outputs
|
custom_generate/requirements.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
transformers>=4.53.0 # 4.52 results in an infinite loop
|
|
|
|
|
|