|
|
import getpass |
|
|
import math |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
|
|
|
import requests |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from huggingface_hub import hf_hub_download, login |
|
|
from PIL import ExifTags, Image |
|
|
from safetensors.torch import load_file as load_sft |
|
|
|
|
|
from .model import Flux, FluxLoraWrapper, FluxParams |
|
|
from .modules.autoencoder import AutoEncoder, AutoEncoderParams |
|
|
from .modules.conditioner import HFEmbedder |
|
|
from shared.utils import files_locator as fl |
|
|
|
|
|
CHECKPOINTS_DIR = Path("checkpoints") |
|
|
|
|
|
BFL_API_KEY = os.getenv("BFL_API_KEY") |
|
|
|
|
|
|
|
|
def ensure_hf_auth(): |
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.") |
|
|
try: |
|
|
login(token=hf_token) |
|
|
print("Successfully authenticated with HuggingFace using HF_TOKEN") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to authenticate with HF_TOKEN: {e}") |
|
|
|
|
|
if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")): |
|
|
print("Already authenticated with HuggingFace") |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def prompt_for_hf_auth(): |
|
|
try: |
|
|
token = getpass.getpass("HF Token (hidden input): ").strip() |
|
|
if not token: |
|
|
print("No token provided. Aborting.") |
|
|
return False |
|
|
|
|
|
login(token=token) |
|
|
print("Successfully authenticated!") |
|
|
return True |
|
|
except KeyboardInterrupt: |
|
|
print("\nAuthentication cancelled by user.") |
|
|
return False |
|
|
except Exception as auth_e: |
|
|
print(f"Authentication failed: {auth_e}") |
|
|
print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") |
|
|
return False |
|
|
|
|
|
|
|
|
def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path: |
|
|
"""Get the local path for a checkpoint file, downloading if necessary.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local_path = filename |
|
|
from mmgp import offload |
|
|
|
|
|
if False: |
|
|
print(f"Downloading {filename} from {repo_id} to {local_path}") |
|
|
try: |
|
|
ensure_hf_auth() |
|
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) |
|
|
except Exception as e: |
|
|
if "gated repo" in str(e).lower() or "restricted" in str(e).lower(): |
|
|
print(f"\nError: Cannot access {repo_id} -- this is a gated repository.") |
|
|
|
|
|
|
|
|
if prompt_for_hf_auth(): |
|
|
|
|
|
print("Retrying download...") |
|
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) |
|
|
else: |
|
|
print("Authentication failed or cancelled.") |
|
|
print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") |
|
|
raise RuntimeError(f"Authentication required for {repo_id}") |
|
|
else: |
|
|
raise e |
|
|
|
|
|
return local_path |
|
|
|
|
|
|
|
|
def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: |
|
|
"""Download ONNX models for TRT to our checkpoints directory""" |
|
|
onnx_repo_map = { |
|
|
"flux-dev": "black-forest-labs/FLUX.1-dev-onnx", |
|
|
"flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx", |
|
|
"flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx", |
|
|
"flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx", |
|
|
"flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx", |
|
|
"flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx", |
|
|
"flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx", |
|
|
} |
|
|
|
|
|
if model_name not in onnx_repo_map: |
|
|
return None |
|
|
|
|
|
repo_id = onnx_repo_map[model_name] |
|
|
safe_repo_name = repo_id.replace("/", "_") |
|
|
onnx_dir = CHECKPOINTS_DIR / safe_repo_name |
|
|
|
|
|
|
|
|
onnx_file_map = { |
|
|
"clip": "clip.opt/model.onnx", |
|
|
"transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx", |
|
|
"transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data", |
|
|
"t5": "t5.opt/model.onnx", |
|
|
"t5_data": "t5.opt/backbone.onnx_data", |
|
|
"vae": "vae.opt/model.onnx", |
|
|
} |
|
|
|
|
|
|
|
|
if onnx_dir.exists(): |
|
|
all_files_exist = True |
|
|
custom_paths = [] |
|
|
for module, onnx_file in onnx_file_map.items(): |
|
|
if module.endswith("_data"): |
|
|
continue |
|
|
local_path = onnx_dir / onnx_file |
|
|
if not local_path.exists(): |
|
|
all_files_exist = False |
|
|
break |
|
|
custom_paths.append(f"{module}:{local_path}") |
|
|
|
|
|
if all_files_exist: |
|
|
print(f"ONNX models ready in {onnx_dir}") |
|
|
return ",".join(custom_paths) |
|
|
|
|
|
|
|
|
print(f"Downloading ONNX models from {repo_id} to {onnx_dir}") |
|
|
print(f"Using transformer precision: {trt_transformer_precision}") |
|
|
onnx_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
for module, onnx_file in onnx_file_map.items(): |
|
|
local_path = onnx_dir / onnx_file |
|
|
if local_path.exists(): |
|
|
continue |
|
|
|
|
|
|
|
|
local_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
print(f"Downloading {onnx_file}") |
|
|
hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir) |
|
|
except Exception as e: |
|
|
if "does not exist" in str(e).lower() or "not found" in str(e).lower(): |
|
|
continue |
|
|
elif "gated repo" in str(e).lower() or "restricted" in str(e).lower(): |
|
|
print(f"Cannot access {repo_id} - requires license acceptance") |
|
|
print("Please follow these steps:") |
|
|
print(f" 1. Visit: https://huggingface.co/{repo_id}") |
|
|
print(" 2. Log in to your HuggingFace account") |
|
|
print(" 3. Accept the license terms and conditions") |
|
|
print(" 4. Then retry this command") |
|
|
raise RuntimeError(f"License acceptance required for {model_name}") |
|
|
else: |
|
|
|
|
|
raise |
|
|
|
|
|
print(f"ONNX models ready in {onnx_dir}") |
|
|
|
|
|
|
|
|
|
|
|
custom_paths = [] |
|
|
for module, onnx_file in onnx_file_map.items(): |
|
|
if module.endswith("_data"): |
|
|
continue |
|
|
full_path = onnx_dir / onnx_file |
|
|
if full_path.exists(): |
|
|
custom_paths.append(f"{module}:{full_path}") |
|
|
|
|
|
return ",".join(custom_paths) |
|
|
|
|
|
|
|
|
def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: |
|
|
"""Check ONNX access and download models for TRT - returns ONNX directory path""" |
|
|
return download_onnx_models_for_trt(model_name, trt_transformer_precision) |
|
|
|
|
|
|
|
|
def track_usage_via_api(name: str, n=1) -> None: |
|
|
""" |
|
|
Track usage of licensed models via the BFL API for commercial licensing compliance. |
|
|
|
|
|
For more information on licensing BFL's models for commercial use and usage reporting, |
|
|
see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true |
|
|
""" |
|
|
assert BFL_API_KEY is not None, "BFL_API_KEY is not set" |
|
|
|
|
|
model_slug_map = { |
|
|
"flux-dev": "flux-1-dev", |
|
|
"flux-dev-kontext": "flux-1-kontext-dev", |
|
|
"flux-dev-fill": "flux-tools", |
|
|
"flux-dev-depth": "flux-tools", |
|
|
"flux-dev-canny": "flux-tools", |
|
|
"flux-dev-canny-lora": "flux-tools", |
|
|
"flux-dev-depth-lora": "flux-tools", |
|
|
"flux-dev-redux": "flux-tools", |
|
|
} |
|
|
|
|
|
if name not in model_slug_map: |
|
|
print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.") |
|
|
return |
|
|
|
|
|
model_slug = model_slug_map[name] |
|
|
url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage" |
|
|
headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"} |
|
|
payload = {"number_of_generations": n} |
|
|
|
|
|
response = requests.post(url, headers=headers, json=payload) |
|
|
if response.status_code != 200: |
|
|
raise Exception(f"Failed to track usage: {response.status_code} {response.text}") |
|
|
else: |
|
|
print(f"Successfully tracked usage for {name} with {n} generations") |
|
|
|
|
|
|
|
|
def save_image( |
|
|
nsfw_classifier, |
|
|
name: str, |
|
|
output_name: str, |
|
|
idx: int, |
|
|
x: torch.Tensor, |
|
|
add_sampling_metadata: bool, |
|
|
prompt: str, |
|
|
nsfw_threshold: float = 0.85, |
|
|
track_usage: bool = False, |
|
|
) -> int: |
|
|
fn = output_name.format(idx=idx) |
|
|
print(f"Saving {fn}") |
|
|
|
|
|
x = x.clamp(-1, 1) |
|
|
x = rearrange(x[0], "c h w -> h w c") |
|
|
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) |
|
|
|
|
|
if nsfw_classifier is not None: |
|
|
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] |
|
|
else: |
|
|
nsfw_score = nsfw_threshold - 1.0 |
|
|
|
|
|
if nsfw_score < nsfw_threshold: |
|
|
exif_data = Image.Exif() |
|
|
if name in ["flux-dev", "flux-schnell"]: |
|
|
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" |
|
|
else: |
|
|
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" |
|
|
exif_data[ExifTags.Base.Make] = "Black Forest Labs" |
|
|
exif_data[ExifTags.Base.Model] = name |
|
|
if add_sampling_metadata: |
|
|
exif_data[ExifTags.Base.ImageDescription] = prompt |
|
|
img.save(fn, exif=exif_data, quality=95, subsampling=0) |
|
|
if track_usage: |
|
|
track_usage_via_api(name, 1) |
|
|
idx += 1 |
|
|
else: |
|
|
print("Your generated image may contain NSFW content.") |
|
|
|
|
|
return idx |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelSpec: |
|
|
params: FluxParams |
|
|
ae_params: AutoEncoderParams |
|
|
repo_id: str |
|
|
repo_flow: str |
|
|
repo_ae: str |
|
|
lora_repo_id: str | None = None |
|
|
lora_filename: str | None = None |
|
|
|
|
|
|
|
|
configs = { |
|
|
"flux-dev": ModelSpec( |
|
|
repo_id="", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-schnell": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-schnell", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=False, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-chroma": ModelSpec( |
|
|
repo_id="lodestones/Chroma1-HD", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=False, |
|
|
chroma=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-canny": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-Canny-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=128, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-canny-lora": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora", |
|
|
lora_filename="flux1-canny-dev-lora.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=128, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-depth": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-Depth-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=128, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-depth-lora": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora", |
|
|
lora_filename="flux1-depth-dev-lora.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=128, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-redux": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-Redux-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-fill": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-Fill-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=384, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-kontext": ModelSpec( |
|
|
repo_id="black-forest-labs/FLUX.1-Kontext-dev", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-uso": ModelSpec( |
|
|
repo_id="", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
eso= True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-umo": ModelSpec( |
|
|
repo_id="", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
eso= True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
"flux-dev-kontext-dreamomni2": ModelSpec( |
|
|
repo_id="", |
|
|
repo_flow="", |
|
|
repo_ae="ckpts/flux_vae.safetensors", |
|
|
params=FluxParams( |
|
|
in_channels=64, |
|
|
out_channels=64, |
|
|
vec_in_dim=768, |
|
|
context_in_dim=4096, |
|
|
hidden_size=3072, |
|
|
mlp_ratio=4.0, |
|
|
num_heads=24, |
|
|
depth=19, |
|
|
depth_single_blocks=38, |
|
|
axes_dim=[16, 56, 56], |
|
|
theta=10_000, |
|
|
qkv_bias=True, |
|
|
guidance_embed=True, |
|
|
eso= True, |
|
|
), |
|
|
ae_params=AutoEncoderParams( |
|
|
resolution=256, |
|
|
in_channels=3, |
|
|
ch=128, |
|
|
out_ch=3, |
|
|
ch_mult=[1, 2, 4, 4], |
|
|
num_res_blocks=2, |
|
|
z_channels=16, |
|
|
scale_factor=0.3611, |
|
|
shift_factor=0.1159, |
|
|
), |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
PREFERED_KONTEXT_RESOLUTIONS = [ |
|
|
(672, 1568), |
|
|
(688, 1504), |
|
|
(720, 1456), |
|
|
(752, 1392), |
|
|
(800, 1328), |
|
|
(832, 1248), |
|
|
(880, 1184), |
|
|
(944, 1104), |
|
|
(1024, 1024), |
|
|
(1104, 944), |
|
|
(1184, 880), |
|
|
(1248, 832), |
|
|
(1328, 800), |
|
|
(1392, 752), |
|
|
(1456, 720), |
|
|
(1504, 688), |
|
|
(1568, 672), |
|
|
] |
|
|
|
|
|
|
|
|
def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]: |
|
|
width = float(aspect_ratio.split(":")[0]) |
|
|
height = float(aspect_ratio.split(":")[1]) |
|
|
ratio = width / height |
|
|
width = round(math.sqrt(area * ratio)) |
|
|
height = round(math.sqrt(area / ratio)) |
|
|
return 16 * (width // 16), 16 * (height // 16) |
|
|
|
|
|
|
|
|
def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
|
|
if len(missing) > 0 and len(unexpected) > 0: |
|
|
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
|
print("\n" + "-" * 79 + "\n") |
|
|
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) |
|
|
elif len(missing) > 0: |
|
|
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
|
|
elif len(unexpected) > 0: |
|
|
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) |
|
|
|
|
|
|
|
|
def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux: |
|
|
|
|
|
config = configs[name] |
|
|
|
|
|
ckpt_path = model_filename |
|
|
|
|
|
with torch.device("meta"): |
|
|
if config.lora_repo_id is not None and config.lora_filename is not None: |
|
|
model = FluxLoraWrapper(params=config.params).to(torch.bfloat16) |
|
|
else: |
|
|
model = Flux(config.params).to(torch.bfloat16) |
|
|
|
|
|
|
|
|
from mmgp import offload |
|
|
offload.load_model_data(model, model_filename ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder: |
|
|
|
|
|
return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device) |
|
|
|
|
|
|
|
|
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: |
|
|
return HFEmbedder( fl.locate_folder("clip_vit_large_patch14"), "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) |
|
|
|
|
|
|
|
|
def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: |
|
|
config = configs[name] |
|
|
ckpt_path = str(get_checkpoint_path(config.repo_id, fl.locate_file("flux_vae.safetensors"), "FLUX_AE")) |
|
|
|
|
|
|
|
|
with torch.device("meta"): |
|
|
ae = AutoEncoder(config.ae_params) |
|
|
|
|
|
|
|
|
sd = load_sft(ckpt_path, device=str(device)) |
|
|
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
|
|
print_load_warning(missing, unexpected) |
|
|
return ae |
|
|
|
|
|
|
|
|
def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: |
|
|
""" |
|
|
Optionally expand the state dict to match the model's parameters shapes. |
|
|
""" |
|
|
for name, param in model.named_parameters(): |
|
|
if name in state_dict: |
|
|
if state_dict[name].shape != param.shape: |
|
|
print( |
|
|
f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." |
|
|
) |
|
|
|
|
|
expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) |
|
|
slices = tuple(slice(0, dim) for dim in state_dict[name].shape) |
|
|
expanded_state_dict_weight[slices] = state_dict[name] |
|
|
state_dict[name] = expanded_state_dict_weight |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|