z-image-studio / app.py
victor's picture
victor HF Staff
Reduce GPU duration to 30s (saves quota)
84a2935
import os
import io
import base64
import random
import spaces
import torch
from gradio import Server
from fastapi.responses import HTMLResponse
from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = 2**32 - 1
# ---------------------------------------------------------------------------
# Module-level model loading (component-by-component, matching reference space)
# ---------------------------------------------------------------------------
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
MODEL_PATH, subfolder="vae",
torch_dtype=torch.bfloat16, device_map="cuda",
)
print("Loading text encoder...")
text_encoder = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, subfolder="text_encoder",
torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True,
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, subfolder="tokenizer", padding_side="left",
)
print("Building pipeline...")
pipe = ZImagePipeline(
vae=vae, text_encoder=text_encoder,
tokenizer=tokenizer, scheduler=None, transformer=None,
)
print("Loading transformer...")
transformer = ZImageTransformer2DModel.from_pretrained(
MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16,
)
transformer = transformer.to("cuda")
pipe.transformer = transformer
print("All models loaded!")
# ---------------------------------------------------------------------------
# Server
# ---------------------------------------------------------------------------
app = Server()
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
@spaces.GPU(duration=30)
def run_inference(prompt, width, height, seed):
"""Run the diffusion pipeline on GPU."""
generator = torch.Generator("cuda").manual_seed(seed)
pipe.scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=3.0,
)
image = pipe(
prompt=prompt,
height=height,
width=width,
guidance_scale=0.0,
num_inference_steps=9,
generator=generator,
max_sequence_length=256,
).images[0]
return image
@app.api(name="generate")
def generate(
prompt: str,
width: int = 1024,
height: int = 1024,
seed: int = -1,
) -> str:
"""Generate an image from a text prompt. Returns base64-encoded PNG."""
width = max(512, min(2048, (width // 64) * 64))
height = max(512, min(2048, (height // 64) * 64))
if seed < 0:
seed = random.randint(0, MAX_SEED)
image = run_inference(prompt, width, height, seed)
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}'
demo = app
if __name__ == "__main__":
demo.launch(show_error=True, ssr_mode=False)