import spaces import json import logging import os import random import re import sys import warnings from PIL import Image import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline ) from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel # Environment setup os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") logging.getLogger("transformers").setLevel(logging.ERROR) MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" # Resolution options RESOLUTION_OPTIONS = { "1024": [ "1024x1024 (1:1)", "1152x896 (9:7)", "896x1152 (7:9)", "1152x864 (4:3)", "864x1152 (3:4)", "1248x832 (3:2)", "832x1248 (2:3)", "1280x720 (16:9)", "720x1280 (9:16)", "1344x576 (21:9)", "576x1344 (9:21)" ], "1280": [ "1280x1280 (1:1)", "1440x1120 (9:7)", "1120x1440 (7:9)" ], "1536": [ "1536x1536 (1:1)", "1728x1344 (9:7)", "1344x1728 (7:9)", "1728x1296 (4:3)", "1296x1728 (3:4)", "1872x1248 (3:2)", "1248x1872 (2:3)", "2048x1152 (16:9)", "1152x2048 (9:16)", "2016x864 (21:9)", "864x2016 (9:21)" ] } RESOLUTION_SET = [] for resolutions in RESOLUTION_OPTIONS.values(): RESOLUTION_SET.extend(resolutions) EXAMPLE_PROMPTS = [ "一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。", "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。", "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。", ] # Model loading function def load_model(model_path, enable_compile=False): print(f"Loading model from {model_path}...") # Simplified model loading logic vae = AutoencoderKL.from_pretrained( f"{model_path}", subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", ) text_encoder = AutoModelForCausalLM.from_pretrained( f"{model_path}", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", ).eval() tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer")) # Initialize pipeline pipe = ZImagePipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, ) # Load transformer transformer = ZImageTransformer2DModel.from_pretrained( f"{model_path}", subfolder="transformer", ) pipe.transformer = transformer pipe.to("cuda", torch.bfloat16) return pipe # Image generation function @spaces.GPU def generate_image( pipe, prompt, resolution="1024x1024 (1:1)", seed=42, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True), ): """Generate image using Z-Image model""" width, height = 1024, 1024 # Default resolution # Parse resolution string match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) if match: width, height = int(match.group(1))), int(match.group(2))) generator = torch.Generator("cuda").manual_seed(seed) scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=1000, shift=3.0 ) pipe.scheduler = scheduler # Generate image image = pipe( prompt=prompt, height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, ).images[0] return image # Initialize the model pipe = None try: pipe = load_model(MODEL_PATH, enable_compile=ENABLE_COMPILE) print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") # Main application with gr.Blocks( title="Z-Image Turbo", theme=gr.themes.Soft(), footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"] ) as demo: # Header section with gr.Row(): gr.Markdown(""" # Z-Image Turbo *Efficient Image Generation with Single-Stream Diffusion Transformer* """) # Main content area with gr.Row(): with gr.Column(scale=1): # Prompt input prompt_input = gr.Textbox( label="Describe your image", placeholder="Enter a detailed description of what you want to generate...", lines=3 ) # Settings in accordion with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): resolution_dropdown = gr.Dropdown( choices=RESOLUTION_SET, value="1024x1024 (1:1)", label="Resolution" ) seed_input = gr.Number( label="Seed", value=42, precision=0 ) random_seed_check = gr.Checkbox( label="Use random seed", value=True ) # Generate button generate_btn = gr.Button( "Generate Image 🎨", variant="primary", size="lg" ) # Examples gr.Examples( examples=EXAMPLE_PROMPTS, inputs=prompt_input, label="Try these examples:" ) with gr.Column(scale=1): # Output gallery output_gallery = gr.Gallery( label="Generated Images", columns=2, height=500 ) # Generation handler def handle_generation(prompt, resolution, seed, use_random_seed): if not prompt.strip(): raise gr.Error("Please enter a prompt") if use_random_seed: actual_seed = random.randint(1, 1000000) else: actual_seed = int(seed) if seed != -1 else random.randint(1, 1000000) # Generate image image = generate_image( pipe=pipe, prompt=prompt, resolution=resolution, seed=actual_seed, ) return [image], str(actual_seed), actual_seed generate_btn.click( fn=handle_generation, inputs=[prompt_input, resolution_dropdown, seed_input, random_seed_check], outputs=[output_gallery, gr.Textbox(label="Seed Used"), gr.Number(label="Seed Value")], api_visibility="public" ) # Mobile optimization CSS css = """ .gradio-container { max-width: 100% !important; padding: 10px !important; } .mobile-optimized { min-height: 400px !important; } """ demo.launch( css=css, mcp_server=True )