Spaces:
Running on Zero
Running on Zero
Disable autocast for MLPEmbedder and Modulation to fix CUBLAS errors
Browse files- src/flux/modules/layers.py +15 -7
src/flux/modules/layers.py
CHANGED
|
@@ -59,13 +59,14 @@ class MLPEmbedder(nn.Module):
|
|
| 59 |
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 60 |
|
| 61 |
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
-
#
|
| 63 |
orig_dtype = x.dtype
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
return x.to(orig_dtype)
|
| 70 |
|
| 71 |
|
|
@@ -176,7 +177,14 @@ class Modulation(nn.Module):
|
|
| 176 |
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 177 |
|
| 178 |
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
return (
|
| 182 |
ModulationOut(*out[:3]),
|
|
|
|
| 59 |
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 60 |
|
| 61 |
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
# Disable autocast and use fp32 for computation to avoid CUBLAS errors
|
| 63 |
orig_dtype = x.dtype
|
| 64 |
+
with torch.autocast(device_type='cuda', enabled=False):
|
| 65 |
+
x = x.float()
|
| 66 |
+
# Compute with fp32 weights
|
| 67 |
+
x = F.linear(x, self.in_layer.weight.float(), self.in_layer.bias.float() if self.in_layer.bias is not None else None)
|
| 68 |
+
x = self.silu(x)
|
| 69 |
+
x = F.linear(x, self.out_layer.weight.float(), self.out_layer.bias.float() if self.out_layer.bias is not None else None)
|
| 70 |
return x.to(orig_dtype)
|
| 71 |
|
| 72 |
|
|
|
|
| 177 |
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 178 |
|
| 179 |
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 180 |
+
# Disable autocast and use fp32 for computation to avoid CUBLAS errors
|
| 181 |
+
orig_dtype = vec.dtype
|
| 182 |
+
with torch.autocast(device_type='cuda', enabled=False):
|
| 183 |
+
vec = vec.float()
|
| 184 |
+
out = F.linear(F.silu(vec), self.lin.weight.float(), self.lin.bias.float() if self.lin.bias is not None else None)
|
| 185 |
+
out = out[:, None, :].chunk(self.multiplier, dim=-1)
|
| 186 |
+
# Convert back to original dtype
|
| 187 |
+
out = tuple(o.to(orig_dtype) for o in out)
|
| 188 |
|
| 189 |
return (
|
| 190 |
ModulationOut(*out[:3]),
|