akhaliq's picture
akhaliq HF Staff
Update app.py
bca471a verified
raw
history blame
15.9 kB
import spaces
import gradio as gr
import torch
import numpy as np
import random
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, Qwen3ForCausalLM
from controlnet_aux.processor import Processor
from PIL import Image
# Try to import ControlNet components, fall back to basic pipeline if unavailable
try:
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel
CONTROLNET_AVAILABLE = True
except ImportError:
from diffusers import ZImagePipeline
CONTROLNET_AVAILABLE = False
print("ControlNet components not available. Running in basic mode.")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# Configuration
MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
CONTROLNET_WEIGHTS = "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" # Optional local path
print("Loading Z-Image Turbo model...")
print("This may take a few minutes on first run...")
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.bfloat16
# Load models
if CONTROLNET_AVAILABLE:
print("Loading with ControlNet support...")
# Load transformer with control layers
transformer = ZImageControlTransformer2DModel.from_pretrained(
MODEL_REPO,
subfolder="transformer",
transformer_additional_kwargs={
"control_layers_places": [0, 5, 10, 15, 20, 25],
"control_in_dim": 16
},
).to(device, weight_dtype)
# Optionally load ControlNet weights if available
try:
from safetensors.torch import load_file
import os
if os.path.exists(CONTROLNET_WEIGHTS):
print(f"Loading ControlNet weights from {CONTROLNET_WEIGHTS}")
state_dict = load_file(CONTROLNET_WEIGHTS)
state_dict = state_dict.get("state_dict", state_dict)
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"Loaded ControlNet: {len(m)} missing keys, {len(u)} unexpected keys")
except Exception as e:
print(f"Could not load ControlNet weights: {e}")
# Load other components
vae = AutoencoderKL.from_pretrained(
MODEL_REPO,
subfolder="vae",
).to(device, weight_dtype)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_REPO,
subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_REPO,
subfolder="text_encoder",
torch_dtype=weight_dtype,
).to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
MODEL_REPO,
subfolder="scheduler"
)
pipe = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
pipe.to(device, weight_dtype)
else:
print("Loading basic Z-Image Turbo (no ControlNet)...")
pipe = ZImagePipeline.from_pretrained(
MODEL_REPO,
torch_dtype=weight_dtype,
low_cpu_mem_usage=False,
)
pipe.to(device)
print(f"Model loaded successfully on {device}!")
def rescale_image(image, scale, divisible_by=16):
"""Rescale image and ensure dimensions are divisible by specified value."""
width, height = image.size
new_width = int(width * scale)
new_height = int(height * scale)
# Make dimensions divisible by divisible_by
new_width = (new_width // divisible_by) * divisible_by
new_height = (new_height // divisible_by) * divisible_by
# Clamp to max size
if new_width > MAX_IMAGE_SIZE:
new_width = MAX_IMAGE_SIZE
if new_height > MAX_IMAGE_SIZE:
new_height = MAX_IMAGE_SIZE
resized = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return resized, new_width, new_height
def get_image_latent(image, sample_size):
"""Convert PIL image to VAE latent representation."""
import torchvision.transforms as transforms
# Normalize image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
img_tensor = transform(image).unsqueeze(0).unsqueeze(2) # [B, C, 1, H, W]
img_tensor = img_tensor.to(device, weight_dtype)
with torch.no_grad():
latent = pipe.vae.encode(img_tensor).latent_dist.sample()
latent = latent * pipe.vae.config.scaling_factor
return latent
@spaces.GPU()
def generate_image(
prompt,
negative_prompt="blurry, ugly, bad quality",
input_image=None,
control_mode="Canny",
control_context_scale=0.75,
image_scale=1.0,
num_inference_steps=9,
guidance_scale=1.0,
seed=42,
randomize_seed=True,
progress=gr.Progress(track_tqdm=True)
):
"""Generate image with optional ControlNet guidance."""
if not prompt.strip():
raise gr.Error("Please enter a prompt to generate an image.")
# Set seed
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device).manual_seed(seed)
# Basic generation (no control image)
if input_image is None or not CONTROLNET_AVAILABLE:
if input_image is not None and not CONTROLNET_AVAILABLE:
gr.Warning("ControlNet not available. Generating without control image.")
progress(0.1, desc="Generating image...")
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=0.0 if not CONTROLNET_AVAILABLE else guidance_scale,
generator=generator,
)
image = result.images[0]
progress(1.0, desc="Complete!")
return image, seed, None
# ControlNet generation
progress(0.1, desc="Processing control image...")
# Map control mode to processor
processor_map = {
'Canny': 'canny',
'HED': 'softedge_hed',
'Depth': 'depth_midas',
'MLSD': 'mlsd',
'Pose': 'openpose_full'
}
processor_id = processor_map.get(control_mode, 'canny')
processor = Processor(processor_id)
# Process control image
control_image, width, height = rescale_image(input_image, image_scale, 16)
control_image_1024 = control_image.resize((1024, 1024))
progress(0.3, desc=f"Applying {control_mode} detection...")
control_image_processed = processor(control_image_1024, to_pil=True)
control_image_processed = control_image_processed.resize((width, height))
# Convert to latent
progress(0.5, desc="Converting to latent space...")
control_image_torch = get_image_latent(
control_image_processed,
sample_size=[height, width]
)[:, :, 0]
# Generate with control
progress(0.6, desc="Generating controlled image...")
try:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
height=height,
width=width,
generator=generator,
guidance_scale=guidance_scale,
control_image=control_image_torch,
num_inference_steps=num_inference_steps,
control_context_scale=control_context_scale,
)
image = result.images[0]
progress(1.0, desc="Complete!")
return image, seed, control_image_processed
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
# Apple-style CSS
apple_css = """
.gradio-container {
max-width: 1200px !important;
margin: 0 auto !important;
padding: 48px 20px !important;
font-family: -apple-system, BlinkMacSystemFont, 'Inter', 'Segoe UI', sans-serif !important;
}
.header-container {
text-align: center;
margin-bottom: 48px;
}
.main-title {
font-size: 56px !important;
font-weight: 600 !important;
letter-spacing: -0.02em !important;
color: #1d1d1f !important;
margin: 0 0 12px 0 !important;
}
.subtitle {
font-size: 21px !important;
color: #6e6e73 !important;
margin: 0 0 24px 0 !important;
}
.info-badge {
display: inline-block;
background: #0071e3;
color: white;
padding: 6px 16px;
border-radius: 20px;
font-size: 14px;
font-weight: 500;
margin-bottom: 16px;
}
textarea {
font-size: 17px !important;
border-radius: 12px !important;
border: 1px solid #d2d2d7 !important;
padding: 12px 16px !important;
}
textarea:focus {
border-color: #0071e3 !important;
box-shadow: 0 0 0 4px rgba(0, 113, 227, 0.15) !important;
outline: none !important;
}
button.primary {
font-size: 17px !important;
padding: 12px 32px !important;
border-radius: 980px !important;
background: #0071e3 !important;
border: none !important;
color: #ffffff !important;
transition: all 0.2s ease !important;
}
button.primary:hover {
background: #0077ed !important;
transform: scale(1.02) !important;
}
.footer-text {
text-align: center;
margin-top: 48px;
font-size: 14px !important;
color: #86868b !important;
}
@media (max-width: 768px) {
.main-title { font-size: 40px !important; }
.subtitle { font-size: 19px !important; }
}
"""
# Create interface
with gr.Blocks(title="Z-Image Turbo with ControlNet") as demo:
# Header
gr.HTML(f"""
<div class="header-container">
<div class="info-badge">{'✓ ControlNet Enabled' if CONTROLNET_AVAILABLE else '⚠ Basic Mode'}</div>
<h1 class="main-title">Z-Image Turbo</h1>
<p class="subtitle">Transform your ideas into stunning visuals with AI-powered control</p>
</div>
""")
with gr.Row():
# Left column - Inputs
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to create...",
lines=3,
max_lines=6,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What to avoid in the image...",
value="blurry, ugly, bad quality",
lines=2,
)
if CONTROLNET_AVAILABLE:
input_image = gr.Image(
label="Control Image (Optional)",
type="pil",
sources=['upload', 'clipboard'],
height=290,
)
control_mode = gr.Radio(
choices=["Canny", "Depth", "HED", "MLSD", "Pose"],
value="Canny",
label="Control Mode",
info="Choose edge/depth/pose detection method"
)
with gr.Accordion("Advanced Settings", open=False):
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=30,
step=1,
value=9,
info="More steps = higher quality but slower"
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
info="How closely to follow the prompt"
)
if CONTROLNET_AVAILABLE:
control_context_scale = gr.Slider(
label="Control Strength",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.75,
info="0.65-0.80 recommended for best results"
)
image_scale = gr.Slider(
label="Image Scale",
minimum=0.5,
maximum=2.0,
step=0.1,
value=1.0,
info="Resize control image"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
generate_btn = gr.Button(
"Generate Image",
variant="primary",
size="lg",
elem_classes="primary"
)
# Right column - Outputs
with gr.Column(scale=1):
output_image = gr.Image(
label="Generated Image",
type="pil",
show_label=True,
)
seed_output = gr.Number(
label="Used Seed",
precision=0,
)
if CONTROLNET_AVAILABLE:
with gr.Accordion("Preprocessor Output", open=False):
control_output = gr.Image(
label="Processed Control Image",
type="pil",
)
# Footer
gr.HTML("""
<div class="footer-text">
<p style="margin-bottom: 8px;">Powered by Z-Image Turbo from Tongyi-MAI</p>
<p style="font-size: 13px;">
<a href="https://huggingface.co/Tongyi-MAI/Z-Image-Turbo" style="color: #0071e3; text-decoration: none; margin: 0 8px;">
Model Card
</a> •
<a href="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" style="color: #0071e3; text-decoration: none; margin: 0 8px;">
ControlNet
</a> •
<a href="https://github.com/aigc-apps/VideoX-Fun" style="color: #0071e3; text-decoration: none; margin: 0 8px;">
GitHub
</a>
</p>
</div>
""")
# Event handlers
generate_inputs = [
prompt,
negative_prompt,
]
if CONTROLNET_AVAILABLE:
generate_inputs.extend([
input_image,
control_mode,
control_context_scale,
image_scale,
])
generate_inputs.extend([
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
])
generate_outputs = [output_image, seed_output, control_output]
else:
# Add None placeholders for missing ControlNet params
generate_inputs.extend([
gr.State(None), # input_image
gr.State("Canny"), # control_mode
gr.State(0.75), # control_context_scale
gr.State(1.0), # image_scale
])
generate_inputs.extend([
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
])
generate_outputs = [output_image, seed_output, gr.State(None)]
generate_btn.click(
fn=generate_image,
inputs=generate_inputs,
outputs=generate_outputs,
)
prompt.submit(
fn=generate_image,
inputs=generate_inputs,
outputs=generate_outputs,
)
if __name__ == "__main__":
demo.launch(
share=False,
show_error=True,
css=apple_css,
)