Spaces:
Running on Zero
Running on Zero
fix: use float32 instead of bfloat16 for compatibility
Browse files- inference.py +7 -2
- src/flux/cli.py +2 -2
- src/flux/util.py +5 -5
- src/flux/xflux_pipeline.py +5 -5
inference.py
CHANGED
|
@@ -316,10 +316,15 @@ class CalligraphyGenerator:
|
|
| 316 |
print(f"Loading checkpoint from {checkpoint_path}")
|
| 317 |
checkpoint = self._load_checkpoint_file(checkpoint_path)
|
| 318 |
|
| 319 |
-
# Determine dtype from checkpoint
|
| 320 |
first_tensor = next(iter(checkpoint.values()))
|
| 321 |
checkpoint_dtype = first_tensor.dtype
|
| 322 |
print(f"Checkpoint dtype: {checkpoint_dtype}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
# Load weights into model (assign=True to use checkpoint tensors directly, preserving dtype)
|
| 325 |
model.load_state_dict(checkpoint, strict=False, assign=True)
|
|
@@ -420,7 +425,7 @@ class CalligraphyGenerator:
|
|
| 420 |
model_engine = deepspeed.init_inference(
|
| 421 |
model=model,
|
| 422 |
mp_size=1, # model parallel size
|
| 423 |
-
dtype=torch.
|
| 424 |
replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
|
| 425 |
)
|
| 426 |
|
|
|
|
| 316 |
print(f"Loading checkpoint from {checkpoint_path}")
|
| 317 |
checkpoint = self._load_checkpoint_file(checkpoint_path)
|
| 318 |
|
| 319 |
+
# Determine dtype from checkpoint and convert to float32
|
| 320 |
first_tensor = next(iter(checkpoint.values()))
|
| 321 |
checkpoint_dtype = first_tensor.dtype
|
| 322 |
print(f"Checkpoint dtype: {checkpoint_dtype}")
|
| 323 |
+
|
| 324 |
+
# Convert checkpoint to float32 if needed
|
| 325 |
+
if checkpoint_dtype != torch.float32:
|
| 326 |
+
print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
|
| 327 |
+
checkpoint = {k: v.float() for k, v in checkpoint.items()}
|
| 328 |
|
| 329 |
# Load weights into model (assign=True to use checkpoint tensors directly, preserving dtype)
|
| 330 |
model.load_state_dict(checkpoint, strict=False, assign=True)
|
|
|
|
| 425 |
model_engine = deepspeed.init_inference(
|
| 426 |
model=model,
|
| 427 |
mp_size=1, # model parallel size
|
| 428 |
+
dtype=torch.float32, # Use float32 for compatibility
|
| 429 |
replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
|
| 430 |
)
|
| 431 |
|
src/flux/cli.py
CHANGED
|
@@ -185,7 +185,7 @@ def main(
|
|
| 185 |
opts.height,
|
| 186 |
opts.width,
|
| 187 |
device=torch_device,
|
| 188 |
-
dtype=torch.
|
| 189 |
seed=opts.seed,
|
| 190 |
)
|
| 191 |
opts.seed = None
|
|
@@ -213,7 +213,7 @@ def main(
|
|
| 213 |
|
| 214 |
# decode latents to pixel space
|
| 215 |
x = unpack(x.float(), opts.height, opts.width)
|
| 216 |
-
with torch.autocast(device_type=torch_device.type, dtype=torch.
|
| 217 |
x = ae.decode(x)
|
| 218 |
t1 = time.perf_counter()
|
| 219 |
|
|
|
|
| 185 |
opts.height,
|
| 186 |
opts.width,
|
| 187 |
device=torch_device,
|
| 188 |
+
dtype=torch.float32,
|
| 189 |
seed=opts.seed,
|
| 190 |
)
|
| 191 |
opts.seed = None
|
|
|
|
| 213 |
|
| 214 |
# decode latents to pixel space
|
| 215 |
x = unpack(x.float(), opts.height, opts.width)
|
| 216 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.float32):
|
| 217 |
x = ae.decode(x)
|
| 218 |
t1 = time.perf_counter()
|
| 219 |
|
src/flux/util.py
CHANGED
|
@@ -294,7 +294,7 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
|
|
| 294 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
| 295 |
|
| 296 |
with torch.device("meta" if ckpt_path is not None else device):
|
| 297 |
-
model = Flux(configs[name].params).to(torch.
|
| 298 |
|
| 299 |
if ckpt_path is not None:
|
| 300 |
print("Loading checkpoint")
|
|
@@ -344,7 +344,7 @@ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf
|
|
| 344 |
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
|
| 345 |
|
| 346 |
|
| 347 |
-
model = Flux(configs[name].params).to(torch.
|
| 348 |
|
| 349 |
print("Loading checkpoint")
|
| 350 |
# load_sft doesn't support torch.device
|
|
@@ -365,11 +365,11 @@ def load_controlnet(name, device, transformer=None):
|
|
| 365 |
|
| 366 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 367 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 368 |
-
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.
|
| 369 |
-
# return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.
|
| 370 |
|
| 371 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 372 |
-
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.
|
| 373 |
|
| 374 |
|
| 375 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
|
|
|
| 294 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
| 295 |
|
| 296 |
with torch.device("meta" if ckpt_path is not None else device):
|
| 297 |
+
model = Flux(configs[name].params).to(torch.float32)
|
| 298 |
|
| 299 |
if ckpt_path is not None:
|
| 300 |
print("Loading checkpoint")
|
|
|
|
| 344 |
json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json')
|
| 345 |
|
| 346 |
|
| 347 |
+
model = Flux(configs[name].params).to(torch.float32)
|
| 348 |
|
| 349 |
print("Loading checkpoint")
|
| 350 |
# load_sft doesn't support torch.device
|
|
|
|
| 365 |
|
| 366 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 367 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 368 |
+
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.float32).to(device)
|
| 369 |
+
# return HFEmbedder("google/mt5-base", max_length=max_length, torch_dtype=torch.float32).to(device)
|
| 370 |
|
| 371 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 372 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.float32).to(device)
|
| 373 |
|
| 374 |
|
| 375 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
src/flux/xflux_pipeline.py
CHANGED
|
@@ -71,14 +71,14 @@ class XFluxPipeline:
|
|
| 71 |
|
| 72 |
# load image encoder
|
| 73 |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
| 74 |
-
self.device, dtype=torch.
|
| 75 |
)
|
| 76 |
self.clip_image_processor = CLIPImageProcessor()
|
| 77 |
|
| 78 |
# setup image embedding projection model
|
| 79 |
self.improj = ImageProjModel(4096, 768, 4)
|
| 80 |
self.improj.load_state_dict(proj)
|
| 81 |
-
self.improj = self.improj.to(self.device, dtype=torch.
|
| 82 |
|
| 83 |
ip_attn_procs = {}
|
| 84 |
|
|
@@ -90,7 +90,7 @@ class XFluxPipeline:
|
|
| 90 |
if ip_state_dict:
|
| 91 |
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
|
| 92 |
ip_attn_procs[name].load_state_dict(ip_state_dict)
|
| 93 |
-
ip_attn_procs[name].to(self.device, dtype=torch.
|
| 94 |
else:
|
| 95 |
ip_attn_procs[name] = self.model.attn_processors[name]
|
| 96 |
|
|
@@ -135,7 +135,7 @@ class XFluxPipeline:
|
|
| 135 |
|
| 136 |
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
|
| 137 |
self.model.to(self.device)
|
| 138 |
-
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.
|
| 139 |
|
| 140 |
checkpoint = load_checkpoint(local_path, repo_id, name)
|
| 141 |
self.controlnet.load_state_dict(checkpoint, strict=False)
|
|
@@ -156,7 +156,7 @@ class XFluxPipeline:
|
|
| 156 |
image_prompt_embeds = self.image_encoder(
|
| 157 |
image_prompt
|
| 158 |
).image_embeds.to(
|
| 159 |
-
device=self.device, dtype=torch.
|
| 160 |
)
|
| 161 |
# encode image
|
| 162 |
image_proj = self.improj(image_prompt_embeds)
|
|
|
|
| 71 |
|
| 72 |
# load image encoder
|
| 73 |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
| 74 |
+
self.device, dtype=torch.float32
|
| 75 |
)
|
| 76 |
self.clip_image_processor = CLIPImageProcessor()
|
| 77 |
|
| 78 |
# setup image embedding projection model
|
| 79 |
self.improj = ImageProjModel(4096, 768, 4)
|
| 80 |
self.improj.load_state_dict(proj)
|
| 81 |
+
self.improj = self.improj.to(self.device, dtype=torch.float32)
|
| 82 |
|
| 83 |
ip_attn_procs = {}
|
| 84 |
|
|
|
|
| 90 |
if ip_state_dict:
|
| 91 |
ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072)
|
| 92 |
ip_attn_procs[name].load_state_dict(ip_state_dict)
|
| 93 |
+
ip_attn_procs[name].to(self.device, dtype=torch.float32)
|
| 94 |
else:
|
| 95 |
ip_attn_procs[name] = self.model.attn_processors[name]
|
| 96 |
|
|
|
|
| 135 |
|
| 136 |
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
|
| 137 |
self.model.to(self.device)
|
| 138 |
+
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.float32)
|
| 139 |
|
| 140 |
checkpoint = load_checkpoint(local_path, repo_id, name)
|
| 141 |
self.controlnet.load_state_dict(checkpoint, strict=False)
|
|
|
|
| 156 |
image_prompt_embeds = self.image_encoder(
|
| 157 |
image_prompt
|
| 158 |
).image_embeds.to(
|
| 159 |
+
device=self.device, dtype=torch.float32,
|
| 160 |
)
|
| 161 |
# encode image
|
| 162 |
image_proj = self.improj(image_prompt_embeds)
|