| | import os |
| | import sys |
| | import random |
| | from pathlib import Path |
| | from typing import Optional, Tuple |
| |
|
| | import gradio as gr |
| | import torch |
| | from PIL import Image |
| | from diffusers import FlowMatchEulerDiscreteScheduler |
| | from omegaconf import OmegaConf |
| | from safetensors.torch import load_file |
| |
|
| | |
| | CURRENT_FILE = Path(__file__).resolve() |
| | PROJECT_ROOTS = [CURRENT_FILE.parent, CURRENT_FILE.parent.parent, CURRENT_FILE.parent.parent.parent] |
| | for root in PROJECT_ROOTS: |
| | root_str = str(root) |
| | if root_str not in sys.path: |
| | sys.path.insert(0, root_str) |
| | REPO_ROOT = PROJECT_ROOTS[-1] |
| |
|
| | from videox_fun.models import ( |
| | AutoencoderKL, |
| | AutoTokenizer, |
| | Qwen3ForCausalLM, |
| | ZImageControlTransformer2DModel, |
| | ) |
| | from videox_fun.pipeline import ZImageControlPipeline |
| | from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| | from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| |
|
| | |
| | CONFIG_PATH = REPO_ROOT / "config" / "z_image" / "z_image_control.yaml" |
| | MODEL_NAME = REPO_ROOT / "models" / "Diffusion_Transformer" / "Z-Image-Turbo" |
| | TRANSFORMER_CKPT = REPO_ROOT / "models" / "Personalized_Model" / "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" |
| | DEFAULT_POSE_PATH = REPO_ROOT / "asset" / "pose_1024x1024.png" |
| | SAMPLERS = { |
| | "Flow": FlowMatchEulerDiscreteScheduler, |
| | "Flow_Unipc": FlowUniPCMultistepScheduler, |
| | "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| | } |
| | DEFAULT_SAMPLER = "Flow" |
| |
|
| | PIPELINE: Optional[ZImageControlPipeline] = None |
| | PIPELINE_DEVICE: Optional[torch.device] = None |
| | PIPELINE_DTYPE: Optional[torch.dtype] = None |
| |
|
| |
|
| | def _pick_dtype() -> torch.dtype: |
| | if torch.cuda.is_available(): |
| | if torch.cuda.is_bf16_supported(): |
| | return torch.bfloat16 |
| | return torch.float16 |
| | return torch.float32 |
| |
|
| |
|
| | def _load_pipeline() -> Tuple[ZImageControlPipeline, torch.device, torch.dtype]: |
| | global PIPELINE, PIPELINE_DEVICE, PIPELINE_DTYPE |
| | if PIPELINE is not None: |
| | return PIPELINE, PIPELINE_DEVICE, PIPELINE_DTYPE |
| |
|
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | dtype = _pick_dtype() |
| |
|
| | config = OmegaConf.load(CONFIG_PATH) |
| |
|
| | transformer = ZImageControlTransformer2DModel.from_pretrained( |
| | MODEL_NAME, |
| | subfolder="transformer", |
| | low_cpu_mem_usage=True, |
| | torch_dtype=dtype, |
| | transformer_additional_kwargs=OmegaConf.to_container(config["transformer_additional_kwargs"]), |
| | ).to(dtype) |
| |
|
| | if TRANSFORMER_CKPT.exists(): |
| | if TRANSFORMER_CKPT.suffix == ".safetensors": |
| | state_dict = load_file(TRANSFORMER_CKPT) |
| | else: |
| | state_dict = torch.load(TRANSFORMER_CKPT, map_location="cpu") |
| | if "state_dict" in state_dict: |
| | state_dict = state_dict["state_dict"] |
| | missing, unexpected = transformer.load_state_dict(state_dict, strict=False) |
| | print(f"[load] transformer ckpt loaded, missing={len(missing)}, unexpected={len(unexpected)}") |
| | else: |
| | print(f"[warn] transformer checkpoint not found at {TRANSFORMER_CKPT}, using base weights") |
| |
|
| | vae = AutoencoderKL.from_pretrained(MODEL_NAME, subfolder="vae").to(dtype) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer") |
| | text_encoder = Qwen3ForCausalLM.from_pretrained( |
| | MODEL_NAME, |
| | subfolder="text_encoder", |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | ) |
| |
|
| | scheduler_cls = SAMPLERS.get(DEFAULT_SAMPLER, FlowMatchEulerDiscreteScheduler) |
| | scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler") |
| |
|
| | pipe = ZImageControlPipeline( |
| | vae=vae, |
| | tokenizer=tokenizer, |
| | text_encoder=text_encoder, |
| | transformer=transformer, |
| | scheduler=scheduler, |
| | ) |
| |
|
| | if torch.cuda.is_available(): |
| | pipe.to(device=device) |
| | else: |
| | |
| | pipe.to(device) |
| |
|
| | PIPELINE = pipe |
| | PIPELINE_DEVICE = device |
| | PIPELINE_DTYPE = dtype |
| | return pipe, device, dtype |
| |
|
| |
|
| | def _ensure_pose_image(pose_image: Optional[Image.Image]) -> Image.Image: |
| | if pose_image is None: |
| | return Image.open(DEFAULT_POSE_PATH).convert("RGB") |
| | if pose_image.mode != "RGB": |
| | pose_image = pose_image.convert("RGB") |
| | return pose_image |
| |
|
| |
|
| | def _align_size(value: int) -> int: |
| | |
| | return max(256, (value // 16) * 16) |
| |
|
| |
|
| | def infer( |
| | prompt: str, |
| | negative_prompt: str, |
| | pose_image: Optional[Image.Image], |
| | height: int, |
| | width: int, |
| | steps: int, |
| | guidance_scale: float, |
| | control_strength: float, |
| | seed: int, |
| | ): |
| | pipe, device, _ = _load_pipeline() |
| |
|
| | if not prompt.strip(): |
| | raise gr.Error("提示词不能为空") |
| |
|
| | pose_image = _ensure_pose_image(pose_image) |
| | height = _align_size(height) |
| | width = _align_size(width) |
| |
|
| | if seed is None or seed < 0: |
| | seed = random.randint(1, 2**31 - 1) |
| | generator = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | with torch.inference_mode(): |
| | result = pipe( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | height=height, |
| | width=width, |
| | num_inference_steps=int(steps), |
| | guidance_scale=float(guidance_scale), |
| | generator=generator, |
| | control_image=pose_image, |
| | control_context_scale=float(control_strength), |
| | max_sequence_length=128, |
| | ).images[0] |
| |
|
| | return result, seed |
| |
|
| |
|
| | def build_ui(): |
| | css = """ |
| | .compact-slider {padding-top: 4px; padding-bottom: 4px;} |
| | """ |
| |
|
| | with gr.Blocks(title="Z-Image Turbo 文生图 (Pose)", css=css) as demo: |
| | gr.Markdown("## Z-Image Turbo 文生图 (含 Pose 控制)") |
| | gr.Markdown( |
| | "上传姿态图,输入提示词即可生成图像。右侧为缩略图预览,可放大/下载原分辨率。", |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=320): |
| | prompt = gr.Textbox( |
| | label="提示词", |
| | placeholder="描述你想生成的画面", |
| | lines=4, |
| | value="1 girl, on the beach, summer, full body, highly detailed", |
| | ) |
| | negative_prompt = gr.Textbox( |
| | label="反向提示词", |
| | placeholder="不希望出现的元素,例如 '低质量, 模糊'", |
| | lines=3, |
| | value="lowres, blurry, text, watermark", |
| | ) |
| | steps = gr.Slider(minimum=4, maximum=30, step=1, value=9, label="采样步数", elem_classes=["compact-slider"]) |
| | guidance_scale = gr.Slider(minimum=0.0, maximum=6.0, step=0.1, value=0.0, label="CFG 指数 (>=1 生效)", elem_classes=["compact-slider"]) |
| | control_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0.75, label="Pose 强度", elem_classes=["compact-slider"]) |
| | height = gr.Slider(minimum=512, maximum=1792, step=16, value=1024, label="高度 (16 的倍数)", elem_classes=["compact-slider"]) |
| | width = gr.Slider(minimum=512, maximum=1792, step=16, value=1024, label="宽度 (16 的倍数)", elem_classes=["compact-slider"]) |
| | seed = gr.Number(value=-1, label="随机种子 (-1 表示随机)", precision=0) |
| | run_btn = gr.Button("生成", variant="primary") |
| |
|
| | with gr.Column(scale=2.4): |
| | with gr.Row(): |
| | with gr.Column(scale=0.8, min_width=200): |
| | pose_image = gr.Image( |
| | label="姿态图上传 (RGB)", |
| | type="pil", |
| | height=320, |
| | width=240, |
| | show_download_button=True, |
| | ) |
| | with gr.Column(scale=2.6): |
| | result_img = gr.Image( |
| | label="生成结果 (缩略图)", |
| | type="pil", |
| | height=520, |
| | show_download_button=True, |
| | show_fullscreen_button=True, |
| | ) |
| | used_seed = gr.Number(label="实际种子", precision=0) |
| |
|
| | run_btn.click( |
| | infer, |
| | inputs=[prompt, negative_prompt, pose_image, height, width, steps, guidance_scale, control_strength, seed], |
| | outputs=[result_img, used_seed], |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | def main(): |
| | _load_pipeline() |
| | demo = build_ui() |
| | demo.queue().launch(server_name="0.0.0.0", server_port=7860, inbrowser=False, share=False) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|