victor's picture
victor HF Staff
Upload /Users/vm/code/image-studio/app.py with huggingface_hub
a02cb2e verified
import os
import io
import base64
import random
import spaces
import torch
from PIL import Image
from gradio import Server
from fastapi.responses import HTMLResponse
from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker"
MAX_SEED = 2**32 - 1
# ---------------------------------------------------------------------------
# Module-level model loading (runs once at startup, before ZeroGPU kicks in)
# ---------------------------------------------------------------------------
vae = AutoencoderKL.from_pretrained(
MODEL_PATH, subfolder="vae",
torch_dtype=torch.bfloat16, device_map="cuda",
)
text_encoder = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, subfolder="text_encoder",
torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH, subfolder="tokenizer", padding_side="left",
)
pipe = ZImagePipeline(
vae=vae, text_encoder=text_encoder,
tokenizer=tokenizer, scheduler=None, transformer=None,
)
transformer = ZImageTransformer2DModel.from_pretrained(
MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16,
)
transformer = transformer.to("cuda")
pipe.transformer = transformer
pipe.transformer.set_attention_backend("flash_3")
# Safety checker
safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH)
feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH)
pipe.safety_checker = safety_checker.to("cuda")
pipe.feature_extractor = feature_extractor
# ---------------------------------------------------------------------------
# Server setup
# ---------------------------------------------------------------------------
app = Server()
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
@spaces.GPU
@app.api(name="generate")
def generate(
prompt: str,
width: int = 1024,
height: int = 1024,
seed: int = -1,
) -> str:
"""Generate an image from a text prompt. Returns base64-encoded PNG."""
# Clamp to multiples of 64
width = max(256, min(2048, (width // 64) * 64))
height = max(256, min(2048, (height // 64) * 64))
if seed < 0:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator("cuda").manual_seed(seed)
# Fresh scheduler per call (it's stateful)
pipe.scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=3.0,
)
result = pipe(
prompt,
height=height,
width=width,
guidance_scale=0.0,
num_inference_steps=9,
generator=generator,
max_sequence_length=256,
)
image = result.images[0]
# Encode as base64 PNG
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}'
demo = app
if __name__ == "__main__":
demo.launch(show_error=True)