import spaces import gradio as gr import torch import numpy as np import random from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from transformers import AutoTokenizer, Qwen3ForCausalLM from controlnet_aux.processor import Processor from PIL import Image # Try to import ControlNet components, fall back to basic pipeline if unavailable try: from videox_fun.pipeline import ZImageControlPipeline from videox_fun.models import ZImageControlTransformer2DModel CONTROLNET_AVAILABLE = True except ImportError: from diffusers import ZImagePipeline CONTROLNET_AVAILABLE = False print("ControlNet components not available. Running in basic mode.") MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1280 # Configuration MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" CONTROLNET_WEIGHTS = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" # Optional local path print("Loading Z-Image Turbo model...") print("This may take a few minutes on first run...") device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = torch.bfloat16 # Load models if CONTROLNET_AVAILABLE: print("Loading with ControlNet support...") # Load transformer with control layers transformer = ZImageControlTransformer2DModel.from_pretrained( MODEL_REPO, subfolder="transformer", transformer_additional_kwargs={ "control_layers_places": [0, 5, 10, 15, 20, 25], "control_in_dim": 16 }, ).to(device, weight_dtype) # Optionally load ControlNet weights if available try: from safetensors.torch import load_file import os if os.path.exists(CONTROLNET_WEIGHTS): print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}") state_dict = load_file(CONTROLNET_WEIGHTS) state_dict = state_dict.get("state_dict", state_dict) m, u = transformer.load_state_dict(state_dict, strict=False) print(f"Loaded ControlNet: {len(m)} missing keys, {len(u)} unexpected keys") except Exception as e: print(f"Could not load ControlNet weights: {e}") # Load other components vae = AutoencoderKL.from_pretrained( MODEL_REPO, subfolder="vae", ).to(device, weight_dtype) tokenizer = AutoTokenizer.from_pretrained( MODEL_REPO, subfolder="tokenizer" ) text_encoder = Qwen3ForCausalLM.from_pretrained( MODEL_REPO, subfolder="text_encoder", torch_dtype=weight_dtype, ).to(device) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( MODEL_REPO, subfolder="scheduler" ) pipe = ZImageControlPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) pipe.to(device, weight_dtype) else: print("Loading basic Z-Image Turbo (no ControlNet)...") pipe = ZImagePipeline.from_pretrained( MODEL_REPO, torch_dtype=weight_dtype, low_cpu_mem_usage=False, ) pipe.to(device) print(f"Model loaded successfully on {device}!") def rescale_image(image, scale, divisible_by=16): """Rescale image and ensure dimensions are divisible by specified value.""" width, height = image.size new_width = int(width * scale) new_height = int(height * scale) # Make dimensions divisible by divisible_by new_width = (new_width // divisible_by) * divisible_by new_height = (new_height // divisible_by) * divisible_by # Clamp to max size if new_width > MAX_IMAGE_SIZE: new_width = MAX_IMAGE_SIZE if new_height > MAX_IMAGE_SIZE: new_height = MAX_IMAGE_SIZE resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS) return resized, new_width, new_height def get_image_latent(image, sample_size): """Convert PIL image to VAE latent representation.""" import torchvision.transforms as transforms # Normalize image transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) img_tensor = transform(image).unsqueeze(0).unsqueeze(2) # [B, C, 1, H, W] img_tensor = img_tensor.to(device, weight_dtype) with torch.no_grad(): latent = pipe.vae.encode(img_tensor).latent_dist.sample() latent = latent * pipe.vae.config.scaling_factor return latent @spaces.GPU() def generate_image( prompt, negative_prompt="blurry, ugly, bad quality", input_image=None, control_mode="Canny", control_context_scale=0.75, image_scale=1.0, num_inference_steps=9, guidance_scale=1.0, seed=42, randomize_seed=True, progress=gr.Progress(track_tqdm=True) ): """Generate image with optional ControlNet guidance.""" if not prompt.strip(): raise gr.Error("Please enter a prompt to generate an image.") # Set seed if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device).manual_seed(seed) # Basic generation (no control image) if input_image is None or not CONTROLNET_AVAILABLE: if input_image is not None and not CONTROLNET_AVAILABLE: gr.Warning("ControlNet not available. Generating without control image.") progress(0.1, desc="Generating image...") result = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, height=1024, width=1024, num_inference_steps=num_inference_steps, guidance_scale=0.0 if not CONTROLNET_AVAILABLE else guidance_scale, generator=generator, ) image = result.images[0] progress(1.0, desc="Complete!") return image, seed, None # ControlNet generation progress(0.1, desc="Processing control image...") # Map control mode to processor processor_map = { 'Canny': 'canny', 'HED': 'softedge_hed', 'Depth': 'depth_midas', 'MLSD': 'mlsd', 'Pose': 'openpose_full' } processor_id = processor_map.get(control_mode, 'canny') processor = Processor(processor_id) # Process control image control_image, width, height = rescale_image(input_image, image_scale, 16) control_image_1024 = control_image.resize((1024, 1024)) progress(0.3, desc=f"Applying {control_mode} detection...") control_image_processed = processor(control_image_1024, to_pil=True) control_image_processed = control_image_processed.resize((width, height)) # Convert to latent progress(0.5, desc="Converting to latent space...") control_image_torch = get_image_latent( control_image_processed, sample_size=[height, width] )[:, :, 0] # Generate with control progress(0.6, desc="Generating controlled image...") try: result = pipe( prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, height=height, width=width, generator=generator, guidance_scale=guidance_scale, control_image=control_image_torch, num_inference_steps=num_inference_steps, control_context_scale=control_context_scale, ) image = result.images[0] progress(1.0, desc="Complete!") return image, seed, control_image_processed except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # Apple-style CSS apple_css = """ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; padding: 48px 20px !important; font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important; } .header-container { text-align: center; margin-bottom: 48px; } .main-title { font-size: 56px !important; font-weight: 600 !important; letter-spacing: -0.02em !important; color: #1d1d1f !important; margin: 0 0 12px 0 !important; } .subtitle { font-size: 21px !important; color: #6e6e73 !important; margin: 0 0 24px 0 !important; } .info-badge { display: inline-block; background: #0071e3; color: white; padding: 6px 16px; border-radius: 20px; font-size: 14px; font-weight: 500; margin-bottom: 16px; } textarea { font-size: 17px !important; border-radius: 12px !important; border: 1px solid #d2d2d7 !important; padding: 12px 16px !important; } textarea:focus { border-color: #0071e3 !important; box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; outline: none !important; } button.primary { font-size: 17px !important; padding: 12px 32px !important; border-radius: 980px !important; background: #0071e3 !important; border: none !important; color: #ffffff !important; transition: all 0.2s ease !important; } button.primary:hover { background: #0077ed !important; transform: scale(1.02) !important; } .footer-text { text-align: center; margin-top: 48px; font-size: 14px !important; color: #86868b !important; } @media (max-width: 768px) { .main-title { font-size: 40px !important; } .subtitle { font-size: 19px !important; } } """ # Create interface with gr.Blocks(title="Z-Image Turbo with ControlNet") as demo: # Header gr.HTML(f"""
{'✓ ControlNet Enabled' if CONTROLNET_AVAILABLE else '⚠ Basic Mode'}

Z-Image Turbo

Transform your ideas into stunning visuals with AI-powered control

""") with gr.Row(): # Left column - Inputs with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to create...", lines=3, max_lines=6, ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="What to avoid in the image...", value="blurry, ugly, bad quality", lines=2, ) if CONTROLNET_AVAILABLE: input_image = gr.Image( label="Control Image (Optional)", type="pil", sources=['upload', 'clipboard'], height=290, ) control_mode = gr.Radio( choices=["Canny", "Depth", "HED", "MLSD", "Pose"], value="Canny", label="Control Mode", info="Choose edge/depth/pose detection method" ) with gr.Accordion("Advanced Settings", open=False): num_inference_steps = gr.Slider( label="Inference Steps", minimum=1, maximum=30, step=1, value=9, info="More steps = higher quality but slower" ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=1.0, info="How closely to follow the prompt" ) if CONTROLNET_AVAILABLE: control_context_scale = gr.Slider( label="Control Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.75, info="0.65-0.80 recommended for best results" ) image_scale = gr.Slider( label="Image Scale", minimum=0.5, maximum=2.0, step=0.1, value=1.0, info="Resize control image" ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox( label="Randomize Seed", value=True ) generate_btn = gr.Button( "Generate Image", variant="primary", size="lg", elem_classes="primary" ) # Right column - Outputs with gr.Column(scale=1): output_image = gr.Image( label="Generated Image", type="pil", show_label=True, ) seed_output = gr.Number( label="Used Seed", precision=0, ) if CONTROLNET_AVAILABLE: with gr.Accordion("Preprocessor Output", open=False): control_output = gr.Image( label="Processed Control Image", type="pil", ) # Footer gr.HTML(""" """) # Event handlers generate_inputs = [ prompt, negative_prompt, ] if CONTROLNET_AVAILABLE: generate_inputs.extend([ input_image, control_mode, control_context_scale, image_scale, ]) generate_inputs.extend([ num_inference_steps, guidance_scale, seed, randomize_seed, ]) generate_outputs = [output_image, seed_output, control_output] else: # Add None placeholders for missing ControlNet params generate_inputs.extend([ gr.State(None), # input_image gr.State("Canny"), # control_mode gr.State(0.75), # control_context_scale gr.State(1.0), # image_scale ]) generate_inputs.extend([ num_inference_steps, guidance_scale, seed, randomize_seed, ]) generate_outputs = [output_image, seed_output, gr.State(None)] generate_btn.click( fn=generate_image, inputs=generate_inputs, outputs=generate_outputs, ) prompt.submit( fn=generate_image, inputs=generate_inputs, outputs=generate_outputs, ) if __name__ == "__main__": demo.launch( share=False, show_error=True, css=apple_css, )