|
|
import os |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
|
|
for project_root in project_roots: |
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
|
|
|
from videox_fun.dist import set_multi_gpus_devices, shard_model |
|
|
from videox_fun.models import (AutoencoderKLFlux2, |
|
|
Mistral3ForConditionalGeneration, |
|
|
PixtralProcessor, Flux2ControlTransformer2DModel) |
|
|
from videox_fun.models.cache_utils import get_teacache_coefficients |
|
|
from videox_fun.pipeline import Flux2ControlPipeline |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, |
|
|
convert_weight_dtype_wrapper) |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import (filter_kwargs, get_image, get_image_latent, |
|
|
get_image_to_video_latent, |
|
|
get_video_to_video_latent, |
|
|
save_videos_grid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "model_cpu_offload" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_degree = 1 |
|
|
ring_degree = 1 |
|
|
|
|
|
fsdp_dit = False |
|
|
fsdp_text_encoder = False |
|
|
|
|
|
|
|
|
compile_dit = False |
|
|
|
|
|
|
|
|
config_path = "config/flux2/flux2_control.yaml" |
|
|
|
|
|
model_name = "models/Diffusion_Transformer/FLUX.2-dev" |
|
|
|
|
|
|
|
|
sampler_name = "Flow" |
|
|
|
|
|
|
|
|
transformer_path = "models/Personalized_Model/FLUX.2-dev-Fun-Controlnet-Union.safetensors" |
|
|
vae_path = None |
|
|
lora_path = None |
|
|
|
|
|
|
|
|
sample_size = [1728, 992] |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
image = "asset/8.png" |
|
|
control_image = "asset/pose.jpg" |
|
|
inpaint_image = None |
|
|
mask_image = None |
|
|
control_context_scale = 0.75 |
|
|
|
|
|
|
|
|
|
|
|
prompt = "This is a panoramic portrait photo of a young woman. She has flowing long hair and a soft lavender like color. She is wearing a white sleeveless dress with a blue ribbon bow tied around the collar. She has a confident posture, with her left hand naturally hanging down and her right hand in her pocket, and her legs slightly apart. Look straight at the camera. The sea breeze gently brushed her long hair, and they stood on the sunny seaside path, surrounded by blooming purple seaside flowers and smooth pebbles, with the sparkling sea and blue sky behind them. The screen presents a bright summer atmosphere, with soft and natural lighting, realistic details, and 8K ultra high definition image quality, clearly presenting fine textures such as clothing and hair. " |
|
|
negative_prompt = " " |
|
|
guidance_scale = 4.00 |
|
|
seed = 43 |
|
|
num_inference_steps = 50 |
|
|
lora_weight = 0.55 |
|
|
save_path = "samples/flux2-t2i-control" |
|
|
|
|
|
device = set_multi_gpus_devices(ulysses_degree, ring_degree) |
|
|
config = OmegaConf.load(config_path) |
|
|
|
|
|
transformer = Flux2ControlTransformer2DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer", |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
|
|
).to(weight_dtype) |
|
|
|
|
|
if transformer_path is not None: |
|
|
print(f"From checkpoint: {transformer_path}") |
|
|
if transformer_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
vae = AutoencoderKLFlux2.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
if vae_path is not None: |
|
|
print(f"From checkpoint: {vae_path}") |
|
|
if vae_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(vae_path) |
|
|
else: |
|
|
state_dict = torch.load(vae_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
tokenizer = PixtralProcessor.from_pretrained( |
|
|
model_name, subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
Chosen_Scheduler = scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
}[sampler_name] |
|
|
scheduler = Chosen_Scheduler.from_pretrained( |
|
|
model_name, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
pipeline = Flux2ControlPipeline( |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
|
|
|
if ulysses_degree > 1 or ring_degree > 1: |
|
|
from functools import partial |
|
|
transformer.enable_multi_gpus_inference() |
|
|
if fsdp_dit: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks)) |
|
|
pipeline.transformer = shard_fn(pipeline.transformer) |
|
|
print("Add FSDP DIT") |
|
|
if fsdp_text_encoder: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"]) |
|
|
text_encoder = shard_fn(text_encoder) |
|
|
print("Add FSDP TEXT ENCODER") |
|
|
|
|
|
if compile_dit: |
|
|
for i in range(len(pipeline.transformer.transformer_blocks)): |
|
|
pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) |
|
|
print("Add Compile") |
|
|
|
|
|
if GPU_memory_mode == "sequential_cpu_offload": |
|
|
pipeline.enable_sequential_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload": |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_full_load_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.to(device=device) |
|
|
else: |
|
|
pipeline.to(device=device) |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if image is not None: |
|
|
if not isinstance(image, list): |
|
|
image = get_image(image) |
|
|
else: |
|
|
image = [get_image(_image) for _image in image] |
|
|
|
|
|
if inpaint_image is not None: |
|
|
inpaint_image = get_image_latent(inpaint_image, sample_size=sample_size)[:, :, 0] |
|
|
else: |
|
|
inpaint_image = torch.zeros([1, 3, sample_size[0], sample_size[1]]) |
|
|
|
|
|
if mask_image is not None: |
|
|
mask_image = get_image_latent(mask_image, sample_size=sample_size)[:, :1, 0] |
|
|
else: |
|
|
mask_image = torch.ones([1, 1, sample_size[0], sample_size[1]]) * 255 |
|
|
|
|
|
if control_image is not None: |
|
|
control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
|
|
|
|
|
sample = pipeline( |
|
|
prompt = prompt, |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
image = image, |
|
|
inpaint_image = inpaint_image, |
|
|
mask_image = mask_image, |
|
|
control_image = control_image, |
|
|
num_inference_steps = num_inference_steps, |
|
|
control_context_scale = control_context_scale, |
|
|
).images |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
|
|
|
def save_results(): |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
video_path = os.path.join(save_path, prefix + ".png") |
|
|
image = sample[0] |
|
|
image.save(video_path) |
|
|
|
|
|
if ulysses_degree * ring_degree > 1: |
|
|
import torch.distributed as dist |
|
|
if dist.get_rank() == 0: |
|
|
save_results() |
|
|
else: |
|
|
save_results() |