TSXu commited on
Commit
b2bfb8e
·
1 Parent(s): aecc9f1

Use pure fp32 for ZeroGPU - disable autocast entirely

Browse files
src/flux/modules/layers.py CHANGED
@@ -59,15 +59,7 @@ class MLPEmbedder(nn.Module):
59
  self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
60
 
61
  def forward(self, x: Tensor) -> Tensor:
62
- # Disable autocast and use fp32 for computation to avoid CUBLAS errors
63
- orig_dtype = x.dtype
64
- with torch.autocast(device_type='cuda', enabled=False):
65
- x = x.float()
66
- # Compute with fp32 weights
67
- x = F.linear(x, self.in_layer.weight.float(), self.in_layer.bias.float() if self.in_layer.bias is not None else None)
68
- x = self.silu(x)
69
- x = F.linear(x, self.out_layer.weight.float(), self.out_layer.bias.float() if self.out_layer.bias is not None else None)
70
- return x.to(orig_dtype)
71
 
72
 
73
  class RMSNorm(torch.nn.Module):
@@ -177,14 +169,7 @@ class Modulation(nn.Module):
177
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
178
 
179
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
180
- # Disable autocast and use fp32 for computation to avoid CUBLAS errors
181
- orig_dtype = vec.dtype
182
- with torch.autocast(device_type='cuda', enabled=False):
183
- vec = vec.float()
184
- out = F.linear(F.silu(vec), self.lin.weight.float(), self.lin.bias.float() if self.lin.bias is not None else None)
185
- out = out[:, None, :].chunk(self.multiplier, dim=-1)
186
- # Convert back to original dtype
187
- out = tuple(o.to(orig_dtype) for o in out)
188
 
189
  return (
190
  ModulationOut(*out[:3]),
 
59
  self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
60
 
61
  def forward(self, x: Tensor) -> Tensor:
62
+ return self.out_layer(self.silu(self.in_layer(x)))
 
 
 
 
 
 
 
 
63
 
64
 
65
  class RMSNorm(torch.nn.Module):
 
169
  self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
170
 
171
  def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
172
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
 
 
 
 
 
 
 
173
 
174
  return (
175
  ModulationOut(*out[:3]),
src/flux/xflux_pipeline.py CHANGED
@@ -365,46 +365,46 @@ class XFluxPipeline:
365
  if neg_image_proj is not None:
366
  neg_image_proj = neg_image_proj.to(inference_dtype)
367
 
368
- # Use autocast for automatic mixed precision - handles fp16/fp32 fallback
369
- with torch.autocast(device_type='cuda', dtype=torch.float16):
370
- if self.controlnet_loaded:
371
- x = denoise_controlnet(
372
- self.model,
373
- **inp_cond,
374
- controlnet=self.controlnet,
375
- timesteps=timesteps,
376
- guidance=guidance,
377
- controlnet_cond=controlnet_image,
378
- timestep_to_start_cfg=timestep_to_start_cfg,
379
- neg_txt=neg_inp_cond['txt'],
380
- neg_txt_ids=neg_inp_cond['txt_ids'],
381
- neg_vec=neg_inp_cond['vec'],
382
- true_gs=true_gs,
383
- controlnet_gs=control_weight,
384
- image_proj=image_proj,
385
- neg_image_proj=neg_image_proj,
386
- ip_scale=ip_scale,
387
- neg_ip_scale=neg_ip_scale,
388
- )
389
- else:
390
- x = denoise(
391
- self.model,
392
- **inp_cond,
393
- timesteps=timesteps,
394
- guidance=guidance,
395
- cond_latent=cond_latent,
396
- cond_txt_latent=cond_txt_latent,
397
- timestep_to_start_cfg=timestep_to_start_cfg,
398
- neg_txt=neg_inp_cond['txt'],
399
- neg_txt_ids=neg_inp_cond['txt_ids'],
400
- neg_vec=neg_inp_cond['vec'],
401
- true_gs=true_gs,
402
- image_proj=image_proj,
403
- neg_image_proj=neg_image_proj,
404
- ip_scale=ip_scale,
405
- neg_ip_scale=neg_ip_scale,
406
- is_generation=is_generation,
407
- )
408
 
409
  if self.offload:
410
  self.offload_model_to_cpu(self.model)
 
365
  if neg_image_proj is not None:
366
  neg_image_proj = neg_image_proj.to(inference_dtype)
367
 
368
+ # Disable autocast - ZeroGPU has CUBLAS issues with fp16
369
+ # Use pure fp32 for all operations
370
+ if self.controlnet_loaded:
371
+ x = denoise_controlnet(
372
+ self.model,
373
+ **inp_cond,
374
+ controlnet=self.controlnet,
375
+ timesteps=timesteps,
376
+ guidance=guidance,
377
+ controlnet_cond=controlnet_image,
378
+ timestep_to_start_cfg=timestep_to_start_cfg,
379
+ neg_txt=neg_inp_cond['txt'],
380
+ neg_txt_ids=neg_inp_cond['txt_ids'],
381
+ neg_vec=neg_inp_cond['vec'],
382
+ true_gs=true_gs,
383
+ controlnet_gs=control_weight,
384
+ image_proj=image_proj,
385
+ neg_image_proj=neg_image_proj,
386
+ ip_scale=ip_scale,
387
+ neg_ip_scale=neg_ip_scale,
388
+ )
389
+ else:
390
+ x = denoise(
391
+ self.model,
392
+ **inp_cond,
393
+ timesteps=timesteps,
394
+ guidance=guidance,
395
+ cond_latent=cond_latent,
396
+ cond_txt_latent=cond_txt_latent,
397
+ timestep_to_start_cfg=timestep_to_start_cfg,
398
+ neg_txt=neg_inp_cond['txt'],
399
+ neg_txt_ids=neg_inp_cond['txt_ids'],
400
+ neg_vec=neg_inp_cond['vec'],
401
+ true_gs=true_gs,
402
+ image_proj=image_proj,
403
+ neg_image_proj=neg_image_proj,
404
+ ip_scale=ip_scale,
405
+ neg_ip_scale=neg_ip_scale,
406
+ is_generation=is_generation,
407
+ )
408
 
409
  if self.offload:
410
  self.offload_model_to_cpu(self.model)