File size: 3,490 Bytes
a02cb2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import io
import base64
import random

import spaces
import torch
from PIL import Image
from gradio import Server
from fastapi.responses import HTMLResponse
from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor

MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker"
MAX_SEED = 2**32 - 1

# ---------------------------------------------------------------------------
# Module-level model loading (runs once at startup, before ZeroGPU kicks in)
# ---------------------------------------------------------------------------

vae = AutoencoderKL.from_pretrained(
    MODEL_PATH, subfolder="vae",
    torch_dtype=torch.bfloat16, device_map="cuda",
)

text_encoder = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH, subfolder="text_encoder",
    torch_dtype=torch.bfloat16, device_map="cuda",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, subfolder="tokenizer", padding_side="left",
)

pipe = ZImagePipeline(
    vae=vae, text_encoder=text_encoder,
    tokenizer=tokenizer, scheduler=None, transformer=None,
)

transformer = ZImageTransformer2DModel.from_pretrained(
    MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16,
)
transformer = transformer.to("cuda")
pipe.transformer = transformer

pipe.transformer.set_attention_backend("flash_3")

# Safety checker
safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH)
feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH)
pipe.safety_checker = safety_checker.to("cuda")
pipe.feature_extractor = feature_extractor

# ---------------------------------------------------------------------------
# Server setup
# ---------------------------------------------------------------------------

app = Server()


@app.get("/", response_class=HTMLResponse)
async def homepage():
    html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
    with open(html_path, "r", encoding="utf-8") as f:
        return f.read()


@spaces.GPU
@app.api(name="generate")
def generate(
    prompt: str,
    width: int = 1024,
    height: int = 1024,
    seed: int = -1,
) -> str:
    """Generate an image from a text prompt. Returns base64-encoded PNG."""
    # Clamp to multiples of 64
    width = max(256, min(2048, (width // 64) * 64))
    height = max(256, min(2048, (height // 64) * 64))

    if seed < 0:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator("cuda").manual_seed(seed)

    # Fresh scheduler per call (it's stateful)
    pipe.scheduler = FlowMatchEulerDiscreteScheduler(
        num_train_timesteps=1000, shift=3.0,
    )

    result = pipe(
        prompt,
        height=height,
        width=width,
        guidance_scale=0.0,
        num_inference_steps=9,
        generator=generator,
        max_sequence_length=256,
    )
    image = result.images[0]

    # Encode as base64 PNG
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

    return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}'


demo = app

if __name__ == "__main__":
    demo.launch(show_error=True)