mqyqlx commited on
Commit ·
83ee88b
1
Parent(s): d8e070e
match config class
Browse files- generation_demo.py +1 -1
- modeling_dcformer.py +1 -0
generation_demo.py
CHANGED
|
@@ -7,7 +7,7 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
| 7 |
tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCFormer-2.8B")
|
| 8 |
model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCFormer-2.8B", trust_remote_code=True)
|
| 9 |
|
| 10 |
-
device = torch.device('cuda')
|
| 11 |
MAX_BATCH_SIZE = 1
|
| 12 |
MAX_SEQ_LENGTH = 2048
|
| 13 |
NUM_TOKENS_TO_GENERATE = 100
|
|
|
|
| 7 |
tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCFormer-2.8B")
|
| 8 |
model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCFormer-2.8B", trust_remote_code=True)
|
| 9 |
|
| 10 |
+
device = torch.device('cuda:1')
|
| 11 |
MAX_BATCH_SIZE = 1
|
| 12 |
MAX_SEQ_LENGTH = 2048
|
| 13 |
NUM_TOKENS_TO_GENERATE = 100
|
modeling_dcformer.py
CHANGED
|
@@ -70,6 +70,7 @@ class KVKWCache(nn.Module):
|
|
| 70 |
return k_out, v_out, kw_out
|
| 71 |
|
| 72 |
class DCFormer(PreTrainedModel):
|
|
|
|
| 73 |
'''
|
| 74 |
DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
|
| 75 |
'''
|
|
|
|
| 70 |
return k_out, v_out, kw_out
|
| 71 |
|
| 72 |
class DCFormer(PreTrainedModel):
|
| 73 |
+
config_class=DCFormerConfig
|
| 74 |
'''
|
| 75 |
DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
|
| 76 |
'''
|