temp
Browse files- modeling_cocom.py +9 -7
modeling_cocom.py
CHANGED
|
@@ -72,6 +72,7 @@ class COCOMConfig(PretrainedConfig):
|
|
| 72 |
training_form="both",
|
| 73 |
lora_r=16,
|
| 74 |
attn_implementation="eager",
|
|
|
|
| 75 |
**kwargs):
|
| 76 |
super().__init__(**kwargs)
|
| 77 |
|
|
@@ -86,6 +87,7 @@ class COCOMConfig(PretrainedConfig):
|
|
| 86 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
| 87 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
| 88 |
self.attn_implementation = attn_implementation
|
|
|
|
| 89 |
|
| 90 |
class COCOM(PreTrainedModel):
|
| 91 |
config_class = COCOMConfig
|
|
@@ -100,7 +102,7 @@ class COCOM(PreTrainedModel):
|
|
| 100 |
torch_dtype=torch.float16,
|
| 101 |
attn_implementation=attn_impl,
|
| 102 |
low_cpu_mem_usage = True,
|
| 103 |
-
device_map=
|
| 104 |
)
|
| 105 |
elif cfg.quantization == "int4":
|
| 106 |
quant_config = BitsAndBytesConfig(
|
|
@@ -117,7 +119,7 @@ class COCOM(PreTrainedModel):
|
|
| 117 |
resume_download=True,
|
| 118 |
low_cpu_mem_usage = True,
|
| 119 |
trust_remote_code=True,
|
| 120 |
-
device_map=
|
| 121 |
)
|
| 122 |
elif cfg.quantization == "int8":
|
| 123 |
quant_config = BitsAndBytesConfig(
|
|
@@ -134,7 +136,7 @@ class COCOM(PreTrainedModel):
|
|
| 134 |
resume_download=True,
|
| 135 |
low_cpu_mem_usage = True,
|
| 136 |
trust_remote_code=True,
|
| 137 |
-
device_map=
|
| 138 |
)
|
| 139 |
else:
|
| 140 |
raise NotImplementedError()
|
|
@@ -300,10 +302,10 @@ class COCOM(PreTrainedModel):
|
|
| 300 |
|
| 301 |
# generate
|
| 302 |
model_input = {
|
| 303 |
-
'enc_input_ids': enc_input['input_ids'],
|
| 304 |
-
'enc_attention_mask': enc_input['attention_mask'],
|
| 305 |
-
'dec_input_ids': inp_dec['input_ids'],
|
| 306 |
-
'dec_attention_mask': inp_dec['attention_mask']
|
| 307 |
}
|
| 308 |
|
| 309 |
return self.generate(model_input, max_new_tokens)
|
|
|
|
| 72 |
training_form="both",
|
| 73 |
lora_r=16,
|
| 74 |
attn_implementation="eager",
|
| 75 |
+
device_map = "cuda",
|
| 76 |
**kwargs):
|
| 77 |
super().__init__(**kwargs)
|
| 78 |
|
|
|
|
| 87 |
self.training_form = training_form # training form, could be compressor: training only comprssor; both:
|
| 88 |
self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
|
| 89 |
self.attn_implementation = attn_implementation
|
| 90 |
+
self.device_map = device_map
|
| 91 |
|
| 92 |
class COCOM(PreTrainedModel):
|
| 93 |
config_class = COCOMConfig
|
|
|
|
| 102 |
torch_dtype=torch.float16,
|
| 103 |
attn_implementation=attn_impl,
|
| 104 |
low_cpu_mem_usage = True,
|
| 105 |
+
device_map =cfg.device_map
|
| 106 |
)
|
| 107 |
elif cfg.quantization == "int4":
|
| 108 |
quant_config = BitsAndBytesConfig(
|
|
|
|
| 119 |
resume_download=True,
|
| 120 |
low_cpu_mem_usage = True,
|
| 121 |
trust_remote_code=True,
|
| 122 |
+
device_map =cfg.device_map
|
| 123 |
)
|
| 124 |
elif cfg.quantization == "int8":
|
| 125 |
quant_config = BitsAndBytesConfig(
|
|
|
|
| 136 |
resume_download=True,
|
| 137 |
low_cpu_mem_usage = True,
|
| 138 |
trust_remote_code=True,
|
| 139 |
+
device_map =cfg.device_map
|
| 140 |
)
|
| 141 |
else:
|
| 142 |
raise NotImplementedError()
|
|
|
|
| 302 |
|
| 303 |
# generate
|
| 304 |
model_input = {
|
| 305 |
+
'enc_input_ids': enc_input['input_ids'].to(self.decoder.device),
|
| 306 |
+
'enc_attention_mask': enc_input['attention_mask'].to(self.decoder.device),
|
| 307 |
+
'dec_input_ids': inp_dec['input_ids'].to(self.decoder.device),
|
| 308 |
+
'dec_attention_mask': inp_dec['attention_mask'].to(self.decoder.device)
|
| 309 |
}
|
| 310 |
|
| 311 |
return self.generate(model_input, max_new_tokens)
|