victor HF Staff commited on
Commit
a02cb2e
·
verified ·
1 Parent(s): 918ad58

Upload /Users/vm/code/image-studio/app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. Users/vm/code/image-studio/app.py +117 -0
Users/vm/code/image-studio/app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import random
5
+
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
+ from gradio import Server
10
+ from fastapi.responses import HTMLResponse
11
+ from diffusers import ZImagePipeline, ZImageTransformer2DModel, FlowMatchEulerDiscreteScheduler
12
+ from diffusers import AutoencoderKL
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
15
+
16
+ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
17
+ SAFETY_CHECKER_PATH = "CompVis/stable-diffusion-safety-checker"
18
+ MAX_SEED = 2**32 - 1
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Module-level model loading (runs once at startup, before ZeroGPU kicks in)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ vae = AutoencoderKL.from_pretrained(
25
+ MODEL_PATH, subfolder="vae",
26
+ torch_dtype=torch.bfloat16, device_map="cuda",
27
+ )
28
+
29
+ text_encoder = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_PATH, subfolder="text_encoder",
31
+ torch_dtype=torch.bfloat16, device_map="cuda",
32
+ trust_remote_code=True,
33
+ )
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ MODEL_PATH, subfolder="tokenizer", padding_side="left",
37
+ )
38
+
39
+ pipe = ZImagePipeline(
40
+ vae=vae, text_encoder=text_encoder,
41
+ tokenizer=tokenizer, scheduler=None, transformer=None,
42
+ )
43
+
44
+ transformer = ZImageTransformer2DModel.from_pretrained(
45
+ MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16,
46
+ )
47
+ transformer = transformer.to("cuda")
48
+ pipe.transformer = transformer
49
+
50
+ pipe.transformer.set_attention_backend("flash_3")
51
+
52
+ # Safety checker
53
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(SAFETY_CHECKER_PATH)
54
+ feature_extractor = CLIPImageProcessor.from_pretrained(SAFETY_CHECKER_PATH)
55
+ pipe.safety_checker = safety_checker.to("cuda")
56
+ pipe.feature_extractor = feature_extractor
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Server setup
60
+ # ---------------------------------------------------------------------------
61
+
62
+ app = Server()
63
+
64
+
65
+ @app.get("/", response_class=HTMLResponse)
66
+ async def homepage():
67
+ html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
68
+ with open(html_path, "r", encoding="utf-8") as f:
69
+ return f.read()
70
+
71
+
72
+ @spaces.GPU
73
+ @app.api(name="generate")
74
+ def generate(
75
+ prompt: str,
76
+ width: int = 1024,
77
+ height: int = 1024,
78
+ seed: int = -1,
79
+ ) -> str:
80
+ """Generate an image from a text prompt. Returns base64-encoded PNG."""
81
+ # Clamp to multiples of 64
82
+ width = max(256, min(2048, (width // 64) * 64))
83
+ height = max(256, min(2048, (height // 64) * 64))
84
+
85
+ if seed < 0:
86
+ seed = random.randint(0, MAX_SEED)
87
+
88
+ generator = torch.Generator("cuda").manual_seed(seed)
89
+
90
+ # Fresh scheduler per call (it's stateful)
91
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler(
92
+ num_train_timesteps=1000, shift=3.0,
93
+ )
94
+
95
+ result = pipe(
96
+ prompt,
97
+ height=height,
98
+ width=width,
99
+ guidance_scale=0.0,
100
+ num_inference_steps=9,
101
+ generator=generator,
102
+ max_sequence_length=256,
103
+ )
104
+ image = result.images[0]
105
+
106
+ # Encode as base64 PNG
107
+ buf = io.BytesIO()
108
+ image.save(buf, format="PNG")
109
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
110
+
111
+ return f'{{"image_b64":"{b64}","seed":{seed},"width":{width},"height":{height}}}'
112
+
113
+
114
+ demo = app
115
+
116
+ if __name__ == "__main__":
117
+ demo.launch(show_error=True)