Geek7 commited on
Commit
8b65f01
·
verified ·
1 Parent(s): 1368b5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -217
app.py CHANGED
@@ -1,219 +1,6 @@
1
- from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
- import torch
3
- import os
4
 
5
- try:
6
- import intel_extension_for_pytorch as ipex
7
- except:
8
- pass
9
 
10
- from PIL import Image
11
- import numpy as np
12
- import gradio as gr
13
- import psutil
14
- import time
15
- import math
16
- from transformers.utils.hub import move_cache
17
-
18
- move_cache()
19
-
20
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
21
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
22
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
- # check if MPS is available OSX only M1/M2/M3 chips
24
- mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
25
- xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
26
- device = torch.device(
27
- "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
28
- )
29
- torch_device = device
30
- torch_dtype = torch.float16
31
-
32
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
33
- print(f"TORCH_COMPILE: {TORCH_COMPILE}")
34
- print(f"device: {device}")
35
-
36
- if mps_available:
37
- device = torch.device("mps")
38
- torch_device = "cpu"
39
- torch_dtype = torch.float32
40
-
41
- if SAFETY_CHECKER == "True":
42
- i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
43
- "stabilityai/sdxl-turbo",
44
- torch_dtype=torch_dtype,
45
- variant="fp16" if torch_dtype == torch.float16 else "fp32",
46
- )
47
- t2i_pipe = AutoPipelineForText2Image.from_pretrained(
48
- "stabilityai/sdxl-turbo",
49
- torch_dtype=torch_dtype,
50
- variant="fp16" if torch_dtype == torch.float16 else "fp32",
51
- )
52
- else:
53
- i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
54
- "stabilityai/sdxl-turbo",
55
- safety_checker=None,
56
- torch_dtype=torch_dtype,
57
- variant="fp16" if torch_dtype == torch.float16 else "fp32",
58
- )
59
- t2i_pipe = AutoPipelineForText2Image.from_pretrained(
60
- "stabilityai/sdxl-turbo",
61
- safety_checker=None,
62
- torch_dtype=torch_dtype,
63
- variant="fp16" if torch_dtype == torch.float16 else "fp32",
64
- )
65
-
66
-
67
- t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
68
- t2i_pipe.set_progress_bar_config(disable=True)
69
- i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
70
- i2i_pipe.set_progress_bar_config(disable=True)
71
-
72
-
73
- def resize_crop(image, size=512):
74
- image = image.convert("RGB")
75
- w, h = image.size
76
- image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
77
- return image
78
-
79
-
80
- async def predict(init_image, prompt, strength, steps, seed=1231231):
81
- if init_image is not None:
82
- init_image = resize_crop(init_image)
83
- generator = torch.manual_seed(seed)
84
- last_time = time.time()
85
-
86
- if int(steps * strength) < 1:
87
- steps = math.ceil(1 / max(0.10, strength))
88
-
89
- results = i2i_pipe(
90
- prompt=prompt,
91
- image=init_image,
92
- generator=generator,
93
- num_inference_steps=steps,
94
- guidance_scale=0.0,
95
- strength=strength,
96
- width=512,
97
- height=512,
98
- output_type="pil",
99
- )
100
- else:
101
- generator = torch.manual_seed(seed)
102
- last_time = time.time()
103
- results = t2i_pipe(
104
- prompt=prompt,
105
- generator=generator,
106
- num_inference_steps=steps,
107
- guidance_scale=0.0,
108
- width=512,
109
- height=512,
110
- output_type="pil",
111
- )
112
- print(f"Pipe took {time.time() - last_time} seconds")
113
- nsfw_content_detected = (
114
- results.nsfw_content_detected[0]
115
- if "nsfw_content_detected" in results
116
- else False
117
- )
118
- if nsfw_content_detected:
119
- gr.Warning("NSFW content detected.")
120
- return Image.new("RGB", (512, 512))
121
- return results.images[0]
122
-
123
-
124
- css = """
125
- #container{
126
- margin: 0 auto;
127
- max-width: 80rem;
128
- }
129
- #intro{
130
- max-width: 100%;
131
- text-align: center;
132
- margin: 0 auto;
133
- }
134
- """
135
- with gr.Blocks(css=css) as demo:
136
- init_image_state = gr.State()
137
- with gr.Column(elem_id="container"):
138
- gr.Markdown(
139
- """# SDXL Turbo Image to Image/Text to Image
140
- ## Unofficial Demo
141
- SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
142
- **Model**: https://huggingface.co/stabilityai/sdxl-turbo
143
- """,
144
- elem_id="intro",
145
- )
146
- with gr.Row():
147
- prompt = gr.Textbox(
148
- placeholder="Insert your prompt here:",
149
- scale=5,
150
- container=False,
151
- )
152
- generate_bt = gr.Button("Generate", scale=1)
153
- with gr.Row():
154
- with gr.Column():
155
- image_input = gr.Image(
156
- sources=["upload", "webcam", "clipboard"],
157
- label="Webcam",
158
- type="pil",
159
- )
160
- with gr.Column():
161
- image = gr.Image(type="filepath")
162
- with gr.Accordion("Advanced options", open=False):
163
- strength = gr.Slider(
164
- label="Strength",
165
- value=0.7,
166
- minimum=0.0,
167
- maximum=1.0,
168
- step=0.001,
169
- )
170
- steps = gr.Slider(
171
- label="Steps", value=2, minimum=1, maximum=10, step=1
172
- )
173
- seed = gr.Slider(
174
- randomize=True,
175
- minimum=0,
176
- maximum=12013012031030,
177
- label="Seed",
178
- step=1,
179
- )
180
-
181
- with gr.Accordion("Run with diffusers"):
182
- gr.Markdown(
183
- """## Running SDXL Turbo with `diffusers`
184
- ```bash
185
- pip install diffusers==0.23.1
186
- ```
187
- ```py
188
- from diffusers import DiffusionPipeline
189
-
190
- pipe = DiffusionPipeline.from_pretrained(
191
- "stabilityai/sdxl-turbo"
192
- ).to("cuda")
193
- results = pipe(
194
- prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
195
- num_inference_steps=1,
196
- guidance_scale=0.0,
197
- )
198
- imga = results.images[0]
199
- imga.save("image.png")
200
- ```
201
- """
202
- )
203
-
204
- inputs = [image_input, prompt, strength, steps, seed]
205
- generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
206
- prompt.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
207
- steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
208
- seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
209
- strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
210
- image_input.change(
211
- fn=lambda x: x,
212
- inputs=image_input,
213
- outputs=init_image_state,
214
- show_progress=False,
215
- queue=False,
216
- )
217
-
218
- demo.queue()
219
- demo.launch()
 
1
+ from diffusers import DiffusionPipeline
 
 
2
 
3
+ pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4")
 
 
 
4
 
5
+ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
6
+ image = pipe(prompt).images[0]