mqyqlx commited on
Commit
83ee88b
·
1 Parent(s): d8e070e

match config class

Browse files
Files changed (2) hide show
  1. generation_demo.py +1 -1
  2. modeling_dcformer.py +1 -0
generation_demo.py CHANGED
@@ -7,7 +7,7 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
7
  tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCFormer-2.8B")
8
  model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCFormer-2.8B", trust_remote_code=True)
9
 
10
- device = torch.device('cuda')
11
  MAX_BATCH_SIZE = 1
12
  MAX_SEQ_LENGTH = 2048
13
  NUM_TOKENS_TO_GENERATE = 100
 
7
  tokenizer = AutoTokenizer.from_pretrained("Caiyun-AI/DCFormer-2.8B")
8
  model = AutoModelForCausalLM.from_pretrained("Caiyun-AI/DCFormer-2.8B", trust_remote_code=True)
9
 
10
+ device = torch.device('cuda:1')
11
  MAX_BATCH_SIZE = 1
12
  MAX_SEQ_LENGTH = 2048
13
  NUM_TOKENS_TO_GENERATE = 100
modeling_dcformer.py CHANGED
@@ -70,6 +70,7 @@ class KVKWCache(nn.Module):
70
  return k_out, v_out, kw_out
71
 
72
  class DCFormer(PreTrainedModel):
 
73
  '''
74
  DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
75
  '''
 
70
  return k_out, v_out, kw_out
71
 
72
  class DCFormer(PreTrainedModel):
73
+ config_class=DCFormerConfig
74
  '''
75
  DCFormer's implementation is adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L89
76
  '''