razmars commited on
Commit
b610119
·
verified ·
1 Parent(s): eeb454f

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +7 -2
modeling_super_linear.py CHANGED
@@ -548,8 +548,13 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
548
 
549
  # backbone expects (B, C, L)
550
  x_enc = inputs_embeds
551
-
552
-
 
 
 
 
 
553
  # backbone returns (B, pred_len, C)
554
  preds = self.backbone(x_enc)
555
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
 
548
 
549
  # backbone expects (B, C, L)
550
  x_enc = inputs_embeds
551
+ if x_enc.shape[1] < 512:
552
+ if len(x_enc) == 2:
553
+ x_enc = F.interpolate(x_enc.unsqueeze(0),size = 512,mode="linear", align_corners=False)
554
+ else:
555
+ x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
556
+
557
+
558
  # backbone returns (B, pred_len, C)
559
  preds = self.backbone(x_enc)
560
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)