TSXu commited on
Commit
8af673c
·
1 Parent(s): 9d88d74

Use torch.autocast for automatic mixed precision inference

Browse files
Files changed (1) hide show
  1. src/flux/xflux_pipeline.py +40 -38
src/flux/xflux_pipeline.py CHANGED
@@ -365,44 +365,46 @@ class XFluxPipeline:
365
  if neg_image_proj is not None:
366
  neg_image_proj = neg_image_proj.to(inference_dtype)
367
 
368
- if self.controlnet_loaded:
369
- x = denoise_controlnet(
370
- self.model,
371
- **inp_cond,
372
- controlnet=self.controlnet,
373
- timesteps=timesteps,
374
- guidance=guidance,
375
- controlnet_cond=controlnet_image,
376
- timestep_to_start_cfg=timestep_to_start_cfg,
377
- neg_txt=neg_inp_cond['txt'],
378
- neg_txt_ids=neg_inp_cond['txt_ids'],
379
- neg_vec=neg_inp_cond['vec'],
380
- true_gs=true_gs,
381
- controlnet_gs=control_weight,
382
- image_proj=image_proj,
383
- neg_image_proj=neg_image_proj,
384
- ip_scale=ip_scale,
385
- neg_ip_scale=neg_ip_scale,
386
- )
387
- else:
388
- x = denoise(
389
- self.model,
390
- **inp_cond,
391
- timesteps=timesteps,
392
- guidance=guidance,
393
- cond_latent=cond_latent,
394
- cond_txt_latent=cond_txt_latent,
395
- timestep_to_start_cfg=timestep_to_start_cfg,
396
- neg_txt=neg_inp_cond['txt'],
397
- neg_txt_ids=neg_inp_cond['txt_ids'],
398
- neg_vec=neg_inp_cond['vec'],
399
- true_gs=true_gs,
400
- image_proj=image_proj,
401
- neg_image_proj=neg_image_proj,
402
- ip_scale=ip_scale,
403
- neg_ip_scale=neg_ip_scale,
404
- is_generation=is_generation,
405
- )
 
 
406
 
407
  if self.offload:
408
  self.offload_model_to_cpu(self.model)
 
365
  if neg_image_proj is not None:
366
  neg_image_proj = neg_image_proj.to(inference_dtype)
367
 
368
+ # Use autocast for automatic mixed precision - handles fp16/fp32 fallback
369
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
370
+ if self.controlnet_loaded:
371
+ x = denoise_controlnet(
372
+ self.model,
373
+ **inp_cond,
374
+ controlnet=self.controlnet,
375
+ timesteps=timesteps,
376
+ guidance=guidance,
377
+ controlnet_cond=controlnet_image,
378
+ timestep_to_start_cfg=timestep_to_start_cfg,
379
+ neg_txt=neg_inp_cond['txt'],
380
+ neg_txt_ids=neg_inp_cond['txt_ids'],
381
+ neg_vec=neg_inp_cond['vec'],
382
+ true_gs=true_gs,
383
+ controlnet_gs=control_weight,
384
+ image_proj=image_proj,
385
+ neg_image_proj=neg_image_proj,
386
+ ip_scale=ip_scale,
387
+ neg_ip_scale=neg_ip_scale,
388
+ )
389
+ else:
390
+ x = denoise(
391
+ self.model,
392
+ **inp_cond,
393
+ timesteps=timesteps,
394
+ guidance=guidance,
395
+ cond_latent=cond_latent,
396
+ cond_txt_latent=cond_txt_latent,
397
+ timestep_to_start_cfg=timestep_to_start_cfg,
398
+ neg_txt=neg_inp_cond['txt'],
399
+ neg_txt_ids=neg_inp_cond['txt_ids'],
400
+ neg_vec=neg_inp_cond['vec'],
401
+ true_gs=true_gs,
402
+ image_proj=image_proj,
403
+ neg_image_proj=neg_image_proj,
404
+ ip_scale=ip_scale,
405
+ neg_ip_scale=neg_ip_scale,
406
+ is_generation=is_generation,
407
+ )
408
 
409
  if self.offload:
410
  self.offload_model_to_cpu(self.model)