| import gradio as gr |
| import numpy as np |
| import random, json, spaces, torch, time |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler |
|
|
| from transformers import AutoTokenizer, Qwen3ForCausalLM |
| from diffusers import AutoencoderKL |
| from utils.image_utils import get_image_latent, rescale_image |
| from utils.prompt_utils import polish_prompt |
| from utils import repo_utils |
| |
| from controlnet_aux.processor import Processor |
|
|
| |
| repo_utils.clone_repo_if_not_exists("https://github.com/aigc-apps/VideoX-Fun.git", "app/repos") |
| repo_utils.move_folder("app/repos/VideoX-Fun/videox_fun", "app/videox_fun") |
| from videox_fun.pipeline import ZImageControlPipeline |
| from videox_fun.models import ZImageControlTransformer2DModel |
|
|
| repo_utils.clone_repo_if_not_exists("https://huggingface.co/Tongyi-MAI/Z-Image-Turbo", "app/models") |
| repo_utils.clone_repo_if_not_exists("https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", "app/models") |
|
|
|
|
| |
| MAX_SEED = np.iinfo(np.int32).max |
| MAX_IMAGE_SIZE = 1280 |
|
|
| |
| MODEL_LOCAL = "models/Z-Image-Turbo/" |
| |
| TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union-2.0/Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors" |
|
|
| weight_dtype = torch.bfloat16 |
|
|
| |
| transformer = ZImageControlTransformer2DModel.from_pretrained( |
| MODEL_LOCAL, |
| subfolder="transformer", |
| transformer_additional_kwargs={ |
| "control_layers_places": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], |
| "control_refiner_layers_places": [0, 1], |
| "add_control_noise_refiner": True, |
| "control_in_dim": 33 |
| }, |
| ).to("cuda", torch.bfloat16) |
|
|
| if TRANSFORMER_LOCAL is not None: |
| print(f"From checkpoint: {TRANSFORMER_LOCAL}") |
| from safetensors.torch import load_file, safe_open |
| state_dict = load_file(TRANSFORMER_LOCAL) |
| state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
| m, u = transformer.load_state_dict(state_dict, strict=False) |
| print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
| |
| vae = AutoencoderKL.from_pretrained( |
| MODEL_LOCAL, |
| subfolder="vae", |
| device_map="cuda" |
| ).to(weight_dtype) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_LOCAL, |
| subfolder="tokenizer" |
| ) |
| text_encoder = Qwen3ForCausalLM.from_pretrained( |
| MODEL_LOCAL, |
| subfolder="text_encoder", |
| torch_dtype=weight_dtype, |
| ) |
| |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( |
| MODEL_LOCAL, |
| subfolder="scheduler" |
| ) |
| pipe = ZImageControlPipeline( |
| vae=vae, |
| tokenizer=tokenizer, |
| text_encoder=text_encoder, |
| transformer=transformer, |
| scheduler=scheduler, |
| ) |
| pipe.to("cuda", torch.bfloat16) |
| print("pipe ready.") |
| |
| |
|
|
| |
| |
| |
|
|
| def prepare(prompt, is_polish_prompt): |
| if not is_polish_prompt: return prompt, False |
| polished_prompt = polish_prompt(prompt) |
| return polished_prompt, True |
|
|
| @spaces.GPU |
| def inference( |
| prompt, |
| negative_prompt, |
| input_image, |
| image_scale=1.0, |
| control_mode='Canny', |
| control_context_scale = 0.75, |
| seed=42, |
| randomize_seed=True, |
| guidance_scale=1.5, |
| num_inference_steps=8, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| timestamp = time.time() |
| print(f"timestamp: {timestamp}") |
|
|
| |
| print("DEBUG: process image") |
| if input_image is None: |
| print("Error: input_image is empty.") |
| return None |
| |
| |
|
|
| |
| |
| processor_id = 'canny' |
| if control_mode == 'HED': |
| processor_id = 'softedge_hed' |
| if control_mode =='Depth': |
| processor_id = 'depth_midas' |
| if control_mode =='MLSD': |
| processor_id = 'mlsd' |
| if control_mode =='Pose': |
| processor_id = 'openpose_full' |
|
|
| print(f"DEBUG: processor_id={processor_id}") |
| processor = Processor(processor_id) |
|
|
| |
| control_image, width, height = rescale_image(input_image, image_scale, 16) |
| control_image = control_image.resize((1024, 1024)) |
|
|
| print("DEBUG: control_image_torch") |
| sample_size = [height, width] |
| control_image = processor(control_image, to_pil=True) |
| control_image = control_image.resize((width, height)) |
| control_image_torch = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| if randomize_seed: seed = random.randint(0, MAX_SEED) |
| generator = torch.Generator().manual_seed(seed) |
|
|
| image = pipe( |
| prompt=prompt, |
| negative_prompt = negative_prompt, |
| height=height, |
| width=width, |
| generator=generator, |
| guidance_scale=guidance_scale, |
| image = None, |
| mask_image = None, |
| control_image=control_image_torch, |
| num_inference_steps=num_inference_steps, |
| control_context_scale=control_context_scale, |
| ).images[0] |
|
|
| return image, seed, control_image |
|
|
|
|
| def read_file(path: str) -> str: |
| with open(path, 'r', encoding='utf-8') as f: |
| content = f.read() |
| return content |
|
|
|
|
| css = """ |
| #col-container { |
| margin: 0 auto; |
| max-width: 960px; |
| } |
| """ |
|
|
| with open('static/data.json', 'r') as file: |
| data = json.load(file) |
| examples = data['examples'] |
|
|
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="col-container"): |
| with gr.Column(): |
| gr.HTML(read_file("static/header.html")) |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| height=290, sources=['upload', 'clipboard'], |
| image_mode='RGB', |
| |
| type="pil", label="Upload") |
| |
| prompt = gr.Textbox( |
| label="Prompt", |
| show_label=False, |
| lines=2, |
| placeholder="Enter your prompt", |
| |
| ) |
| is_polish_prompt = gr.Checkbox(label="Polish prompt", value=True) |
| control_mode = gr.Radio( |
| choices=["Canny", "Depth", "HED", "MLSD", "Pose"], |
| value="Canny", |
| label="Control Mode" |
| ) |
| run_button = gr.Button("Generate", variant="primary") |
| with gr.Accordion("Advanced Settings", open=False): |
| |
| negative_prompt = gr.Textbox( |
| label="Negative prompt", |
| lines=2, |
| container=False, |
| placeholder="Enter your negative prompt", |
| value="blurry ugly bad" |
| ) |
| with gr.Row(): |
| num_inference_steps = gr.Slider( |
| label="Steps", |
| minimum=1, |
| maximum=30, |
| step=1, |
| value=9, |
| ) |
| control_context_scale = gr.Slider( |
| label="Context scale", |
| minimum=0.0, |
| maximum=1.0, |
| step=0.01, |
| value=0.75, |
| ) |
|
|
| with gr.Row(): |
| guidance_scale = gr.Slider( |
| label="Guidance scale", |
| minimum=0.0, |
| maximum=10.0, |
| step=0.1, |
| value=1.0, |
| ) |
|
|
| image_scale = gr.Slider( |
| label="Image scale", |
| minimum=0.5, |
| maximum=2.0, |
| step=0.1, |
| value=1.0, |
| ) |
|
|
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=MAX_SEED, |
| step=1, |
| value=42, |
| ) |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
|
|
| with gr.Column(): |
| output_image = gr.Image(label="Generated image", show_label=False) |
| polished_prompt = gr.Textbox(label="Polished prompt", interactive=False) |
|
|
| with gr.Accordion("Preprocessor output", open=False): |
| control_image = gr.Image(label="Control image", show_label=False) |
| |
|
|
| |
| gr.Examples(examples=examples, inputs=[input_image, prompt, control_mode]) |
| gr.Markdown(read_file("static/footer.md")) |
|
|
| run_button.click( |
| fn=prepare, |
| inputs=[prompt, is_polish_prompt], |
| outputs=[polished_prompt, is_polish_prompt] |
| |
| ).then( |
| fn=inference, |
| inputs=[ |
| polished_prompt, |
| negative_prompt, |
| input_image, |
| image_scale, |
| control_mode, |
| control_context_scale, |
| seed, |
| randomize_seed, |
| guidance_scale, |
| num_inference_steps, |
| ], |
| outputs=[output_image, seed, control_image], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(mcp_server=True) |
|
|