TSXu commited on
Commit
aecc9f1
·
1 Parent(s): 8af673c

Disable autocast for MLPEmbedder and Modulation to fix CUBLAS errors

Browse files
Files changed (1) hide show
  1. src/flux/modules/layers.py +15 -7
src/flux/modules/layers.py CHANGED
@@ -59,13 +59,14 @@ class MLPEmbedder(nn.Module):
59
  self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
60
 
61
  def forward(self, x: Tensor) -> Tensor:
62
- # Use fp32 for computation to avoid CUBLAS errors, then convert back
63
  orig_dtype = x.dtype
64
- x = x.float()
65
- # Compute with fp32 weights
66
- x = F.linear(x, self.in_layer.weight.float(), self.in_layer.bias.float() if self.in_layer.bias is not None else None)
67
- x = self.silu(x)
68
- x = F.linear(x, self.out_layer.weight.float(), self.out_layer.bias.float() if self.out_layer.bias is not None else None)
 
69
  return x.to(orig_dtype)
70
 
71
 
@@ -176,7 +177,14 @@ class Modulation(nn.Module):
176
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
177
 
178
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
179
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
 
 
 
 
 
 
 
180
 
181
  return (
182
  ModulationOut(*out[:3]),
 
59
  self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
60
 
61
  def forward(self, x: Tensor) -> Tensor:
62
+ # Disable autocast and use fp32 for computation to avoid CUBLAS errors
63
  orig_dtype = x.dtype
64
+ with torch.autocast(device_type='cuda', enabled=False):
65
+ x = x.float()
66
+ # Compute with fp32 weights
67
+ x = F.linear(x, self.in_layer.weight.float(), self.in_layer.bias.float() if self.in_layer.bias is not None else None)
68
+ x = self.silu(x)
69
+ x = F.linear(x, self.out_layer.weight.float(), self.out_layer.bias.float() if self.out_layer.bias is not None else None)
70
  return x.to(orig_dtype)
71
 
72
 
 
177
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
178
 
179
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
180
+ # Disable autocast and use fp32 for computation to avoid CUBLAS errors
181
+ orig_dtype = vec.dtype
182
+ with torch.autocast(device_type='cuda', enabled=False):
183
+ vec = vec.float()
184
+ out = F.linear(F.silu(vec), self.lin.weight.float(), self.lin.bias.float() if self.lin.bias is not None else None)
185
+ out = out[:, None, :].chunk(self.multiplier, dim=-1)
186
+ # Convert back to original dtype
187
+ out = tuple(o.to(orig_dtype) for o in out)
188
 
189
  return (
190
  ModulationOut(*out[:3]),