Update modeling_super_linear.py
Browse files- modeling_super_linear.py +7 -2
modeling_super_linear.py
CHANGED
|
@@ -548,8 +548,13 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 548 |
|
| 549 |
# backbone expects (B, C, L)
|
| 550 |
x_enc = inputs_embeds
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
# backbone returns (B, pred_len, C)
|
| 554 |
preds = self.backbone(x_enc)
|
| 555 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
|
|
|
| 548 |
|
| 549 |
# backbone expects (B, C, L)
|
| 550 |
x_enc = inputs_embeds
|
| 551 |
+
if x_enc.shape[1] < 512:
|
| 552 |
+
if len(x_enc) == 2:
|
| 553 |
+
x_enc = F.interpolate(x_enc.unsqueeze(0),size = 512,mode="linear", align_corners=False)
|
| 554 |
+
else:
|
| 555 |
+
x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
# backbone returns (B, pred_len, C)
|
| 559 |
preds = self.backbone(x_enc)
|
| 560 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|