yagizdevre commited on
Commit
992669d
·
1 Parent(s): 1177ff0
__pycache__/convolve.cpython-312.pyc CHANGED
Binary files a/__pycache__/convolve.cpython-312.pyc and b/__pycache__/convolve.cpython-312.pyc differ
 
__pycache__/stu.cpython-312.pyc CHANGED
Binary files a/__pycache__/stu.cpython-312.pyc and b/__pycache__/stu.cpython-312.pyc differ
 
layers.py CHANGED
@@ -57,7 +57,6 @@ class STULayer(nn.Module):
57
 
58
  # Normalize and apply STU
59
  x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
60
- print(f"x dtype: {x.dtype}")
61
  x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
62
  x = x + x_stu
63
 
 
57
 
58
  # Normalize and apply STU
59
  x_normed = self.stu_norm(x).to(dtype=self.stu.M_inputs.dtype) # Match dtype for STU
 
60
  x_stu = self.stu(x_normed).to(dtype=x.dtype) # Ensure output matches `x`'s dtype
61
  x = x + x_stu
62
 
stu.py CHANGED
@@ -29,11 +29,12 @@ class STU(nn.Module):
29
  self.d_out = config.n_embd
30
  self.use_hankel_L = config.use_hankel_L
31
  self.use_approx = config.use_approx
32
- self.flash_fft = (
33
- FlashFFTConv(self.n, dtype=torch_dtype)
34
- if config.use_flash_fft and flash_fft_available
35
- else None
36
- )
 
37
  if self.use_approx:
38
  self.M_inputs = nn.Parameter(
39
  torch.empty(self.d_in, self.d_out, dtype=torch_dtype)
 
29
  self.d_out = config.n_embd
30
  self.use_hankel_L = config.use_hankel_L
31
  self.use_approx = config.use_approx
32
+ self.flash_fft = None
33
+ if config.use_flash_fft and flash_fft_available:
34
+ if torch_dtype == torch.float16: # Only enable for float16
35
+ self.flash_fft = FlashFFTConv(self.n, dtype=torch.float16)
36
+ else:
37
+ print(f"Disabling FlashFFTConv for unsupported dtype: {torch_dtype}")
38
  if self.use_approx:
39
  self.M_inputs = nn.Parameter(
40
  torch.empty(self.d_in, self.d_out, dtype=torch_dtype)