Use Longformer
Browse files- modeling_cocom.py +2 -2
modeling_cocom.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
from peft import get_peft_model, LoraConfig, TaskType
|
|
@@ -263,7 +263,7 @@ class COCOM(PreTrainedModel):
|
|
| 263 |
attention_mask=dec_attention_mask.to(device),
|
| 264 |
do_sample=False,
|
| 265 |
top_p=None,
|
| 266 |
-
max_new_tokens=max_new_tokens
|
| 267 |
)
|
| 268 |
decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 269 |
return decoded
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel,LongformerForCausalLM, LongformerTokenizer
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
from peft import get_peft_model, LoraConfig, TaskType
|
|
|
|
| 263 |
attention_mask=dec_attention_mask.to(device),
|
| 264 |
do_sample=False,
|
| 265 |
top_p=None,
|
| 266 |
+
max_new_tokens=min(max_new_tokens, 4096)
|
| 267 |
)
|
| 268 |
decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 269 |
return decoded
|