Update modeling.py
Browse files- modeling.py +4 -1
modeling.py
CHANGED
|
@@ -151,4 +151,7 @@ class CustomModel(PreTrainedModel):
|
|
| 151 |
def forward(self, tensor):
|
| 152 |
with torch.autocast('cuda', dtype=torch.bfloat16):
|
| 153 |
logits = self.model(tensor)
|
| 154 |
-
return CausalLMOutput(logits=logits)
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def forward(self, tensor):
|
| 152 |
with torch.autocast('cuda', dtype=torch.bfloat16):
|
| 153 |
logits = self.model(tensor)
|
| 154 |
+
return CausalLMOutput(logits=logits)
|
| 155 |
+
|
| 156 |
+
def get_input_embeddings(self):
|
| 157 |
+
return self.embed
|