|
|
import os |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
|
|
for project_root in project_roots: |
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
|
|
|
from videox_fun.dist import set_multi_gpus_devices, shard_model |
|
|
from videox_fun.models import (AutoencoderKLFlux2, |
|
|
Mistral3ForConditionalGeneration, |
|
|
PixtralProcessor, Flux2ControlTransformer2DModel) |
|
|
from videox_fun.models.cache_utils import get_teacache_coefficients |
|
|
from videox_fun.pipeline import Flux2ControlPipeline |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, |
|
|
convert_weight_dtype_wrapper) |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import (filter_kwargs, get_image, get_image_latent, |
|
|
get_image_to_video_latent, |
|
|
get_video_to_video_latent, |
|
|
save_videos_grid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "model_cpu_offload" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_degree = 1 |
|
|
ring_degree = 1 |
|
|
|
|
|
fsdp_dit = False |
|
|
fsdp_text_encoder = False |
|
|
|
|
|
|
|
|
compile_dit = False |
|
|
|
|
|
|
|
|
config_path = "config/flux2/flux2_control.yaml" |
|
|
|
|
|
model_name = "models/Diffusion_Transformer/FLUX.2-dev" |
|
|
|
|
|
|
|
|
sampler_name = "Flow" |
|
|
|
|
|
|
|
|
transformer_path = "models/Personalized_Model/FLUX.2-dev-Fun-Controlnet-Union.safetensors" |
|
|
vae_path = None |
|
|
lora_path = None |
|
|
|
|
|
|
|
|
sample_size = [1728, 992] |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
image = None |
|
|
control_image = None |
|
|
inpaint_image = "asset/8.png" |
|
|
mask_image = "asset/mask.png" |
|
|
control_context_scale = 0.75 |
|
|
|
|
|
|
|
|
|
|
|
prompt = "This is a panoramic portrait photo of a young man. He has medium-length flowing hair with a soft lavender-like color. He is wearing a white sleeveless shirt with a blue ribbon bow tied around the collar. He has a confident posture, with his left hand naturally hanging down and his right hand in his pocket, and his legs slightly apart. He looks straight at the camera. The sea breeze gently brushes his hair as he stands on the sunny seaside path, surrounded by blooming purple seaside flowers and smooth pebbles, with the sparkling sea and blue sky behind him. The scene presents a bright summer atmosphere, with soft and natural lighting, realistic details, and 8K ultra high definition image quality, clearly presenting fine textures such as clothing and hair." |
|
|
negative_prompt = " " |
|
|
guidance_scale = 4.00 |
|
|
seed = 43 |
|
|
num_inference_steps = 50 |
|
|
lora_weight = 0.55 |
|
|
save_path = "samples/flux2-t2i-control" |
|
|
|
|
|
device = set_multi_gpus_devices(ulysses_degree, ring_degree) |
|
|
config = OmegaConf.load(config_path) |
|
|
|
|
|
transformer = Flux2ControlTransformer2DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer", |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
|
|
).to(weight_dtype) |
|
|
|
|
|
if transformer_path is not None: |
|
|
print(f"From checkpoint: {transformer_path}") |
|
|
if transformer_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
vae = AutoencoderKLFlux2.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
if vae_path is not None: |
|
|
print(f"From checkpoint: {vae_path}") |
|
|
if vae_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(vae_path) |
|
|
else: |
|
|
state_dict = torch.load(vae_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
tokenizer = PixtralProcessor.from_pretrained( |
|
|
model_name, subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
Chosen_Scheduler = scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
}[sampler_name] |
|
|
scheduler = Chosen_Scheduler.from_pretrained( |
|
|
model_name, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
pipeline = Flux2ControlPipeline( |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
|
|
|
if ulysses_degree > 1 or ring_degree > 1: |
|
|
from functools import partial |
|
|
transformer.enable_multi_gpus_inference() |
|
|
if fsdp_dit: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks)) |
|
|
pipeline.transformer = shard_fn(pipeline.transformer) |
|
|
print("Add FSDP DIT") |
|
|
if fsdp_text_encoder: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"]) |
|
|
text_encoder = shard_fn(text_encoder) |
|
|
print("Add FSDP TEXT ENCODER") |
|
|
|
|
|
if compile_dit: |
|
|
for i in range(len(pipeline.transformer.transformer_blocks)): |
|
|
pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) |
|
|
print("Add Compile") |
|
|
|
|
|
if GPU_memory_mode == "sequential_cpu_offload": |
|
|
pipeline.enable_sequential_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload": |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_full_load_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.to(device=device) |
|
|
else: |
|
|
pipeline.to(device=device) |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if image is not None: |
|
|
if not isinstance(image, list): |
|
|
image = get_image(image) |
|
|
else: |
|
|
image = [get_image(_image) for _image in image] |
|
|
|
|
|
if inpaint_image is not None: |
|
|
inpaint_image = get_image_latent(inpaint_image, sample_size=sample_size)[:, :, 0] |
|
|
else: |
|
|
inpaint_image = torch.zeros([1, 3, sample_size[0], sample_size[1]]) |
|
|
|
|
|
if mask_image is not None: |
|
|
mask_image = get_image_latent(mask_image, sample_size=sample_size)[:, :1, 0] |
|
|
else: |
|
|
mask_image = torch.ones([1, 1, sample_size[0], sample_size[1]]) * 255 |
|
|
|
|
|
if control_image is not None: |
|
|
control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
|
|
|
|
|
sample = pipeline( |
|
|
prompt = prompt, |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
image = image, |
|
|
inpaint_image = inpaint_image, |
|
|
mask_image = mask_image, |
|
|
control_image = control_image, |
|
|
num_inference_steps = num_inference_steps, |
|
|
control_context_scale = control_context_scale, |
|
|
).images |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
def save_results(): |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
video_path = os.path.join(save_path, prefix + ".png") |
|
|
image = sample[0] |
|
|
image.save(video_path) |
|
|
|
|
|
if ulysses_degree * ring_degree > 1: |
|
|
import torch.distributed as dist |
|
|
if dist.get_rank() == 0: |
|
|
save_results() |
|
|
else: |
|
|
save_results() |