import os import io import base64 import random import spaces import torch from PIL import Image from gradio import Server from fastapi.responses import HTMLResponse from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker" MAX_SEED = 2**32 - 1 # --------------------------------------------------------------------------- # Module-level model loading (runs once at startup, before ZeroGPU kicks in) # --------------------------------------------------------------------------- vae = AutoencoderKL.from_pretrained( MODEL_PATH, subfolder="vae", torch_dtype=torch.bfloat16, device_map="cuda", ) text_encoder = AutoModelForCausalLM.from_pretrained( MODEL_PATH, subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, subfolder="tokenizer", padding_side="left", ) pipe = ZImagePipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=None, transformer=None, ) transformer = ZImageTransformer2DModel.from_pretrained( MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16, ) transformer = transformer.to("cuda") pipe.transformer = transformer pipe.transformer.set_attention_backend("flash_3") # Safety checker safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH) feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH) pipe.safety_checker = safety_checker.to("cuda") pipe.feature_extractor = feature_extractor # --------------------------------------------------------------------------- # Server setup # --------------------------------------------------------------------------- app = Server() @app.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return f.read() @spaces.GPU @app.api(name="generate") def generate( prompt: str, width: int = 1024, height: int = 1024, seed: int = -1, ) -> str: """Generate an image from a text prompt. Returns base64-encoded PNG.""" # Clamp to multiples of 64 width = max(256, min(2048, (width // 64) * 64)) height = max(256, min(2048, (height // 64) * 64)) if seed < 0: seed = random.randint(0, MAX_SEED) generator = torch.Generator("cuda").manual_seed(seed) # Fresh scheduler per call (it's stateful) pipe.scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=1000, shift=3.0, ) result = pipe( prompt, height=height, width=width, guidance_scale=0.0, num_inference_steps=9, generator=generator, max_sequence_length=256, ) image = result.images[0] # Encode as base64 PNG buf = io.BytesIO() image.save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}' demo = app if __name__ == "__main__": demo.launch(show_error=True)