razmars commited on
Commit
c6c399a
·
verified ·
1 Parent(s): 6b54cc5

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -0
modeling_super_linear.py CHANGED
@@ -606,6 +606,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
606
  # If input was 2D, remove the channel dimension
607
  if x.shape[1] == 1:
608
  upsample = upsample.squeeze(-1)
 
 
609
 
610
  #print(f"Upsampled shape: {upsample.shape}")
611
  return upsample
 
606
  # If input was 2D, remove the channel dimension
607
  if x.shape[1] == 1:
608
  upsample = upsample.squeeze(-1)
609
+
610
+ upsample = upsample.float()
611
 
612
  #print(f"Upsampled shape: {upsample.shape}")
613
  return upsample