|
|
import os |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import (FlowMatchEulerDiscreteScheduler) |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
|
|
for project_root in project_roots: |
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
|
|
|
from videox_fun.dist import set_multi_gpus_devices, shard_model |
|
|
from videox_fun.models import (AutoencoderKLFlux2, |
|
|
Mistral3ForConditionalGeneration, |
|
|
PixtralProcessor, Flux2Transformer2DModel) |
|
|
from videox_fun.models.cache_utils import get_teacache_coefficients |
|
|
from videox_fun.pipeline import Flux2Pipeline |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, |
|
|
convert_weight_dtype_wrapper) |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "sequential_cpu_offload" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_degree = 1 |
|
|
ring_degree = 1 |
|
|
|
|
|
fsdp_dit = False |
|
|
fsdp_text_encoder = False |
|
|
|
|
|
|
|
|
compile_dit = False |
|
|
|
|
|
|
|
|
model_name = "models/Diffusion_Transformer/FLUX.2-dev" |
|
|
|
|
|
|
|
|
sampler_name = "Flow" |
|
|
|
|
|
|
|
|
transformer_path = None |
|
|
vae_path = None |
|
|
lora_path = None |
|
|
|
|
|
|
|
|
sample_size = [1344, 768] |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
|
|
|
|
|
|
prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" |
|
|
negative_prompt = " " |
|
|
guidance_scale = 4.00 |
|
|
seed = 43 |
|
|
num_inference_steps = 50 |
|
|
lora_weight = 0.55 |
|
|
save_path = "samples/flux2-t2i" |
|
|
|
|
|
device = set_multi_gpus_devices(ulysses_degree, ring_degree) |
|
|
|
|
|
transformer = Flux2Transformer2DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer", |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
).to(weight_dtype) |
|
|
|
|
|
if transformer_path is not None: |
|
|
print(f"From checkpoint: {transformer_path}") |
|
|
if transformer_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
vae = AutoencoderKLFlux2.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
if vae_path is not None: |
|
|
print(f"From checkpoint: {vae_path}") |
|
|
if vae_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(vae_path) |
|
|
else: |
|
|
state_dict = torch.load(vae_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
tokenizer = PixtralProcessor.from_pretrained( |
|
|
model_name, subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
Chosen_Scheduler = scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
}[sampler_name] |
|
|
scheduler = Chosen_Scheduler.from_pretrained( |
|
|
model_name, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
pipeline = Flux2Pipeline( |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
|
|
|
if ulysses_degree > 1 or ring_degree > 1: |
|
|
from functools import partial |
|
|
transformer.enable_multi_gpus_inference() |
|
|
if fsdp_dit: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks)) |
|
|
pipeline.transformer = shard_fn(pipeline.transformer) |
|
|
print("Add FSDP DIT") |
|
|
if fsdp_text_encoder: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers) |
|
|
text_encoder = shard_fn(text_encoder) |
|
|
print("Add FSDP TEXT ENCODER") |
|
|
|
|
|
if compile_dit: |
|
|
for i in range(len(pipeline.transformer.transformer_blocks)): |
|
|
pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) |
|
|
print("Add Compile") |
|
|
|
|
|
if GPU_memory_mode == "sequential_cpu_offload": |
|
|
pipeline.enable_sequential_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload": |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_full_load_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.to(device=device) |
|
|
else: |
|
|
pipeline.to(device=device) |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
sample = pipeline( |
|
|
prompt = prompt, |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
num_inference_steps = num_inference_steps, |
|
|
).images |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
def save_results(): |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
video_path = os.path.join(save_path, prefix + ".png") |
|
|
image = sample[0] |
|
|
image.save(video_path) |
|
|
|
|
|
if ulysses_degree * ring_degree > 1: |
|
|
import torch.distributed as dist |
|
|
if dist.get_rank() == 0: |
|
|
save_results() |
|
|
else: |
|
|
save_results() |