| | import getpass |
| | import math |
| | import os |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| |
|
| | import requests |
| | import torch |
| | from einops import rearrange |
| | from huggingface_hub import hf_hub_download, login |
| | from PIL import ExifTags, Image |
| | from safetensors.torch import load_file as load_sft |
| |
|
| | from .model import Flux, FluxLoraWrapper, FluxParams |
| | from .modules.autoencoder import AutoEncoder, AutoEncoderParams |
| | from .modules.conditioner import HFEmbedder |
| | from shared.utils import files_locator as fl |
| | from shared.convert import convert_diffusers_to_flux |
| |
|
| | CHECKPOINTS_DIR = Path("checkpoints") |
| |
|
| | BFL_API_KEY = os.getenv("BFL_API_KEY") |
| |
|
| |
|
| | def ensure_hf_auth(): |
| | hf_token = os.environ.get("HF_TOKEN") |
| | if hf_token: |
| | print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.") |
| | try: |
| | login(token=hf_token) |
| | print("Successfully authenticated with HuggingFace using HF_TOKEN") |
| | return True |
| | except Exception as e: |
| | print(f"Warning: Failed to authenticate with HF_TOKEN: {e}") |
| |
|
| | if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")): |
| | print("Already authenticated with HuggingFace") |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | def prompt_for_hf_auth(): |
| | try: |
| | token = getpass.getpass("HF Token (hidden input): ").strip() |
| | if not token: |
| | print("No token provided. Aborting.") |
| | return False |
| |
|
| | login(token=token) |
| | print("Successfully authenticated!") |
| | return True |
| | except KeyboardInterrupt: |
| | print("\nAuthentication cancelled by user.") |
| | return False |
| | except Exception as auth_e: |
| | print(f"Authentication failed: {auth_e}") |
| | print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") |
| | return False |
| |
|
| |
|
| | def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path: |
| | """Get the local path for a checkpoint file, downloading if necessary.""" |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | local_path = filename |
| | from mmgp import offload |
| |
|
| | if False: |
| | print(f"Downloading {filename} from {repo_id} to {local_path}") |
| | try: |
| | ensure_hf_auth() |
| | hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) |
| | except Exception as e: |
| | if "gated repo" in str(e).lower() or "restricted" in str(e).lower(): |
| | print(f"\nError: Cannot access {repo_id} -- this is a gated repository.") |
| |
|
| | |
| | if prompt_for_hf_auth(): |
| | |
| | print("Retrying download...") |
| | hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) |
| | else: |
| | print("Authentication failed or cancelled.") |
| | print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") |
| | raise RuntimeError(f"Authentication required for {repo_id}") |
| | else: |
| | raise e |
| |
|
| | return local_path |
| |
|
| |
|
| | def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: |
| | """Download ONNX models for TRT to our checkpoints directory""" |
| | onnx_repo_map = { |
| | "flux-dev": "black-forest-labs/FLUX.1-dev-onnx", |
| | "flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx", |
| | "flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx", |
| | "flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx", |
| | "flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx", |
| | "flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx", |
| | "flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx", |
| | } |
| |
|
| | if model_name not in onnx_repo_map: |
| | return None |
| |
|
| | repo_id = onnx_repo_map[model_name] |
| | safe_repo_name = repo_id.replace("/", "_") |
| | onnx_dir = CHECKPOINTS_DIR / safe_repo_name |
| |
|
| | |
| | onnx_file_map = { |
| | "clip": "clip.opt/model.onnx", |
| | "transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx", |
| | "transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data", |
| | "t5": "t5.opt/model.onnx", |
| | "t5_data": "t5.opt/backbone.onnx_data", |
| | "vae": "vae.opt/model.onnx", |
| | } |
| |
|
| | |
| | if onnx_dir.exists(): |
| | all_files_exist = True |
| | custom_paths = [] |
| | for module, onnx_file in onnx_file_map.items(): |
| | if module.endswith("_data"): |
| | continue |
| | local_path = onnx_dir / onnx_file |
| | if not local_path.exists(): |
| | all_files_exist = False |
| | break |
| | custom_paths.append(f"{module}:{local_path}") |
| |
|
| | if all_files_exist: |
| | print(f"ONNX models ready in {onnx_dir}") |
| | return ",".join(custom_paths) |
| |
|
| | |
| | print(f"Downloading ONNX models from {repo_id} to {onnx_dir}") |
| | print(f"Using transformer precision: {trt_transformer_precision}") |
| | onnx_dir.mkdir(exist_ok=True) |
| |
|
| | |
| | for module, onnx_file in onnx_file_map.items(): |
| | local_path = onnx_dir / onnx_file |
| | if local_path.exists(): |
| | continue |
| |
|
| | |
| | local_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | try: |
| | print(f"Downloading {onnx_file}") |
| | hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir) |
| | except Exception as e: |
| | if "does not exist" in str(e).lower() or "not found" in str(e).lower(): |
| | continue |
| | elif "gated repo" in str(e).lower() or "restricted" in str(e).lower(): |
| | print(f"Cannot access {repo_id} - requires license acceptance") |
| | print("Please follow these steps:") |
| | print(f" 1. Visit: https://huggingface.co/{repo_id}") |
| | print(" 2. Log in to your HuggingFace account") |
| | print(" 3. Accept the license terms and conditions") |
| | print(" 4. Then retry this command") |
| | raise RuntimeError(f"License acceptance required for {model_name}") |
| | else: |
| | |
| | raise |
| |
|
| | print(f"ONNX models ready in {onnx_dir}") |
| |
|
| | |
| | |
| | custom_paths = [] |
| | for module, onnx_file in onnx_file_map.items(): |
| | if module.endswith("_data"): |
| | continue |
| | full_path = onnx_dir / onnx_file |
| | if full_path.exists(): |
| | custom_paths.append(f"{module}:{full_path}") |
| |
|
| | return ",".join(custom_paths) |
| |
|
| |
|
| | def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: |
| | """Check ONNX access and download models for TRT - returns ONNX directory path""" |
| | return download_onnx_models_for_trt(model_name, trt_transformer_precision) |
| |
|
| |
|
| | def track_usage_via_api(name: str, n=1) -> None: |
| | """ |
| | Track usage of licensed models via the BFL API for commercial licensing compliance. |
| | |
| | For more information on licensing BFL's models for commercial use and usage reporting, |
| | see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true |
| | """ |
| | assert BFL_API_KEY is not None, "BFL_API_KEY is not set" |
| |
|
| | model_slug_map = { |
| | "flux-dev": "flux-1-dev", |
| | "flux-dev-kontext": "flux-1-kontext-dev", |
| | "flux-dev-fill": "flux-tools", |
| | "flux-dev-depth": "flux-tools", |
| | "flux-dev-canny": "flux-tools", |
| | "flux-dev-canny-lora": "flux-tools", |
| | "flux-dev-depth-lora": "flux-tools", |
| | "flux-dev-redux": "flux-tools", |
| | } |
| |
|
| | if name not in model_slug_map: |
| | print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.") |
| | return |
| |
|
| | model_slug = model_slug_map[name] |
| | url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage" |
| | headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"} |
| | payload = {"number_of_generations": n} |
| |
|
| | response = requests.post(url, headers=headers, json=payload) |
| | if response.status_code != 200: |
| | raise Exception(f"Failed to track usage: {response.status_code} {response.text}") |
| | else: |
| | print(f"Successfully tracked usage for {name} with {n} generations") |
| |
|
| |
|
| | def save_image( |
| | nsfw_classifier, |
| | name: str, |
| | output_name: str, |
| | idx: int, |
| | x: torch.Tensor, |
| | add_sampling_metadata: bool, |
| | prompt: str, |
| | nsfw_threshold: float = 0.85, |
| | track_usage: bool = False, |
| | ) -> int: |
| | fn = output_name.format(idx=idx) |
| | print(f"Saving {fn}") |
| | |
| | x = x.clamp(-1, 1) |
| | x = rearrange(x[0], "c h w -> h w c") |
| | img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) |
| | |
| | if nsfw_classifier is not None: |
| | nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] |
| | else: |
| | nsfw_score = nsfw_threshold - 1.0 |
| |
|
| | if nsfw_score < nsfw_threshold: |
| | exif_data = Image.Exif() |
| | if name in ["flux-dev", "flux-schnell"]: |
| | exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" |
| | else: |
| | exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" |
| | exif_data[ExifTags.Base.Make] = "Black Forest Labs" |
| | exif_data[ExifTags.Base.Model] = name |
| | if add_sampling_metadata: |
| | exif_data[ExifTags.Base.ImageDescription] = prompt |
| | img.save(fn, exif=exif_data, quality=95, subsampling=0) |
| | if track_usage: |
| | track_usage_via_api(name, 1) |
| | idx += 1 |
| | else: |
| | print("Your generated image may contain NSFW content.") |
| |
|
| | return idx |
| |
|
| |
|
| | @dataclass |
| | class ModelSpec: |
| | params: FluxParams |
| | ae_params: AutoEncoderParams |
| | repo_id: str |
| | repo_flow: str |
| | repo_ae: str |
| | lora_repo_id: str | None = None |
| | lora_filename: str | None = None |
| |
|
| |
|
| | configs = { |
| | "flux2-dev": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux2_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=128, |
| | vec_in_dim=1, |
| | context_in_dim=15360, |
| | hidden_size=6144, |
| | mlp_ratio=3.0, |
| | single_linear1_mlp_ratio=6.0, |
| | single_mlp_hidden_ratio=3.0, |
| | double_mlp_ratio=3.0, |
| | double_linear1_mlp_ratio=6.0, |
| | num_heads=48, |
| | depth=8, |
| | depth_single_blocks=48, |
| | axes_dim=[32, 32, 32, 32], |
| | theta=2000, |
| | qkv_bias=False, |
| | guidance_embed=True, |
| | flux2=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=1024, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=32, |
| | scale_factor=0.5, |
| | shift_factor=0.0, |
| | ), |
| | ), |
| | "flux2-klein-4b": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux2_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=128, |
| | vec_in_dim=1, |
| | context_in_dim=7680, |
| | hidden_size=3072, |
| | mlp_ratio=3.0, |
| | single_linear1_mlp_ratio=6.0, |
| | single_mlp_hidden_ratio=3.0, |
| | double_mlp_ratio=3.0, |
| | double_linear1_mlp_ratio=6.0, |
| | num_heads=24, |
| | depth=5, |
| | depth_single_blocks=20, |
| | axes_dim=[32, 32, 32, 32], |
| | theta=2000, |
| | qkv_bias=False, |
| | guidance_embed=False, |
| | flux2=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=1024, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=32, |
| | scale_factor=0.5, |
| | shift_factor=0.0, |
| | ), |
| | ), |
| | "flux2-klein-9b": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux2_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=128, |
| | vec_in_dim=1, |
| | context_in_dim=12288, |
| | hidden_size=4096, |
| | mlp_ratio=3.0, |
| | single_linear1_mlp_ratio=6.0, |
| | single_mlp_hidden_ratio=3.0, |
| | double_mlp_ratio=3.0, |
| | double_linear1_mlp_ratio=6.0, |
| | num_heads=32, |
| | depth=8, |
| | depth_single_blocks=24, |
| | axes_dim=[32, 32, 32, 32], |
| | theta=2000, |
| | qkv_bias=False, |
| | guidance_embed=False, |
| | flux2=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=1024, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=32, |
| | scale_factor=0.5, |
| | shift_factor=0.0, |
| | ), |
| | ), |
| | "pi-flux2": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux2_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=128, |
| | vec_in_dim=1, |
| | context_in_dim=15360, |
| | hidden_size=6144, |
| | mlp_ratio=3.0, |
| | single_linear1_mlp_ratio=6.0, |
| | single_mlp_hidden_ratio=3.0, |
| | double_mlp_ratio=3.0, |
| | double_linear1_mlp_ratio=6.0, |
| | num_heads=48, |
| | depth=8, |
| | depth_single_blocks=48, |
| | axes_dim=[32, 32, 32, 32], |
| | theta=2000, |
| | qkv_bias=False, |
| | guidance_embed=True, |
| | flux2=True, |
| | piflow=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=1024, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=32, |
| | scale_factor=0.5, |
| | shift_factor=0.0, |
| | ), |
| | ), |
| | "flux-dev": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-schnell": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-schnell", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=False, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-chroma": ModelSpec( |
| | repo_id="lodestones/Chroma1-HD", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=False, |
| | chroma=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-chroma-radiance": ModelSpec( |
| | repo_id="lodestones/Chroma1-Radiance", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=3, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=False, |
| | chroma=True, |
| | radiance=True, |
| | radiance_patch_size=16, |
| | radiance_hidden_size=64, |
| | radiance_mlp_ratio=4, |
| | radiance_depth=4, |
| | radiance_max_freqs=8, |
| | radiance_tile_size=0, |
| | radiance_final_head_type="conv", |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-canny": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-Canny-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-canny-lora": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora", |
| | lora_filename="flux1-canny-dev-lora.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-depth": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-Depth-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-depth-lora": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora", |
| | lora_filename="flux1-depth-dev-lora.safetensors", |
| | params=FluxParams( |
| | in_channels=128, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-redux": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-Redux-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-fill": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-Fill-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=384, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-kontext": ModelSpec( |
| | repo_id="black-forest-labs/FLUX.1-Kontext-dev", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-uso": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | eso= True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-umo": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | eso= True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | "flux-dev-kontext-dreamomni2": ModelSpec( |
| | repo_id="", |
| | repo_flow="", |
| | repo_ae="ckpts/flux_vae.safetensors", |
| | params=FluxParams( |
| | in_channels=64, |
| | out_channels=64, |
| | vec_in_dim=768, |
| | context_in_dim=4096, |
| | hidden_size=3072, |
| | mlp_ratio=4.0, |
| | num_heads=24, |
| | depth=19, |
| | depth_single_blocks=38, |
| | axes_dim=[16, 56, 56], |
| | theta=10_000, |
| | qkv_bias=True, |
| | guidance_embed=True, |
| | eso= True, |
| | ), |
| | ae_params=AutoEncoderParams( |
| | resolution=256, |
| | in_channels=3, |
| | ch=128, |
| | out_ch=3, |
| | ch_mult=[1, 2, 4, 4], |
| | num_res_blocks=2, |
| | z_channels=16, |
| | scale_factor=0.3611, |
| | shift_factor=0.1159, |
| | ), |
| | ), |
| | } |
| |
|
| |
|
| | PREFERED_KONTEXT_RESOLUTIONS = [ |
| | (672, 1568), |
| | (688, 1504), |
| | (720, 1456), |
| | (752, 1392), |
| | (800, 1328), |
| | (832, 1248), |
| | (880, 1184), |
| | (944, 1104), |
| | (1024, 1024), |
| | (1104, 944), |
| | (1184, 880), |
| | (1248, 832), |
| | (1328, 800), |
| | (1392, 752), |
| | (1456, 720), |
| | (1504, 688), |
| | (1568, 672), |
| | ] |
| |
|
| |
|
| | def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]: |
| | width = float(aspect_ratio.split(":")[0]) |
| | height = float(aspect_ratio.split(":")[1]) |
| | ratio = width / height |
| | width = round(math.sqrt(area * ratio)) |
| | height = round(math.sqrt(area / ratio)) |
| | return 16 * (width // 16), 16 * (height // 16) |
| |
|
| |
|
| | def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
| | if len(missing) > 0 and len(unexpected) > 0: |
| | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| | print("\n" + "-" * 79 + "\n") |
| | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) |
| | elif len(missing) > 0: |
| | print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| | elif len(unexpected) > 0: |
| | print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) |
| |
|
| |
|
| | def preprocess_flux2_state_dict(state_dict: dict, params: FluxParams) -> dict: |
| | """ |
| | Remap Flux2 checkpoint keys to the shared Flux1-style module layout. |
| | - Duplicate shared modulation weights to each single/double block. |
| | - Drop unused guidance embeddings that are identity in Flux2. |
| | """ |
| | sd = dict(state_dict) |
| | def pop_mod(prefix: str): |
| | w = sd.pop(f"{prefix}.lin.weight", None) |
| | b = sd.pop(f"{prefix}.lin.bias", None) |
| | if w is None: |
| | w = sd.pop(f"{prefix}.linear.weight", None) |
| | b = sd.pop(f"{prefix}.linear.bias", None) if b is None else b |
| | return w, b |
| |
|
| | img_mod_w, img_mod_b = pop_mod("double_stream_modulation_img") |
| | txt_mod_w, txt_mod_b = pop_mod("double_stream_modulation_txt") |
| | single_mod_w, single_mod_b = pop_mod("single_stream_modulation") |
| |
|
| | if img_mod_w is not None: |
| | for i in range(params.depth): |
| | sd[f"double_blocks.{i}.img_mod.lin.weight"] = img_mod_w |
| | if img_mod_b is not None: |
| | sd[f"double_blocks.{i}.img_mod.lin.bias"] = img_mod_b |
| | if txt_mod_w is not None: |
| | for i in range(params.depth): |
| | sd[f"double_blocks.{i}.txt_mod.lin.weight"] = txt_mod_w |
| | if txt_mod_b is not None: |
| | sd[f"double_blocks.{i}.txt_mod.lin.bias"] = txt_mod_b |
| | if single_mod_w is not None: |
| | for i in range(params.depth_single_blocks): |
| | sd[f"single_blocks.{i}.modulation.lin.weight"] = single_mod_w |
| | if single_mod_b is not None: |
| | sd[f"single_blocks.{i}.modulation.lin.bias"] = single_mod_b |
| |
|
| | |
| | for unused in ( |
| | "double_stream_modulation_img.lin.weight", |
| | "double_stream_modulation_img.lin.bias", |
| | "double_stream_modulation_txt.lin.weight", |
| | "double_stream_modulation_txt.lin.bias", |
| | "single_stream_modulation.lin.weight", |
| | "single_stream_modulation.lin.bias", |
| | "double_stream_modulation_img.linear.weight", |
| | "double_stream_modulation_img.linear.bias", |
| | "double_stream_modulation_txt.linear.weight", |
| | "double_stream_modulation_txt.linear.bias", |
| | "single_stream_modulation.linear.weight", |
| | "single_stream_modulation.linear.bias", |
| | "guidance_in.in_layer.weight", |
| | "guidance_in.in_layer.bias", |
| | "guidance_in.out_layer.weight", |
| | "guidance_in.out_layer.bias", |
| | ): |
| | sd.pop(unused, None) |
| | return sd |
| |
|
| |
|
| | def preprocess_flux_state_dict(state_dict: dict) -> dict: |
| | return convert_diffusers_to_flux.convert_state_dict(state_dict) |
| |
|
| |
|
| | def load_flow_model( |
| | name: str, |
| | model_filename, |
| | device: str | torch.device = "cuda", |
| | verbose: bool = True, |
| | preprocess_sd=preprocess_flux_state_dict, |
| | ) -> Flux: |
| | |
| | config = configs[name] |
| |
|
| | ckpt_path = model_filename |
| |
|
| | with torch.device("meta"): |
| | if config.lora_repo_id is not None and config.lora_filename is not None: |
| | model = FluxLoraWrapper(params=config.params).to(torch.bfloat16) |
| | else: |
| | model = Flux(config.params).to(torch.bfloat16) |
| |
|
| | |
| | from mmgp import offload |
| | offload.load_model_data(model, model_filename, preprocess_sd=preprocess_sd) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return model |
| |
|
| |
|
| | def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder: |
| | |
| | return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device) |
| |
|
| |
|
| | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: |
| | return HFEmbedder( fl.locate_folder("clip_vit_large_patch14"), "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) |
| |
|
| |
|
| | def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: |
| | config = configs[name] |
| | ckpt_path = str(get_checkpoint_path(config.repo_id, fl.locate_file("flux_vae.safetensors"), "FLUX_AE")) |
| |
|
| | |
| | with torch.device("meta"): |
| | ae = AutoEncoder(config.ae_params) |
| |
|
| | |
| | sd = load_sft(ckpt_path, device=str(device)) |
| | missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
| | print_load_warning(missing, unexpected) |
| | return ae |
| |
|
| |
|
| | def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: |
| | """ |
| | Optionally expand the state dict to match the model's parameters shapes. |
| | """ |
| | for name, param in model.named_parameters(): |
| | if name in state_dict: |
| | if state_dict[name].shape != param.shape: |
| | print( |
| | f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." |
| | ) |
| | |
| | expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) |
| | slices = tuple(slice(0, dim) for dim in state_dict[name].shape) |
| | expanded_state_dict_weight[slices] = state_dict[name] |
| | state_dict[name] = expanded_state_dict_weight |
| |
|
| | return state_dict |
| |
|
| |
|
| | class Flux2TransformerShared: |
| | """ |
| | Convenience wrapper to instantiate the shared Flux2 transformer with explicit defaults. |
| | """ |
| |
|
| | def __new__(cls, **kwargs): |
| | from diffusers.models.transformers.transformer_flux2 import Flux2Transformer2DModel |
| |
|
| | defaults = dict( |
| | patch_size=1, |
| | in_channels=128, |
| | out_channels=None, |
| | num_layers=8, |
| | num_single_layers=48, |
| | attention_head_dim=128, |
| | num_attention_heads=48, |
| | joint_attention_dim=15360, |
| | timestep_guidance_channels=256, |
| | mlp_ratio=3.0, |
| | axes_dims_rope=(32, 32, 32, 32), |
| | rope_theta=2000, |
| | eps=1e-6, |
| | ) |
| | defaults.update(kwargs) |
| | return Flux2Transformer2DModel(**defaults) |
| |
|
| |
|
| |
|