SpawnedShoyo commited on
Commit
985bb46
·
verified ·
1 Parent(s): 8365100

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+ import spaces
9
+ import torch
10
+ from diffusers import DiffusionPipeline
11
+
12
+ DESCRIPTION = """# Playground v2.5"""
13
+ if not torch.cuda.is_available():
14
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
15
+
16
+ MAX_SEED = np.iinfo(np.int32).max
17
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
18
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
19
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
20
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ NUM_IMAGES_PER_PROMPT = 1
25
+
26
+ if torch.cuda.is_available():
27
+ pipe = DiffusionPipeline.from_pretrained(
28
+ "playgroundai/playground-v2.5-1024px-aesthetic",
29
+ torch_dtype=torch.float16,
30
+ use_safetensors=True,
31
+ add_watermarker=False,
32
+ variant="fp16"
33
+ )
34
+ if ENABLE_CPU_OFFLOAD:
35
+ pipe.enable_model_cpu_offload()
36
+ else:
37
+ pipe.to(device)
38
+ print("Loaded on Device!")
39
+
40
+ if USE_TORCH_COMPILE:
41
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
42
+ print("Model Compiled!")
43
+
44
+
45
+ def save_image(img):
46
+ unique_name = str(uuid.uuid4()) + ".png"
47
+ img.save(unique_name)
48
+ return unique_name
49
+
50
+
51
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
52
+ if randomize_seed:
53
+ seed = random.randint(0, MAX_SEED)
54
+ return seed
55
+
56
+
57
+ @spaces.GPU(enable_queue=True)
58
+ def generate(
59
+ prompt: str,
60
+ negative_prompt: str = "",
61
+ use_negative_prompt: bool = False,
62
+ seed: int = 0,
63
+ width: int = 1024,
64
+ height: int = 1024,
65
+ guidance_scale: float = 3,
66
+ randomize_seed: bool = False,
67
+ use_resolution_binning: bool = True,
68
+ progress=gr.Progress(track_tqdm=True),
69
+ ):
70
+ pipe.to(device)
71
+ seed = int(randomize_seed_fn(seed, randomize_seed))
72
+ generator = torch.Generator().manual_seed(seed)
73
+
74
+ if not use_negative_prompt:
75
+ negative_prompt = None # type: ignore
76
+
77
+ images = pipe(
78
+ prompt=prompt,
79
+ negative_prompt=negative_prompt,
80
+ width=width,
81
+ height=height,
82
+ guidance_scale=guidance_scale,
83
+ num_inference_steps=25,
84
+ generator=generator,
85
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
86
+ use_resolution_binning=use_resolution_binning,
87
+ output_type="pil",
88
+ ).images
89
+
90
+ image_paths = [save_image(img) for img in images]
91
+ print(image_paths)
92
+ return image_paths, seed
93
+
94
+
95
+ examples = [
96
+ "neon holography crystal cat",
97
+ "a cat eating a piece of cheese",
98
+ "an astronaut riding a horse in space",
99
+ "a cartoon of a boy playing with a tiger",
100
+ "a cute robot artist painting on an easel, concept art",
101
+ "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
102
+ ]
103
+
104
+ css = '''
105
+ .gradio-container{max-width: 560px !important}
106
+ h1{text-align:center}
107
+ '''
108
+ with gr.Blocks(css=css) as demo:
109
+ gr.Markdown(DESCRIPTION)
110
+ gr.DuplicateButton(
111
+ value="Duplicate Space for private use",
112
+ elem_id="duplicate-button",
113
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
114
+ )
115
+ with gr.Group():
116
+ with gr.Row():
117
+ prompt = gr.Text(
118
+ label="Prompt",
119
+ show_label=False,
120
+ max_lines=1,
121
+ placeholder="Enter your prompt",
122
+ container=False,
123
+ )
124
+ run_button = gr.Button("Run", scale=0)
125
+ result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False)
126
+ with gr.Accordion("Advanced options", open=False):
127
+ with gr.Row():
128
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
129
+ negative_prompt = gr.Text(
130
+ label="Negative prompt",
131
+ max_lines=1,
132
+ placeholder="Enter a negative prompt",
133
+ visible=True,
134
+ )
135
+ seed = gr.Slider(
136
+ label="Seed",
137
+ minimum=0,
138
+ maximum=MAX_SEED,
139
+ step=1,
140
+ value=0,
141
+ )
142
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
143
+ with gr.Row(visible=True):
144
+ width = gr.Slider(
145
+ label="Width",
146
+ minimum=256,
147
+ maximum=MAX_IMAGE_SIZE,
148
+ step=32,
149
+ value=1024,
150
+ )
151
+ height = gr.Slider(
152
+ label="Height",
153
+ minimum=256,
154
+ maximum=MAX_IMAGE_SIZE,
155
+ step=32,
156
+ value=1024,
157
+ )
158
+ with gr.Row():
159
+ guidance_scale = gr.Slider(
160
+ label="Guidance Scale",
161
+ minimum=0.1,
162
+ maximum=20,
163
+ step=0.1,
164
+ value=3.0,
165
+ )
166
+
167
+ gr.Examples(
168
+ examples=examples,
169
+ inputs=prompt,
170
+ outputs=[result, seed],
171
+ fn=generate,
172
+ cache_examples=CACHE_EXAMPLES,
173
+ )
174
+
175
+ use_negative_prompt.change(
176
+ fn=lambda x: gr.update(visible=x),
177
+ inputs=use_negative_prompt,
178
+ outputs=negative_prompt,
179
+ api_name=False,
180
+ )
181
+
182
+ gr.on(
183
+ triggers=[
184
+ prompt.submit,
185
+ negative_prompt.submit,
186
+ run_button.click,
187
+ ],
188
+ fn=generate,
189
+ inputs=[
190
+ prompt,
191
+ negative_prompt,
192
+ use_negative_prompt,
193
+ seed,
194
+ width,
195
+ height,
196
+ guidance_scale,
197
+ randomize_seed,
198
+ ],
199
+ outputs=[result, seed],
200
+ api_name="run",
201
+ )
202
+
203
+ if __name__ == "__main__":
204
+ demo.queue(max_size=20).launch()