Spaces:
Running on Zero
Running on Zero
| import os | |
| import io | |
| import base64 | |
| import random | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from gradio import Server | |
| from fastapi.responses import HTMLResponse | |
| from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler | |
| from diffusers import AutoencoderKL | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker" | |
| MAX_SEED = 2**32 - 1 | |
| # --------------------------------------------------------------------------- | |
| # Module-level model loading (runs once at startup, before ZeroGPU kicks in) | |
| # --------------------------------------------------------------------------- | |
| vae = AutoencoderKL.from_pretrained( | |
| MODEL_PATH, subfolder="vae", | |
| torch_dtype=torch.bfloat16, device_map="cuda", | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, subfolder="tokenizer", padding_side="left", | |
| ) | |
| pipe = ZImagePipeline( | |
| vae=vae, text_encoder=text_encoder, | |
| tokenizer=tokenizer, scheduler=None, transformer=None, | |
| ) | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16, | |
| ) | |
| transformer = transformer.to("cuda") | |
| pipe.transformer = transformer | |
| pipe.transformer.set_attention_backend("flash_3") | |
| # Safety checker | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH) | |
| feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH) | |
| pipe.safety_checker = safety_checker.to("cuda") | |
| pipe.feature_extractor = feature_extractor | |
| # --------------------------------------------------------------------------- | |
| # Server setup | |
| # --------------------------------------------------------------------------- | |
| app = Server() | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| def generate( | |
| prompt: str, | |
| width: int = 1024, | |
| height: int = 1024, | |
| seed: int = -1, | |
| ) -> str: | |
| """Generate an image from a text prompt. Returns base64-encoded PNG.""" | |
| # Clamp to multiples of 64 | |
| width = max(256, min(2048, (width // 64) * 64)) | |
| height = max(256, min(2048, (height // 64) * 64)) | |
| if seed < 0: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| # Fresh scheduler per call (it's stateful) | |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=1000, shift=3.0, | |
| ) | |
| result = pipe( | |
| prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=0.0, | |
| num_inference_steps=9, | |
| generator=generator, | |
| max_sequence_length=256, | |
| ) | |
| image = result.images[0] | |
| # Encode as base64 PNG | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}' | |
| demo = app | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |