Aklavya commited on
Commit
3f98338
·
verified ·
1 Parent(s): 3b0f256

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -96
app.py CHANGED
@@ -1,104 +1,297 @@
 
 
 
1
  import gradio as gr
 
 
 
2
  import torch
3
- from diffusers import DALL_E
4
- from PIL import Image as PILImage
5
- import concurrent.futures
6
-
7
- # Model cache to avoid reloading the model multiple times
8
- model_cache = {}
9
-
10
- def load_model():
11
- model_name = "dalle-mini/dalle-mini" # Using DALL·E Mini model
12
- # Check if the model is already cached to avoid reloading every time
13
- if model_name in model_cache:
14
- return model_cache[model_name]
15
-
16
- print(f"Loading model: {model_name}")
17
- try:
18
- # Select device (CPU only for ZeroGPU plan)
19
- device = "cpu" # Set to CPU, as you don't have GPU access
20
-
21
- # Load the model with float32 (since float16 is not supported on CPU)
22
- model = DALL_E.from_pretrained(model_name, torch_dtype=torch.float32)
23
- model.to(device)
24
-
25
- # Cache the model for future use
26
- model_cache[model_name] = model
27
- print("Model loaded successfully.")
28
- return model
29
- except Exception as e:
30
- print(f"Error loading model: {e}")
31
- return None
32
-
33
- # Function to generate the image with a timeout
34
- def generate_image_with_timeout(prompt):
35
- timeout = 180 # Timeout after 180 seconds
36
-
37
- try:
38
- # Use ThreadPoolExecutor to handle the timeout
39
- with concurrent.futures.ThreadPoolExecutor() as executor:
40
- future = executor.submit(generate_image, prompt)
41
- return future.result(timeout=timeout) # Will raise TimeoutError if the process exceeds timeout
42
-
43
- except concurrent.futures.TimeoutError:
44
- return "Error: The image generation timed out. Please try again."
45
-
46
- # Function to generate the image
47
- def generate_image(prompt):
48
- model = load_model()
49
-
50
- if model is None:
51
- return "Error loading the model."
52
-
53
- try:
54
- # Generate the image from the prompt
55
- with torch.no_grad():
56
- output = model(prompt)
57
- image = output.images[0] # Assuming the first image is the one we need
58
- image = PILImage.fromarray(image) # Convert to PIL image format for Gradio
59
- return image
60
- except Exception as e:
61
- print(f"Error generating image: {e}")
62
- return "Error generating the image."
63
-
64
- # Define the Gradio interface using gr.Blocks
65
- def create_gradio_interface():
66
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
67
- gr.Markdown("""
68
- <h1 style="
69
- text-align: center;
70
- color: white;
71
- font-weight: bold;
72
- text-transform: uppercase;
73
- text-decoration: underline;
74
- margin-top: 30px;
75
- font-family: 'Arial', sans-serif;
76
- background: linear-gradient(45deg, #ff6b6b, #f06595);
77
- padding: 10px 20px;
78
- border-radius: 15px;
79
- box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.3);
80
- ">
81
- SNAPSCRIBE
82
- </h1>
83
- """)
84
 
85
- with gr.Row():
86
- with gr.Column(scale=3, min_width=300): # Changed scale to integer
87
- prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="e.g., A futuristic city skyline")
88
- submit_button = gr.Button("Generate Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- with gr.Column(scale=7, min_width=600): # Changed scale to integer
91
- output_image = gr.Image(label="Generated Image", height=640)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- submit_button.click(fn=generate_image_with_timeout, inputs=[prompt_input], outputs=output_image)
 
 
 
 
94
 
95
- gr.Markdown("""
96
- <div style="position: relative; left: 0; bottom: 0; width: 100%; background-color: #0B0F19; color: white; text-align: center; padding: 10px 0;">
97
- <p>Developed with ❤ by Aklavya (Bucky)</p>
98
- </div>
99
- """)
 
100
 
101
- demo.launch() # Removed `share=True`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Launch the Gradio interface
104
- create_gradio_interface()
 
1
+ import os
2
+ import random
3
+ import uuid
4
  import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import spaces
8
  import torch
9
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
10
+ from typing import Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ css = '''
13
+ .gradio-container{max-width: 575px !important}
14
+ h1{text-align:center}
15
+ footer {
16
+ visibility: hidden
17
+ }
18
+ '''
19
+
20
+ DESCRIPTIONXX = """## TEXT 2 IMAGE🥠"""
21
+
22
+ examples = [
23
+ "A tiny astronaut hatching from an egg on the moon, 4k, planet theme, --style raw5 --v 6.0",
24
+ "An anime-style illustration of a delicious, golden-brown wiener schnitzel on a plate, served with fresh lemon slices, parsley --style raw5",
25
+ "Cold coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4K, Photo-Realistic",
26
+ "A cat holding a sign that says hello world --ar 85:128 --v 6.0 --style raw"
27
+ ]
28
+
29
+ MODEL_OPTIONS = {
30
+ "LIGHTNING V5.0": "SG161222/RealVisXL_V5.0_Lightning",
31
+ "LIGHTNING V4.0": "SG161222/RealVisXL_V4.0_Lightning",
32
+ }
33
+
34
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
35
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
36
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
37
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
38
+
39
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
+
41
+ style_list = [
42
+ {
43
+ "name": "3840 x 2160",
44
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
45
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
46
+ },
47
+ {
48
+ "name": "2560 x 1440",
49
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
50
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
51
+ },
52
+ {
53
+ "name": "HD+",
54
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
55
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
56
+ },
57
+ {
58
+ "name": "Style Zero",
59
+ "prompt": "{prompt}",
60
+ "negative_prompt": "",
61
+ },
62
+ ]
63
+
64
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
65
+ DEFAULT_STYLE_NAME = "3840 x 2160"
66
+ STYLE_NAMES = list(styles.keys())
67
+
68
+
69
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
70
+ if style_name in styles:
71
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
72
+ else:
73
+ p, n = styles[DEFAULT_STYLE_NAME]
74
+
75
+ if not negative:
76
+ negative = ""
77
+ return p.replace("{prompt}", positive), n + negative
78
+
79
+
80
+ def load_and_prepare_model(model_id):
81
+ pipe = StableDiffusionXLPipeline.from_pretrained(
82
+ model_id,
83
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
84
+ use_safetensors=True,
85
+ add_watermarker=False,
86
+ ).to(device)
87
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
88
+
89
+ # Enable memory-efficient attention
90
+ pipe.enable_attention_slicing()
91
+
92
+ # Compile the model if enabled
93
+ if USE_TORCH_COMPILE:
94
+ pipe.compile()
95
+
96
+ # CPU offload only if explicitly set
97
+ if ENABLE_CPU_OFFLOAD and device.type == "cuda":
98
+ pipe.enable_model_cpu_offload()
99
+
100
+ return pipe
101
+
102
+
103
+ # Preload and compile both models
104
+ models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
105
 
106
+ MAX_SEED = np.iinfo(np.int32).max
107
+
108
+
109
+ def save_image(img):
110
+ unique_name = str(uuid.uuid4()) + ".png"
111
+ img.save(unique_name)
112
+ return unique_name
113
+
114
+
115
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
116
+ if randomize_seed:
117
+ seed = random.randint(0, MAX_SEED)
118
+ return seed
119
+
120
+
121
+ @spaces.GPU(duration=60, enable_queue=True)
122
+ def generate(
123
+ model_choice: str,
124
+ prompt: str,
125
+ negative_prompt: str = "",
126
+ use_negative_prompt: bool = False,
127
+ style_selection: str = DEFAULT_STYLE_NAME,
128
+ seed: int = 1,
129
+ width: int = 1024,
130
+ height: int = 1024,
131
+ guidance_scale: float = 3,
132
+ num_inference_steps: int = 25,
133
+ randomize_seed: bool = False,
134
+ use_resolution_binning: bool = True,
135
+ num_images: int = 1,
136
+ progress=gr.Progress(track_tqdm=True),
137
+ ):
138
+ global models
139
+ pipe = models[model_choice]
140
+
141
+ seed = int(randomize_seed_fn(seed, randomize_seed))
142
+ generator = torch.Generator(device=device).manual_seed(seed)
143
+
144
+ prompt, negative_prompt = apply_style(style_selection, prompt, negative_prompt)
145
+
146
+ options = {
147
+ "prompt": [prompt] * num_images,
148
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
149
+ "width": width,
150
+ "height": height,
151
+ "guidance_scale": guidance_scale,
152
+ "num_inference_steps": num_inference_steps,
153
+ "generator": generator,
154
+ "output_type": "pil",
155
+ }
156
+
157
+ if use_resolution_binning:
158
+ options["use_resolution_binning"] = True
159
+
160
+ images = []
161
+ for i in range(0, num_images, BATCH_SIZE):
162
+ batch_options = options.copy()
163
+ batch_options["prompt"] = options["prompt"][i:i + BATCH_SIZE]
164
+ if "negative_prompt" in batch_options:
165
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i + BATCH_SIZE]
166
+ with torch.cuda.amp.autocast():
167
+ images.extend(pipe(**batch_options).images)
168
+
169
+ image_paths = [save_image(img) for img in images]
170
+ return image_paths, seed
171
+
172
+
173
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
174
+ gr.Markdown(DESCRIPTIONXX)
175
+ with gr.Row():
176
+ prompt = gr.Text(
177
+ label="Prompt",
178
+ show_label=False,
179
+ max_lines=1,
180
+ placeholder="Enter your prompt",
181
+ container=False,
182
+ )
183
+ run_button = gr.Button("Run", scale=0)
184
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
185
+
186
+ with gr.Row():
187
+ model_choice = gr.Dropdown(
188
+ label="Model Selection⬇️",
189
+ choices=list(MODEL_OPTIONS.keys()),
190
+ value="LIGHTNING V5.0"
191
+ )
192
+
193
+ with gr.Accordion("Advanced options", open=False, visible=False):
194
+ style_selection = gr.Radio(
195
+ show_label=True,
196
+ container=True,
197
+ interactive=True,
198
+ choices=STYLE_NAMES,
199
+ value=DEFAULT_STYLE_NAME,
200
+ label="Quality Style",
201
+ )
202
+ num_images = gr.Slider(
203
+ label="Number of Images",
204
+ minimum=1,
205
+ maximum=5,
206
+ step=1,
207
+ value=1,
208
+ )
209
+ with gr.Row():
210
+ with gr.Column(scale=1):
211
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
212
+ negative_prompt = gr.Text(
213
+ label="Negative prompt",
214
+ max_lines=5,
215
+ lines=4,
216
+ placeholder="Enter a negative prompt",
217
+ value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
218
+ visible=True,
219
+ )
220
+ seed = gr.Slider(
221
+ label="Seed",
222
+ minimum=0,
223
+ maximum=MAX_SEED,
224
+ step=1,
225
+ value=0,
226
+ )
227
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
228
+ with gr.Row():
229
+ width = gr.Slider(
230
+ label="Width",
231
+ minimum=512,
232
+ maximum=MAX_IMAGE_SIZE,
233
+ step=8,
234
+ value=1024,
235
+ )
236
+ height = gr.Slider(
237
+ label="Height",
238
+ minimum=512,
239
+ maximum=MAX_IMAGE_SIZE,
240
+ step=8,
241
+ value=1024,
242
+ )
243
+ with gr.Row():
244
+ guidance_scale = gr.Slider(
245
+ label="Guidance Scale",
246
+ minimum=0.1,
247
+ maximum=6,
248
+ step=0.1,
249
+ value=3.0,
250
+ )
251
+ num_inference_steps = gr.Slider(
252
+ label="Number of inference steps",
253
+ minimum=1,
254
+ maximum=60,
255
+ step=1,
256
+ value=28,
257
+ )
258
 
259
+ gr.Examples(
260
+ examples=examples,
261
+ inputs=prompt,
262
+ cache_examples=False
263
+ )
264
 
265
+ use_negative_prompt.change(
266
+ fn=lambda x: gr.update(visible=x),
267
+ inputs=use_negative_prompt,
268
+ outputs=negative_prompt,
269
+ api_name=False,
270
+ )
271
 
272
+ gr.on(
273
+ triggers=[
274
+ prompt.submit,
275
+ negative_prompt.submit,
276
+ run_button.click,
277
+ ],
278
+ fn=generate,
279
+ inputs=[
280
+ model_choice,
281
+ prompt,
282
+ negative_prompt,
283
+ use_negative_prompt,
284
+ style_selection,
285
+ seed,
286
+ width,
287
+ height,
288
+ guidance_scale,
289
+ num_inference_steps,
290
+ randomize_seed,
291
+ num_images,
292
+ ],
293
+ outputs=[result, seed],
294
+ )
295
 
296
+ if __name__ == "__main__":
297
+ demo.queue(max_size=50).launch(show_api=True)