Txu647 commited on
Commit
e7cbbce
·
1 Parent(s): 0634e0c

fix: use float32 instead of bfloat16 for compatibility

Browse files
inference.py CHANGED
@@ -316,10 +316,15 @@ class CalligraphyGenerator:
316
  print(f"Loading checkpoint from {checkpoint_path}")
317
  checkpoint = self._load_checkpoint_file(checkpoint_path)
318
 
319
- # Determine dtype from checkpoint (safetensors saves as bfloat16)
320
  first_tensor = next(iter(checkpoint.values()))
321
  checkpoint_dtype = first_tensor.dtype
322
  print(f"Checkpoint dtype: {checkpoint_dtype}")
 
 
 
 
 
323
 
324
  # Load weights into model (assign=True to use checkpoint tensors directly, preserving dtype)
325
  model.load_state_dict(checkpoint, strict=False, assign=True)
@@ -420,7 +425,7 @@ class CalligraphyGenerator:
420
  model_engine = deepspeed.init_inference(
421
  model=model,
422
  mp_size=1, # model parallel size
423
- dtype=torch.bfloat16 if ds_config.get('bf16', {}).get('enabled', False) else torch.float16,
424
  replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
425
  )
426
 
 
316
  print(f"Loading checkpoint from {checkpoint_path}")
317
  checkpoint = self._load_checkpoint_file(checkpoint_path)
318
 
319
+ # Determine dtype from checkpoint and convert to float32
320
  first_tensor = next(iter(checkpoint.values()))
321
  checkpoint_dtype = first_tensor.dtype
322
  print(f"Checkpoint dtype: {checkpoint_dtype}")
323
+
324
+ # Convert checkpoint to float32 if needed
325
+ if checkpoint_dtype != torch.float32:
326
+ print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
327
+ checkpoint = {k: v.float() for k, v in checkpoint.items()}
328
 
329
  # Load weights into model (assign=True to use checkpoint tensors directly, preserving dtype)
330
  model.load_state_dict(checkpoint, strict=False, assign=True)
 
425
  model_engine = deepspeed.init_inference(
426
  model=model,
427
  mp_size=1, # model parallel size
428
+ dtype=torch.float32, # Use float32 for compatibility
429
  replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
430
  )
431
 
src/flux/cli.py CHANGED
@@ -185,7 +185,7 @@ def main(
185
  opts.height,
186
  opts.width,
187
  device=torch_device,
188
- dtype=torch.bfloat16,
189
  seed=opts.seed,
190
  )
191
  opts.seed = None
@@ -213,7 +213,7 @@ def main(
213
 
214
  # decode latents to pixel space
215
  x = unpack(x.float(), opts.height, opts.width)
216
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
217
  x = ae.decode(x)
218
  t1 = time.perf_counter()
219
 
 
185
  opts.height,
186
  opts.width,
187
  device=torch_device,
188
+ dtype=torch.float32,
189
  seed=opts.seed,
190
  )
191
  opts.seed = None
 
213
 
214
  # decode latents to pixel space
215
  x = unpack(x.float(), opts.height, opts.width)
216
+ with torch.autocast(device_type=torch_device.type, dtype=torch.float32):
217
  x = ae.decode(x)
218
  t1 = time.perf_counter()
219
 
src/flux/util.py CHANGED
@@ -294,7 +294,7 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
294
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
295
 
296
  with torch.device("meta" if ckpt_path is not None else device):
297
- model = Flux(configs[name].params).to(torch.bfloat16)
298
 
299
  if ckpt_path is not None:
300
  print("Loading checkpoint")
@@ -344,7 +344,7 @@ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf
344
  json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
345
 
346
 
347
- model = Flux(configs[name].params).to(torch.bfloat16)
348
 
349
  print("Loading checkpoint")
350
  # load_sft doesn't support torch.device
@@ -365,11 +365,11 @@ def load_controlnet(name, device, transformer=None):
365
 
366
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
367
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
368
- return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
369
- # return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
370
 
371
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
372
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
373
 
374
 
375
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
 
294
  ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
295
 
296
  with torch.device("meta" if ckpt_path is not None else device):
297
+ model = Flux(configs[name].params).to(torch.float32)
298
 
299
  if ckpt_path is not None:
300
  print("Loading checkpoint")
 
344
  json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
345
 
346
 
347
+ model = Flux(configs[name].params).to(torch.float32)
348
 
349
  print("Loading checkpoint")
350
  # load_sft doesn't support torch.device
 
365
 
366
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
367
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
368
+ return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.float32).to(device)
369
+ # return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.float32).to(device)
370
 
371
  def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
372
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.float32).to(device)
373
 
374
 
375
  def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
src/flux/xflux_pipeline.py CHANGED
@@ -71,14 +71,14 @@ class XFluxPipeline:
71
 
72
  # load image encoder
73
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
74
- self.device, dtype=torch.float16
75
  )
76
  self.clip_image_processor = CLIPImageProcessor()
77
 
78
  # setup image embedding projection model
79
  self.improj = ImageProjModel(4096, 768, 4)
80
  self.improj.load_state_dict(proj)
81
- self.improj = self.improj.to(self.device, dtype=torch.bfloat16)
82
 
83
  ip_attn_procs = {}
84
 
@@ -90,7 +90,7 @@ class XFluxPipeline:
90
  if ip_state_dict:
91
  ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
92
  ip_attn_procs[name].load_state_dict(ip_state_dict)
93
- ip_attn_procs[name].to(self.device, dtype=torch.bfloat16)
94
  else:
95
  ip_attn_procs[name] = self.model.attn_processors[name]
96
 
@@ -135,7 +135,7 @@ class XFluxPipeline:
135
 
136
  def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
137
  self.model.to(self.device)
138
- self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
139
 
140
  checkpoint = load_checkpoint(local_path, repo_id, name)
141
  self.controlnet.load_state_dict(checkpoint, strict=False)
@@ -156,7 +156,7 @@ class XFluxPipeline:
156
  image_prompt_embeds = self.image_encoder(
157
  image_prompt
158
  ).image_embeds.to(
159
- device=self.device, dtype=torch.bfloat16,
160
  )
161
  # encode image
162
  image_proj = self.improj(image_prompt_embeds)
 
71
 
72
  # load image encoder
73
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
74
+ self.device, dtype=torch.float32
75
  )
76
  self.clip_image_processor = CLIPImageProcessor()
77
 
78
  # setup image embedding projection model
79
  self.improj = ImageProjModel(4096, 768, 4)
80
  self.improj.load_state_dict(proj)
81
+ self.improj = self.improj.to(self.device, dtype=torch.float32)
82
 
83
  ip_attn_procs = {}
84
 
 
90
  if ip_state_dict:
91
  ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
92
  ip_attn_procs[name].load_state_dict(ip_state_dict)
93
+ ip_attn_procs[name].to(self.device, dtype=torch.float32)
94
  else:
95
  ip_attn_procs[name] = self.model.attn_processors[name]
96
 
 
135
 
136
  def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
137
  self.model.to(self.device)
138
+ self.controlnet = load_controlnet(self.model_type, self.device).to(torch.float32)
139
 
140
  checkpoint = load_checkpoint(local_path, repo_id, name)
141
  self.controlnet.load_state_dict(checkpoint, strict=False)
 
156
  image_prompt_embeds = self.image_encoder(
157
  image_prompt
158
  ).image_embeds.to(
159
+ device=self.device, dtype=torch.float32,
160
  )
161
  # encode image
162
  image_proj = self.improj(image_prompt_embeds)