michaelbzhu commited on
Commit
eaf9a45
·
verified ·
1 Parent(s): 5d119f2

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +4 -1
modeling.py CHANGED
@@ -151,4 +151,7 @@ class CustomModel(PreTrainedModel):
151
  def forward(self, tensor):
152
  with torch.autocast('cuda', dtype=torch.bfloat16):
153
  logits = self.model(tensor)
154
- return CausalLMOutput(logits=logits)
 
 
 
 
151
  def forward(self, tensor):
152
  with torch.autocast('cuda', dtype=torch.bfloat16):
153
  logits = self.model(tensor)
154
+ return CausalLMOutput(logits=logits)
155
+
156
+ def get_input_embeddings(self):
157
+ return self.embed