Spaces:
Runtime error
Runtime error
File size: 3,490 Bytes
a02cb2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import os
import io
import base64
import random
import spaces
import torch
from PIL import Image
from gradio import Server
from fastapi.responses import HTMLResponse
from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker"
MAX_SEED = 2**32 - 1
# ---------------------------------------------------------------------------
# Module-level model loading (runs once at startup, before ZeroGPU kicks in)
# ---------------------------------------------------------------------------
vae = AutoencoderKL.from_pretrained(
MODEL_PATH, subfolder="vae",
torch_dtype=torch.bfloat16, device_map="cuda",
)
text_encoder = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, subfolder="text_encoder",
torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, subfolder="tokenizer", padding_side="left",
)
pipe = ZImagePipeline(
vae=vae, text_encoder=text_encoder,
tokenizer=tokenizer, scheduler=None, transformer=None,
)
transformer = ZImageTransformer2DModel.from_pretrained(
MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16,
)
transformer = transformer.to("cuda")
pipe.transformer = transformer
pipe.transformer.set_attention_backend("flash_3")
# Safety checker
safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH)
feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH)
pipe.safety_checker = safety_checker.to("cuda")
pipe.feature_extractor = feature_extractor
# ---------------------------------------------------------------------------
# Server setup
# ---------------------------------------------------------------------------
app = Server()
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
@spaces.GPU
@app.api(name="generate")
def generate(
prompt: str,
width: int = 1024,
height: int = 1024,
seed: int = -1,
) -> str:
"""Generate an image from a text prompt. Returns base64-encoded PNG."""
# Clamp to multiples of 64
width = max(256, min(2048, (width // 64) * 64))
height = max(256, min(2048, (height // 64) * 64))
if seed < 0:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# Fresh scheduler per call (it's stateful)
pipe.scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=3.0,
)
result = pipe(
prompt,
height=height,
width=width,
guidance_scale=0.0,
num_inference_steps=9,
generator=generator,
max_sequence_length=256,
)
image = result.images[0]
# Encode as base64 PNG
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}'
demo = app
if __name__ == "__main__":
demo.launch(show_error=True)
|