Commit ·
992669d
1
Parent(s): 1177ff0
ix
Browse files- __pycache__/convolve.cpython-312.pyc +0 -0
- __pycache__/stu.cpython-312.pyc +0 -0
- layers.py +0 -1
- stu.py +6 -5
__pycache__/convolve.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/convolve.cpython-312.pyc and b/__pycache__/convolve.cpython-312.pyc differ
|
|
|
__pycache__/stu.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/stu.cpython-312.pyc and b/__pycache__/stu.cpython-312.pyc differ
|
|
|
layers.py
CHANGED
|
@@ -57,7 +57,6 @@ class STULayer(nn.Module):
|
|
| 57 |
|
| 58 |
# Normalize and apply STU
|
| 59 |
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
|
| 60 |
-
print(f"x dtype: {x.dtype}")
|
| 61 |
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
| 62 |
x = x + x_stu
|
| 63 |
|
|
|
|
| 57 |
|
| 58 |
# Normalize and apply STU
|
| 59 |
x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
|
|
|
|
| 60 |
x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
|
| 61 |
x = x + x_stu
|
| 62 |
|
stu.py
CHANGED
|
@@ -29,11 +29,12 @@ class STU(nn.Module):
|
|
| 29 |
self.d_out = config.n_embd
|
| 30 |
self.use_hankel_L = config.use_hankel_L
|
| 31 |
self.use_approx = config.use_approx
|
| 32 |
-
self.flash_fft =
|
| 33 |
-
|
| 34 |
-
if
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
if self.use_approx:
|
| 38 |
self.M_inputs = nn.Parameter(
|
| 39 |
torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
|
|
|
|
| 29 |
self.d_out = config.n_embd
|
| 30 |
self.use_hankel_L = config.use_hankel_L
|
| 31 |
self.use_approx = config.use_approx
|
| 32 |
+
self.flash_fft = None
|
| 33 |
+
if config.use_flash_fft and flash_fft_available:
|
| 34 |
+
if torch_dtype == torch.float16: # Only enable for float16
|
| 35 |
+
self.flash_fft = FlashFFTConv(self.n, dtype=torch.float16)
|
| 36 |
+
else:
|
| 37 |
+
print(f"Disabling FlashFFTConv for unsupported dtype: {torch_dtype}")
|
| 38 |
if self.use_approx:
|
| 39 |
self.M_inputs = nn.Parameter(
|
| 40 |
torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
|