Update modeling_super_linear.py
Browse files- modeling_super_linear.py +2 -0
modeling_super_linear.py
CHANGED
|
@@ -606,6 +606,8 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 606 |
# If input was 2D, remove the channel dimension
|
| 607 |
if x.shape[1] == 1:
|
| 608 |
upsample = upsample.squeeze(-1)
|
|
|
|
|
|
|
| 609 |
|
| 610 |
#print(f"Upsampled shape: {upsample.shape}")
|
| 611 |
return upsample
|
|
|
|
| 606 |
# If input was 2D, remove the channel dimension
|
| 607 |
if x.shape[1] == 1:
|
| 608 |
upsample = upsample.squeeze(-1)
|
| 609 |
+
|
| 610 |
+
upsample = upsample.float()
|
| 611 |
|
| 612 |
#print(f"Upsampled shape: {upsample.shape}")
|
| 613 |
return upsample
|