Oranblock commited on
Commit
9f5148a
·
verified ·
1 Parent(s): d963421

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -32
app.py CHANGED
@@ -1,41 +1,307 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
  import torch
3
- from diffusers import StableDiffusionPipeline
4
- import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Check for the Model Base
7
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
 
 
 
9
  if torch.cuda.is_available():
10
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, use_safetensors=True)
11
- pipe.to(device)
12
- else:
13
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
14
-
15
- def generate_stickers(word):
16
- prompt = (f"Create a set of child-friendly stickers based on the word '{word}'. "
17
- "Each sticker should have a black bold outline and fit within a label size of 70 x 35 mm, "
18
- "and be organized into six distinct groups that fit a single A4 sheet with 24 labels per sheet. "
19
- "Ensure the stickers do not overlap and the black outline is clearly visible. "
20
- "The stickers should be fun, colorful, and cartoon-like, suitable for children, with a transparent or white background.")
21
- negative_prompt = ("text, logos, watermarks, out of frame, ugly, extra limbs, bad anatomy, blurry, "
22
- "overlapping stickers, missing outline")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- width, height = 1024, 768 # Appropriate resolution for detailed stickers
25
- images = pipe(prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=50).images
26
- output_paths = []
27
- for image in images:
28
- filename = f"{uuid.uuid4()}.png"
29
- image.save(filename)
30
- output_paths.append(filename)
31
- return output_paths
32
-
33
- iface = gr.Interface(
34
- fn=generate_stickers,
35
- inputs="text",
36
- outputs="image",
37
- description="Enter a word to generate a set of child-friendly stickers that fit Sorex A4 Premium Sticker sheets (70x35mm labels per sheet)."
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
41
- iface.launch()
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+ import json
7
+
8
  import gradio as gr
9
+ import numpy as np
10
+ from PIL import Image
11
+ import spaces
12
  import torch
13
+ from diffusers import DiffusionPipeline
14
+ from typing import Tuple
15
+
16
+ # Check for the Model Base..//
17
+
18
+ bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary"]'))
19
+ bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
20
+ default_negative = os.getenv("default_negative","")
21
+
22
+ def check_text(prompt, negative=""):
23
+ for i in bad_words:
24
+ if i in prompt:
25
+ return True
26
+ for i in bad_words_negative:
27
+ if i in negative:
28
+ return True
29
+ return False
30
+
31
+ # Updated to child-friendly styles
32
+
33
+ style_list = [
34
+
35
+ {
36
+ "name": "Cartoon",
37
+ "prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors",
38
+ "negative_prompt": "scary, dark, violent, ugly, realistic",
39
+ },
40
+ {
41
+ "name": "Children's Illustration",
42
+ "prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful",
43
+ "negative_prompt": "scary, dark, violent, deformed, ugly",
44
+ },
45
+
46
+ {
47
+ "name": "Sticker",
48
+ "prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish",
49
+ "negative_prompt": "scary, dark, violent, ugly, low resolution",
50
+ },
51
+
52
+ {
53
+ "name": "Fantasy",
54
+ "prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful",
55
+ "negative_prompt": "dark, scary, violent, ugly, realistic",
56
+ },
57
+
58
+ {
59
+ "name": "(No style)",
60
+ "prompt": "{prompt}",
61
+ "negative_prompt": "",
62
+ },
63
+ ]
64
+
65
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
66
+ STYLE_NAMES = list(styles.keys())
67
+ DEFAULT_STYLE_NAME = "Sticker"
68
+
69
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
70
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
71
+ if not negative:
72
+ negative = ""
73
+ return p.replace("{prompt}", positive), n + negative
74
+
75
+ DESCRIPTION = """## Children's Sticker Generator
76
+
77
+ Generate fun and playful stickers for children using AI.
78
+ """
79
+
80
+ if not torch.cuda.is_available():
81
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
82
+
83
+ MAX_SEED = np.iinfo(np.int32).max
84
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
85
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
86
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
87
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
88
 
 
89
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
 
91
+ NUM_IMAGES_PER_PROMPT = 1
92
+
93
  if torch.cuda.is_available():
94
+ pipe = DiffusionPipeline.from_pretrained(
95
+ "SG161222/RealVisXL_V3.0_Turbo",
96
+ torch_dtype=torch.float16,
97
+ use_safetensors=True,
98
+ add_watermarker=False,
99
+ variant="fp16"
100
+ )
101
+ pipe2 = DiffusionPipeline.from_pretrained(
102
+ "SG161222/RealVisXL_V2.02_Turbo",
103
+ torch_dtype=torch.float16,
104
+ use_safetensors=True,
105
+ add_watermarker=False,
106
+ variant="fp16"
107
+ )
108
+ if ENABLE_CPU_OFFLOAD:
109
+ pipe.enable_model_cpu_offload()
110
+ pipe2.enable_model_cpu_offload()
111
+ else:
112
+ pipe.to(device)
113
+ pipe2.to(device)
114
+ print("Loaded on Device!")
115
+
116
+ if USE_TORCH_COMPILE:
117
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
118
+ pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
119
+ print("Model Compiled!")
120
+
121
+ def save_image(img):
122
+ unique_name = str(uuid.uuid4()) + ".png"
123
+ img.save(unique_name)
124
+ return unique_name
125
+
126
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
127
+ if randomize_seed:
128
+ seed = random.randint(0, MAX_SEED)
129
+ return seed
130
+
131
+ @spaces.GPU(enable_queue=True)
132
+ def generate(
133
+ prompt: str,
134
+ negative_prompt: str = "",
135
+ use_negative_prompt: bool = False,
136
+ style: str = DEFAULT_STYLE_NAME,
137
+ seed: int = 0,
138
+ width: int = 512,
139
+ height: int = 512,
140
+ guidance_scale: float = 3,
141
+ randomize_seed: bool = False,
142
+ use_resolution_binning: bool = True,
143
+ progress=gr.Progress(track_tqdm=True),
144
+ ):
145
+ if check_text(prompt, negative_prompt):
146
+ raise ValueError("Prompt contains restricted words.")
147
+
148
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
149
+ seed = int(randomize_seed_fn(seed, randomize_seed))
150
+ generator = torch.Generator().manual_seed(seed)
151
+
152
+ if not use_negative_prompt:
153
+ negative_prompt = "" # type: ignore
154
+ negative_prompt += default_negative
155
+
156
+ options = {
157
+ "prompt": prompt,
158
+ "negative_prompt": negative_prompt,
159
+ "width": width,
160
+ "height": height,
161
+ "guidance_scale": guidance_scale,
162
+ "num_inference_steps": 25,
163
+ "generator": generator,
164
+ "num_images_per_prompt": NUM_IMAGES_PER_PROMPT,
165
+ "use_resolution_binning": use_resolution_binning,
166
+ "output_type": "pil",
167
+ }
168
 
169
+ images = pipe(**options).images + pipe2(**options).images
170
+
171
+ image_paths = [save_image(img) for img in images]
172
+ return image_paths, seed
173
+
174
+ examples = [
175
+ "A cute cartoon bunny holding a carrot in a colorful garden",
176
+ "A playful dragon flying through the clouds, bright and friendly",
177
+ "A magical unicorn standing on a rainbow with sparkles",
178
+ ]
179
+
180
+ css = '''
181
+ .gradio-container{max-width: 700px !important}
182
+ h1{text-align:center}
183
+ '''
184
+
185
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
186
+ gr.Markdown(DESCRIPTION)
187
+ gr.DuplicateButton(
188
+ value="Duplicate Space for private use",
189
+ elem_id="duplicate-button",
190
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
191
+ )
192
+ with gr.Group():
193
+ with gr.Row():
194
+ prompt = gr.Text(
195
+ label="Enter your prompt",
196
+ show_label=False,
197
+ max_lines=1,
198
+ placeholder="Enter a fun, child-friendly idea (e.g., cute bunny with a rainbow)",
199
+ container=False,
200
+ )
201
+ run_button = gr.Button("Run")
202
+ result = gr.Gallery(label="Generated Stickers", columns=1, preview=True)
203
+ with gr.Accordion("Advanced options", open=False):
204
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
205
+ negative_prompt = gr.Text(
206
+ label="Negative prompt",
207
+ max_lines=1,
208
+ placeholder="Enter a negative prompt",
209
+ value="(scary, violent, deformed, ugly, dark)",
210
+ visible=True,
211
+ )
212
+ with gr.Row():
213
+ num_inference_steps = gr.Slider(
214
+ label="Steps",
215
+ minimum=10,
216
+ maximum=60,
217
+ step=1,
218
+ value=25,
219
+ )
220
+ with gr.Row():
221
+ num_images_per_prompt = gr.Slider(
222
+ label="Images",
223
+ minimum=1,
224
+ maximum=5,
225
+ step=1,
226
+ value=2,
227
+ )
228
+ seed = gr.Slider(
229
+ label="Seed",
230
+ minimum=0,
231
+ maximum=MAX_SEED,
232
+ step=1,
233
+ value=0,
234
+ visible=True
235
+ )
236
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
237
+ with gr.Row(visible=True):
238
+ width = gr.Slider(
239
+ label="Width",
240
+ minimum=512,
241
+ maximum=1024,
242
+ step=8,
243
+ value=512,
244
+ )
245
+ height = gr.Slider(
246
+ label="Height",
247
+ minimum=512,
248
+ maximum=1024,
249
+ step=8,
250
+ value=512,
251
+ )
252
+ with gr.Row():
253
+ guidance_scale = gr.Slider(
254
+ label="Guidance Scale",
255
+ minimum=0.1,
256
+ maximum=20.0,
257
+ step=0.1,
258
+ value=7,
259
+ )
260
+ with gr.Row(visible=True):
261
+ style_selection = gr.Radio(
262
+ show_label=True,
263
+ container=True,
264
+ interactive=True,
265
+ choices=STYLE_NAMES,
266
+ value=DEFAULT_STYLE_NAME,
267
+ label="Sticker Style",
268
+ )
269
+ gr.Examples(
270
+ examples=examples,
271
+ inputs=prompt,
272
+ outputs=[result, seed],
273
+ fn=generate,
274
+ cache_examples=CACHE_EXAMPLES,
275
+ )
276
+
277
+ use_negative_prompt.change(
278
+ fn=lambda x: gr.update(visible=x),
279
+ inputs=use_negative_prompt,
280
+ outputs=negative_prompt,
281
+ api_name=False,
282
+ )
283
+
284
+ gr.on(
285
+ triggers=[
286
+ prompt.submit,
287
+ negative_prompt.submit,
288
+ run_button.click,
289
+ ],
290
+ fn=generate,
291
+ inputs=[
292
+ prompt,
293
+ negative_prompt,
294
+ use_negative_prompt,
295
+ style_selection,
296
+ seed,
297
+ width,
298
+ height,
299
+ guidance_scale,
300
+ randomize_seed,
301
+ ],
302
+ outputs=[result, seed],
303
+ api_name="run",
304
+ )
305
 
306
  if __name__ == "__main__":
307
+ demo.queue(max_size=20).launch()