Update modeling_super_linear.py
Browse files- modeling_super_linear.py +2 -2
modeling_super_linear.py
CHANGED
|
@@ -548,9 +548,9 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 548 |
|
| 549 |
# backbone expects (B, C, L)
|
| 550 |
x_enc = inputs_embeds
|
| 551 |
-
print(x_enc.shape)
|
| 552 |
if x_enc.shape[1] < 512:
|
| 553 |
-
if len(x_enc) == 2:
|
| 554 |
x_enc = F.interpolate(x_enc.unsqueeze(0),size = 512,mode="linear", align_corners=False)
|
| 555 |
else:
|
| 556 |
x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
|
|
|
|
| 548 |
|
| 549 |
# backbone expects (B, C, L)
|
| 550 |
x_enc = inputs_embeds
|
| 551 |
+
#print(x_enc.shape)
|
| 552 |
if x_enc.shape[1] < 512:
|
| 553 |
+
if len(x_enc.shape) == 2:
|
| 554 |
x_enc = F.interpolate(x_enc.unsqueeze(0),size = 512,mode="linear", align_corners=False)
|
| 555 |
else:
|
| 556 |
x_enc = F.interpolate(x_enc,size = 512,mode="linear", align_corners=False)
|