Spaces:
Runtime error
Runtime error
| import spaces | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import warnings | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from diffusers import ( | |
| AutoencoderKL, | |
| FlowMatchEulerDiscreteScheduler, | |
| ZImagePipeline | |
| ) | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| # Environment setup | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| # Resolution options | |
| RESOLUTION_OPTIONS = { | |
| "1024": [ | |
| "1024x1024 (1:1)", "1152x896 (9:7)", "896x1152 (7:9)", | |
| "1152x864 (4:3)", "864x1152 (3:4)", "1248x832 (3:2)", | |
| "832x1248 (2:3)", "1280x720 (16:9)", "720x1280 (9:16)", "1344x576 (21:9)", "576x1344 (9:21)" | |
| ], | |
| "1280": [ | |
| "1280x1280 (1:1)", "1440x1120 (9:7)", "1120x1440 (7:9)" | |
| ], | |
| "1536": [ | |
| "1536x1536 (1:1)", "1728x1344 (9:7)", "1344x1728 (7:9)", | |
| "1728x1296 (4:3)", "1296x1728 (3:4)", "1872x1248 (3:2)", "1248x1872 (2:3)", | |
| "2048x1152 (16:9)", "1152x2048 (9:16)", "2016x864 (21:9)", "864x2016 (9:21)" | |
| ] | |
| } | |
| RESOLUTION_SET = [] | |
| for resolutions in RESOLUTION_OPTIONS.values(): | |
| RESOLUTION_SET.extend(resolutions) | |
| EXAMPLE_PROMPTS = [ | |
| "一位男士和他的贵宾犬穿着配套的服装参加狗狗秀,室内灯光,背景中有观众。", | |
| "极具氛围感的暗调人像,一位优雅的中国美女在黑暗的房间里。", | |
| "一张中景手机自拍照片拍摄了一位留着长黑发的年轻东亚女子在灯光明亮的电梯内对着镜子自拍。", | |
| ] | |
| # Model loading function | |
| def load_model(model_path, enable_compile=False): | |
| print(f"Loading model from {model_path}...") | |
| # Simplified model loading logic | |
| vae = AutoencoderKL.from_pretrained( | |
| f"{model_path}", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| f"{model_path}", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer")) | |
| # Initialize pipeline | |
| pipe = ZImagePipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| ) | |
| # Load transformer | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| f"{model_path}", | |
| subfolder="transformer", | |
| ) | |
| pipe.transformer = transformer | |
| pipe.to("cuda", torch.bfloat16) | |
| return pipe | |
| # Image generation function | |
| def generate_image( | |
| pipe, | |
| prompt, | |
| resolution="1024x1024 (1:1)", | |
| seed=42, | |
| guidance_scale=5.0, | |
| num_inference_steps=50, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Generate image using Z-Image model""" | |
| width, height = 1024, 1024 # Default resolution | |
| # Parse resolution string | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution) | |
| if match: | |
| width, height = int(match.group(1))), int(match.group(2))) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=1000, | |
| shift=3.0 | |
| ) | |
| pipe.scheduler = scheduler | |
| # Generate image | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| ).images[0] | |
| return image | |
| # Initialize the model | |
| pipe = None | |
| try: | |
| pipe = load_model(MODEL_PATH, enable_compile=ENABLE_COMPILE) | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Main application | |
| with gr.Blocks( | |
| title="Z-Image Turbo", | |
| theme=gr.themes.Soft(), | |
| footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"] | |
| ) as demo: | |
| # Header section | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| # Z-Image Turbo | |
| *Efficient Image Generation with Single-Stream Diffusion Transformer* | |
| """) | |
| # Main content area | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Prompt input | |
| prompt_input = gr.Textbox( | |
| label="Describe your image", | |
| placeholder="Enter a detailed description of what you want to generate...", | |
| lines=3 | |
| ) | |
| # Settings in accordion | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| resolution_dropdown = gr.Dropdown( | |
| choices=RESOLUTION_SET, | |
| value="1024x1024 (1:1)", | |
| label="Resolution" | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=42, | |
| precision=0 | |
| ) | |
| random_seed_check = gr.Checkbox( | |
| label="Use random seed", | |
| value=True | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button( | |
| "Generate Image 🎨", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| label="Try these examples:" | |
| ) | |
| with gr.Column(scale=1): | |
| # Output gallery | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=2, | |
| height=500 | |
| ) | |
| # Generation handler | |
| def handle_generation(prompt, resolution, seed, use_random_seed): | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt") | |
| if use_random_seed: | |
| actual_seed = random.randint(1, 1000000) | |
| else: | |
| actual_seed = int(seed) if seed != -1 else random.randint(1, 1000000) | |
| # Generate image | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=prompt, | |
| resolution=resolution, | |
| seed=actual_seed, | |
| ) | |
| return [image], str(actual_seed), actual_seed | |
| generate_btn.click( | |
| fn=handle_generation, | |
| inputs=[prompt_input, resolution_dropdown, seed_input, random_seed_check], | |
| outputs=[output_gallery, gr.Textbox(label="Seed Used"), gr.Number(label="Seed Value")], | |
| api_visibility="public" | |
| ) | |
| # Mobile optimization CSS | |
| css = """ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 10px !important; | |
| } | |
| .mobile-optimized { | |
| min-height: 400px !important; | |
| } | |
| """ | |
| demo.launch( | |
| css=css, | |
| mcp_server=True | |
| ) |