oweller2 commited on
Commit ·
8686e3f
1
Parent(s): cf03b9b
added
Browse files- modeling_flexbert.py +9 -3
modeling_flexbert.py
CHANGED
|
@@ -1536,7 +1536,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1536 |
# Initialize weights and apply final processing
|
| 1537 |
self._init_weights(reset_params=False)
|
| 1538 |
|
| 1539 |
-
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 1540 |
# Handle the XOR condition
|
| 1541 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
| 1542 |
|
|
@@ -1556,7 +1556,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1556 |
|
| 1557 |
if not self.config.tie_word_embeddings:
|
| 1558 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1559 |
-
|
| 1560 |
@classmethod
|
| 1561 |
def from_composer(
|
| 1562 |
cls,
|
|
@@ -1702,13 +1702,19 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
| 1702 |
)
|
| 1703 |
|
| 1704 |
if self.pad_logits:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1705 |
return CausalLMOutput(
|
| 1706 |
loss=loss,
|
| 1707 |
-
logits=
|
| 1708 |
hidden_states=None,
|
| 1709 |
attentions=None,
|
| 1710 |
)
|
| 1711 |
else:
|
|
|
|
|
|
|
| 1712 |
return CausalLMOutput(
|
| 1713 |
loss=loss,
|
| 1714 |
logits=logits,
|
|
|
|
| 1536 |
# Initialize weights and apply final processing
|
| 1537 |
self._init_weights(reset_params=False)
|
| 1538 |
|
| 1539 |
+
[] def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
| 1540 |
# Handle the XOR condition
|
| 1541 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
| 1542 |
|
|
|
|
| 1556 |
|
| 1557 |
if not self.config.tie_word_embeddings:
|
| 1558 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
| 1559 |
+
|
| 1560 |
@classmethod
|
| 1561 |
def from_composer(
|
| 1562 |
cls,
|
|
|
|
| 1702 |
)
|
| 1703 |
|
| 1704 |
if self.pad_logits:
|
| 1705 |
+
# Reshape logits to 3D if needed
|
| 1706 |
+
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
| 1707 |
+
if len(new_logits.shape) == 2:
|
| 1708 |
+
new_logits = new_logits.unsqueeze(0)
|
| 1709 |
return CausalLMOutput(
|
| 1710 |
loss=loss,
|
| 1711 |
+
logits=new_logits,
|
| 1712 |
hidden_states=None,
|
| 1713 |
attentions=None,
|
| 1714 |
)
|
| 1715 |
else:
|
| 1716 |
+
if len(logits.shape) == 2:
|
| 1717 |
+
logits = logits.unsqueeze(0)
|
| 1718 |
return CausalLMOutput(
|
| 1719 |
loss=loss,
|
| 1720 |
logits=logits,
|