AbstractQbit commited on
Commit
de797f2
·
1 Parent(s): fc2e79d
Files changed (1) hide show
  1. app.py +52 -5
app.py CHANGED
@@ -3,13 +3,16 @@ import numpy as np
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() \
10
  else "xpu" if torch.xpu.is_available() \
11
  else "cpu"
12
  current_model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
13
 
14
  if torch.cuda.is_available() or torch.xpu.is_available():
15
  torch_dtype = torch.float16
@@ -32,22 +35,39 @@ def clean_vram():
32
  def infer(
33
  prompt,
34
  model_repo,
 
 
35
  negative_prompt,
36
  seed,
37
  randomize_seed,
38
  width,
39
  height,
40
  guidance_scale,
 
41
  num_inference_steps,
42
  progress=gr.Progress(track_tqdm=True),
43
  ):
44
- global current_model_repo_id, pipe
45
 
46
- if model_repo != current_model_repo_id:
47
- print(f"The model changed to {model_repo}, reloading pipeline...")
 
 
 
 
 
 
48
  del pipe
49
  clean_vram()
 
50
  pipe = DiffusionPipeline.from_pretrained(model_repo, torch_dtype=torch_dtype).to(device)
 
 
 
 
 
 
 
51
 
52
  if randomize_seed:
53
  seed = random.randint(0, MAX_SEED)
@@ -58,6 +78,7 @@ def infer(
58
  prompt=prompt,
59
  negative_prompt=negative_prompt,
60
  guidance_scale=guidance_scale,
 
61
  num_inference_steps=num_inference_steps,
62
  width=width,
63
  height=height,
@@ -88,7 +109,7 @@ with gr.Blocks(css=css) as demo:
88
 
89
  model_repo = gr.Dropdown(
90
  label="Model repository path",
91
- choices=["stabilityai/sdxl-turbo", "CompVis/stable-diffusion-v1-4"],
92
  allow_custom_value=True
93
  )
94
 
@@ -148,6 +169,14 @@ with gr.Blocks(css=css) as demo:
148
  step=0.1,
149
  value=0.0, # Replace with defaults that work for your model
150
  )
 
 
 
 
 
 
 
 
151
 
152
  num_inference_steps = gr.Slider(
153
  label="Number of inference steps",
@@ -157,6 +186,21 @@ with gr.Blocks(css=css) as demo:
157
  value=2, # Replace with defaults that work for your model
158
  )
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  gr.Examples(examples=examples, inputs=[prompt])
161
  gr.on(
162
  triggers=[run_button.click, prompt.submit],
@@ -164,12 +208,15 @@ with gr.Blocks(css=css) as demo:
164
  inputs=[
165
  prompt,
166
  model_repo,
 
 
167
  negative_prompt,
168
  seed,
169
  randomize_seed,
170
  width,
171
  height,
172
  guidance_scale,
 
173
  num_inference_steps,
174
  ],
175
  outputs=[result, seed],
 
3
  import random
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import DiffusionPipeline, AutoPipelineForText2Image
7
+ from peft import PeftModel
8
  import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() \
11
  else "xpu" if torch.xpu.is_available() \
12
  else "cpu"
13
  current_model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
14
+ current_lora_repo = None
15
+ current_lora_scale = 1.0
16
 
17
  if torch.cuda.is_available() or torch.xpu.is_available():
18
  torch_dtype = torch.float16
 
35
  def infer(
36
  prompt,
37
  model_repo,
38
+ lora_repo,
39
+ lora_scale,
40
  negative_prompt,
41
  seed,
42
  randomize_seed,
43
  width,
44
  height,
45
  guidance_scale,
46
+ pag_scale,
47
  num_inference_steps,
48
  progress=gr.Progress(track_tqdm=True),
49
  ):
50
+ global current_model_repo_id, current_lora_repo, current_lora_scale, pipe
51
 
52
+ if lora_repo == "None":
53
+ lora_repo = None
54
+
55
+ if (model_repo != current_model_repo_id) or (lora_repo != current_lora_repo) or (current_lora_scale != lora_scale):
56
+ print(f"The model changed to {model_repo}, {lora_repo} lora, reloading pipeline...")
57
+ current_model_repo_id = model_repo
58
+ current_lora_repo = lora_repo
59
+ current_lora_scale = lora_scale
60
  del pipe
61
  clean_vram()
62
+
63
  pipe = DiffusionPipeline.from_pretrained(model_repo, torch_dtype=torch_dtype).to(device)
64
+ if lora_repo:
65
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_repo, subfolder="unet").to(device)
66
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, lora_repo, subfolder="text_encoder").to(device)
67
+ pipe.unet.load_state_dict({k: lora_scale*v if 'lora' in k else v for k, v in pipe.unet.state_dict().items()})
68
+ pipe.text_encoder.load_state_dict({k: lora_scale*v if 'lora' in k else v for k, v in pipe.text_encoder.state_dict().items()})
69
+
70
+ pipe = AutoPipelineForText2Image.from_pipe(pipe, enable_pag=True)
71
 
72
  if randomize_seed:
73
  seed = random.randint(0, MAX_SEED)
 
78
  prompt=prompt,
79
  negative_prompt=negative_prompt,
80
  guidance_scale=guidance_scale,
81
+ pag_scale=pag_scale,
82
  num_inference_steps=num_inference_steps,
83
  width=width,
84
  height=height,
 
109
 
110
  model_repo = gr.Dropdown(
111
  label="Model repository path",
112
+ choices=["stabilityai/sdxl-turbo", "CompVis/stable-diffusion-v1-4", "stable-diffusion-v1-5/stable-diffusion-v1-5"],
113
  allow_custom_value=True
114
  )
115
 
 
169
  step=0.1,
170
  value=0.0, # Replace with defaults that work for your model
171
  )
172
+
173
+ pag_scale = gr.Slider(
174
+ label="PAG scale",
175
+ minimum=0.0,
176
+ maximum=10.0,
177
+ step=0.1,
178
+ value=0.0, # Replace with defaults that work for your model
179
+ )
180
 
181
  num_inference_steps = gr.Slider(
182
  label="Number of inference steps",
 
186
  value=2, # Replace with defaults that work for your model
187
  )
188
 
189
+ with gr.Row():
190
+ lora_repo = gr.Dropdown(
191
+ label="LoRA repository path",
192
+ choices=["None", "AbstractQbit/biskvit_cat_lora"],
193
+ allow_custom_value=True
194
+ )
195
+
196
+ lora_scale = gr.Slider(
197
+ label="LoRA scale",
198
+ minimum=0.0,
199
+ maximum=1.0,
200
+ step=0.1,
201
+ value=1.0, # Replace with defaults that work for your model
202
+ )
203
+
204
  gr.Examples(examples=examples, inputs=[prompt])
205
  gr.on(
206
  triggers=[run_button.click, prompt.submit],
 
208
  inputs=[
209
  prompt,
210
  model_repo,
211
+ lora_repo,
212
+ lora_scale,
213
  negative_prompt,
214
  seed,
215
  randomize_seed,
216
  width,
217
  height,
218
  guidance_scale,
219
+ pag_scale,
220
  num_inference_steps,
221
  ],
222
  outputs=[result, seed],