Commit ·
6eae452
1
Parent(s): 3a7a4b9
updating model class and logo
Browse files- fixing multi-device training for model
- updating logo to sharpened version
- lola-logo.png +0 -0
- modeling_lola_gpt2.py +7 -1
lola-logo.png
CHANGED
|
|
modeling_lola_gpt2.py
CHANGED
|
@@ -204,7 +204,7 @@ class LOLAModel(GPT2PreTrainedModel):
|
|
| 204 |
if input_ids is not None and inputs_embeds is not None:
|
| 205 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 206 |
elif input_ids is not None:
|
| 207 |
-
|
| 208 |
input_shape = input_ids.size()
|
| 209 |
input_ids = input_ids.view(-1, input_shape[-1])
|
| 210 |
batch_size = input_ids.shape[0]
|
|
@@ -537,6 +537,12 @@ class LOLALMHeadModel(GPT2LMHeadModel):
|
|
| 537 |
return_dict=True, # Ensure we get a MoeModelOutputWithPast
|
| 538 |
)
|
| 539 |
hidden_states = transformer_outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
lm_logits = self.lm_head(hidden_states)
|
| 541 |
|
| 542 |
aux_loss = transformer_outputs.aux_loss if hasattr(transformer_outputs, 'aux_loss') else None
|
|
|
|
| 204 |
if input_ids is not None and inputs_embeds is not None:
|
| 205 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 206 |
elif input_ids is not None:
|
| 207 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 208 |
input_shape = input_ids.size()
|
| 209 |
input_ids = input_ids.view(-1, input_shape[-1])
|
| 210 |
batch_size = input_ids.shape[0]
|
|
|
|
| 537 |
return_dict=True, # Ensure we get a MoeModelOutputWithPast
|
| 538 |
)
|
| 539 |
hidden_states = transformer_outputs.last_hidden_state
|
| 540 |
+
|
| 541 |
+
# Set device for model parallelism
|
| 542 |
+
if self.model_parallel:
|
| 543 |
+
torch.cuda.set_device(self.transformer.first_device)
|
| 544 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
| 545 |
+
|
| 546 |
lm_logits = self.lm_head(hidden_states)
|
| 547 |
|
| 548 |
aux_loss = transformer_outputs.aux_loss if hasattr(transformer_outputs, 'aux_loss') else None
|