Update modeling_super_linear.py
Browse files- modeling_super_linear.py +6 -4
modeling_super_linear.py
CHANGED
|
@@ -590,10 +590,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 590 |
# backbone expects (B, C, L)
|
| 591 |
x_enc = inputs_embeds
|
| 592 |
|
| 593 |
-
|
| 594 |
-
x_enc = self.fourier_interp_dim1(x_enc)
|
| 595 |
-
|
| 596 |
-
|
|
|
|
|
|
|
| 597 |
# backbone returns (B, pred_len, C)
|
| 598 |
preds = self.backbone(x_enc)
|
| 599 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|
|
|
|
| 590 |
# backbone expects (B, C, L)
|
| 591 |
x_enc = inputs_embeds
|
| 592 |
|
| 593 |
+
if x_enc.shape[1] < 512:
|
| 594 |
+
x_enc = self.fourier_interp_dim1(x_enc)
|
| 595 |
+
mean = x_enc.mean()
|
| 596 |
+
std = x_enc.std().clamp_min(1e-6)
|
| 597 |
+
x_enc = (x_enc - mean) / std
|
| 598 |
+
|
| 599 |
# backbone returns (B, pred_len, C)
|
| 600 |
preds = self.backbone(x_enc)
|
| 601 |
return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
|