Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import random | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| from transformers import AutoTokenizer, Qwen3ForCausalLM | |
| from controlnet_aux.processor import Processor | |
| from PIL import Image | |
| # Try to import ControlNet components, fall back to basic pipeline if unavailable | |
| try: | |
| from videox_fun.pipeline import ZImageControlPipeline | |
| from videox_fun.models import ZImageControlTransformer2DModel | |
| CONTROLNET_AVAILABLE = True | |
| except ImportError: | |
| from diffusers import ZImagePipeline | |
| CONTROLNET_AVAILABLE = False | |
| print("ControlNet components not available. Running in basic mode.") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1280 | |
| # Configuration | |
| MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" | |
| CONTROLNET_WEIGHTS = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" # Optional local path | |
| print("Loading Z-Image Turbo model...") | |
| print("This may take a few minutes on first run...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weight_dtype = torch.bfloat16 | |
| # Load models | |
| if CONTROLNET_AVAILABLE: | |
| print("Loading with ControlNet support...") | |
| # Load transformer with control layers | |
| transformer = ZImageControlTransformer2DModel.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="transformer", | |
| transformer_additional_kwargs={ | |
| "control_layers_places": [0, 5, 10, 15, 20, 25], | |
| "control_in_dim": 16 | |
| }, | |
| ).to(device, weight_dtype) | |
| # Optionally load ControlNet weights if available | |
| try: | |
| from safetensors.torch import load_file | |
| import os | |
| if os.path.exists(CONTROLNET_WEIGHTS): | |
| print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}") | |
| state_dict = load_file(CONTROLNET_WEIGHTS) | |
| state_dict = state_dict.get("state_dict", state_dict) | |
| m, u = transformer.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded ControlNet: {len(m)} missing keys, {len(u)} unexpected keys") | |
| except Exception as e: | |
| print(f"Could not load ControlNet weights: {e}") | |
| # Load other components | |
| vae = AutoencoderKL.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="vae", | |
| ).to(device, weight_dtype) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="tokenizer" | |
| ) | |
| text_encoder = Qwen3ForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="text_encoder", | |
| torch_dtype=weight_dtype, | |
| ).to(device) | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="scheduler" | |
| ) | |
| pipe = ZImageControlPipeline( | |
| vae=vae, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| ) | |
| pipe.to(device, weight_dtype) | |
| else: | |
| print("Loading basic Z-Image Turbo (no ControlNet)...") | |
| pipe = ZImagePipeline.from_pretrained( | |
| MODEL_REPO, | |
| torch_dtype=weight_dtype, | |
| low_cpu_mem_usage=False, | |
| ) | |
| pipe.to(device) | |
| print(f"Model loaded successfully on {device}!") | |
| def rescale_image(image, scale, divisible_by=16): | |
| """Rescale image and ensure dimensions are divisible by specified value.""" | |
| width, height = image.size | |
| new_width = int(width * scale) | |
| new_height = int(height * scale) | |
| # Make dimensions divisible by divisible_by | |
| new_width = (new_width // divisible_by) * divisible_by | |
| new_height = (new_height // divisible_by) * divisible_by | |
| # Clamp to max size | |
| if new_width > MAX_IMAGE_SIZE: | |
| new_width = MAX_IMAGE_SIZE | |
| if new_height > MAX_IMAGE_SIZE: | |
| new_height = MAX_IMAGE_SIZE | |
| resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| return resized, new_width, new_height | |
| def get_image_latent(image, sample_size): | |
| """Convert PIL image to VAE latent representation.""" | |
| import torchvision.transforms as transforms | |
| # Normalize image | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0).unsqueeze(2) # [B, C, 1, H, W] | |
| img_tensor = img_tensor.to(device, weight_dtype) | |
| with torch.no_grad(): | |
| latent = pipe.vae.encode(img_tensor).latent_dist.sample() | |
| latent = latent * pipe.vae.config.scaling_factor | |
| return latent | |
| def generate_image( | |
| prompt, | |
| negative_prompt="blurry, ugly, bad quality", | |
| input_image=None, | |
| control_mode="Canny", | |
| control_context_scale=0.75, | |
| image_scale=1.0, | |
| num_inference_steps=9, | |
| guidance_scale=1.0, | |
| seed=42, | |
| randomize_seed=True, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """Generate image with optional ControlNet guidance.""" | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt to generate an image.") | |
| # Set seed | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| # Basic generation (no control image) | |
| if input_image is None or not CONTROLNET_AVAILABLE: | |
| if input_image is not None and not CONTROLNET_AVAILABLE: | |
| gr.Warning("ControlNet not available. Generating without control image.") | |
| progress(0.1, desc="Generating image...") | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| height=1024, | |
| width=1024, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=0.0 if not CONTROLNET_AVAILABLE else guidance_scale, | |
| generator=generator, | |
| ) | |
| image = result.images[0] | |
| progress(1.0, desc="Complete!") | |
| return image, seed, None | |
| # ControlNet generation | |
| progress(0.1, desc="Processing control image...") | |
| # Map control mode to processor | |
| processor_map = { | |
| 'Canny': 'canny', | |
| 'HED': 'softedge_hed', | |
| 'Depth': 'depth_midas', | |
| 'MLSD': 'mlsd', | |
| 'Pose': 'openpose_full' | |
| } | |
| processor_id = processor_map.get(control_mode, 'canny') | |
| processor = Processor(processor_id) | |
| # Process control image | |
| control_image, width, height = rescale_image(input_image, image_scale, 16) | |
| control_image_1024 = control_image.resize((1024, 1024)) | |
| progress(0.3, desc=f"Applying {control_mode} detection...") | |
| control_image_processed = processor(control_image_1024, to_pil=True) | |
| control_image_processed = control_image_processed.resize((width, height)) | |
| # Convert to latent | |
| progress(0.5, desc="Converting to latent space...") | |
| control_image_torch = get_image_latent( | |
| control_image_processed, | |
| sample_size=[height, width] | |
| )[:, :, 0] | |
| # Generate with control | |
| progress(0.6, desc="Generating controlled image...") | |
| try: | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| guidance_scale=guidance_scale, | |
| control_image=control_image_torch, | |
| num_inference_steps=num_inference_steps, | |
| control_context_scale=control_context_scale, | |
| ) | |
| image = result.images[0] | |
| progress(1.0, desc="Complete!") | |
| return image, seed, control_image_processed | |
| except Exception as e: | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # Apple-style CSS | |
| apple_css = """ | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| padding: 48px 20px !important; | |
| font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important; | |
| } | |
| .header-container { | |
| text-align: center; | |
| margin-bottom: 48px; | |
| } | |
| .main-title { | |
| font-size: 56px !important; | |
| font-weight: 600 !important; | |
| letter-spacing: -0.02em !important; | |
| color: #1d1d1f !important; | |
| margin: 0 0 12px 0 !important; | |
| } | |
| .subtitle { | |
| font-size: 21px !important; | |
| color: #6e6e73 !important; | |
| margin: 0 0 24px 0 !important; | |
| } | |
| .info-badge { | |
| display: inline-block; | |
| background: #0071e3; | |
| color: white; | |
| padding: 6px 16px; | |
| border-radius: 20px; | |
| font-size: 14px; | |
| font-weight: 500; | |
| margin-bottom: 16px; | |
| } | |
| textarea { | |
| font-size: 17px !important; | |
| border-radius: 12px !important; | |
| border: 1px solid #d2d2d7 !important; | |
| padding: 12px 16px !important; | |
| } | |
| textarea:focus { | |
| border-color: #0071e3 !important; | |
| box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important; | |
| outline: none !important; | |
| } | |
| button.primary { | |
| font-size: 17px !important; | |
| padding: 12px 32px !important; | |
| border-radius: 980px !important; | |
| background: #0071e3 !important; | |
| border: none !important; | |
| color: #ffffff !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| button.primary:hover { | |
| background: #0077ed !important; | |
| transform: scale(1.02) !important; | |
| } | |
| .footer-text { | |
| text-align: center; | |
| margin-top: 48px; | |
| font-size: 14px !important; | |
| color: #86868b !important; | |
| } | |
| @media (max-width: 768px) { | |
| .main-title { font-size: 40px !important; } | |
| .subtitle { font-size: 19px !important; } | |
| } | |
| """ | |
| # Create interface | |
| with gr.Blocks(css=apple_css, title="Z-Image Turbo with ControlNet") as demo: | |
| # Header | |
| gr.HTML(f""" | |
| <div class="header-container"> | |
| <div class="info-badge">{'✓ ControlNet Enabled' if CONTROLNET_AVAILABLE else '⚠ Basic Mode'}</div> | |
| <h1 class="main-title">Z-Image Turbo</h1> | |
| <p class="subtitle">Transform your ideas into stunning visuals with AI-powered control</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left column - Inputs | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the image you want to create...", | |
| lines=3, | |
| max_lines=6, | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="What to avoid in the image...", | |
| value="blurry, ugly, bad quality", | |
| lines=2, | |
| ) | |
| if CONTROLNET_AVAILABLE: | |
| input_image = gr.Image( | |
| label="Control Image (Optional)", | |
| type="pil", | |
| sources=['upload', 'clipboard'], | |
| height=290, | |
| ) | |
| control_mode = gr.Radio( | |
| choices=["Canny", "Depth", "HED", "MLSD", "Pose"], | |
| value="Canny", | |
| label="Control Mode", | |
| info="Choose edge/depth/pose detection method" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=9, | |
| info="More steps = higher quality but slower" | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=1.0, | |
| info="How closely to follow the prompt" | |
| ) | |
| if CONTROLNET_AVAILABLE: | |
| control_context_scale = gr.Slider( | |
| label="Control Strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.75, | |
| info="0.65-0.80 recommended for best results" | |
| ) | |
| image_scale = gr.Slider( | |
| label="Image Scale", | |
| minimum=0.5, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.0, | |
| info="Resize control image" | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| generate_btn = gr.Button( | |
| "Generate Image", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="primary" | |
| ) | |
| # Right column - Outputs | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| show_label=True, | |
| ) | |
| seed_output = gr.Number( | |
| label="Used Seed", | |
| precision=0, | |
| ) | |
| if CONTROLNET_AVAILABLE: | |
| with gr.Accordion("Preprocessor Output", open=False): | |
| control_output = gr.Image( | |
| label="Processed Control Image", | |
| type="pil", | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div class="footer-text"> | |
| <p style="margin-bottom: 8px;">Powered by Z-Image Turbo from Tongyi-MAI</p> | |
| <p style="font-size: 13px;"> | |
| <a href="https://huggingface.co/Tongyi-MAI/Z-Image-Turbo" style="color: #0071e3; text-decoration: none; margin: 0 8px;"> | |
| Model Card | |
| </a> • | |
| <a href="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" style="color: #0071e3; text-decoration: none; margin: 0 8px;"> | |
| ControlNet | |
| </a> • | |
| <a href="https://github.com/aigc-apps/VideoX-Fun" style="color: #0071e3; text-decoration: none; margin: 0 8px;"> | |
| GitHub | |
| </a> | |
| </p> | |
| </div> | |
| """) | |
| # Event handlers | |
| generate_inputs = [ | |
| prompt, | |
| negative_prompt, | |
| ] | |
| if CONTROLNET_AVAILABLE: | |
| generate_inputs.extend([ | |
| input_image, | |
| control_mode, | |
| control_context_scale, | |
| image_scale, | |
| ]) | |
| generate_inputs.extend([ | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| ]) | |
| generate_outputs = [output_image, seed_output, control_output] | |
| else: | |
| # Add None placeholders for missing ControlNet params | |
| generate_inputs.extend([ | |
| gr.State(None), # input_image | |
| gr.State("Canny"), # control_mode | |
| gr.State(0.75), # control_context_scale | |
| gr.State(1.0), # image_scale | |
| ]) | |
| generate_inputs.extend([ | |
| num_inference_steps, | |
| guidance_scale, | |
| seed, | |
| randomize_seed, | |
| ]) | |
| generate_outputs = [output_image, seed_output, gr.State(None)] | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=generate_inputs, | |
| outputs=generate_outputs, | |
| ) | |
| prompt.submit( | |
| fn=generate_image, | |
| inputs=generate_inputs, | |
| outputs=generate_outputs, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| ) |