EDDIE2541 commited on
Commit
5f17680
·
verified ·
1 Parent(s): 5ad7ffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -97
app.py CHANGED
@@ -2,8 +2,10 @@ import os
2
  import gradio as gr
3
  import torch
4
  import uuid
 
5
  from PIL import Image
6
  from diffusers import AutoPipelineForText2Image, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline
 
7
 
8
  # Define global variables
9
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -13,14 +15,17 @@ trigger_word = {}
13
  # Load the pretrained model and add LoRAs
14
  pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
15
  pipe.to("cuda")
16
- pipe.load_lora_weights('lora_weights/lora_dolly_cat.safetensors')
17
- pipe.fuse_lora()
 
 
 
18
 
19
  # Create a dictionary of available LoRAs and their corresponding trigger words
20
  for i in os.scandir('lora_weights'):
21
  if i.name != '.gitignore':
22
  lora_models[i.name] = i.path
23
- trigger_word[i.name] = i.name.split('_')[0] + ' asd dolly cat'
24
 
25
  # Define helper functions
26
  def save_img(image_list, prompt):
@@ -38,103 +43,15 @@ def set_lora_model(lora_name, lora_scale):
38
  pipe.unfuse_lora(True)
39
  pipe.unload_lora_weights()
40
  print(lora_models[lora_name])
41
- pipe.load_lora_weights(lora_models[lora_name])
42
- pipe.fuse_lora(lora_scale=lora_scale)
 
 
 
43
  print('Model swapped')
44
  return trigger_word[lora_name]
45
 
46
- def toggle_freeU(freeU_toggle):
47
- if freeU_toggle:
48
- print('freeU enabled')
49
- pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
50
- else:
51
- print('freeU disabled')
52
- pipe.disable_freeu()
53
-
54
- def generate(
55
- prompt,
56
- guidance_scale,
57
- num_images_per_prompt,
58
- height,
59
- width,
60
- generator_seed,
61
- negative_prompt,
62
- lora_scale
63
- ):
64
- generator = torch.Generator("cuda").manual_seed(generator_seed)
65
- image = pipe(
66
- prompt=prompt,
67
- negative_prompt=negative_prompt,
68
- guidance_scale=guidance_scale,
69
- num_images_per_prompt=num_images_per_prompt,
70
- height=height,
71
- width=width,
72
- num_inference_steps=20,
73
- generator=generator,
74
- cross_attention_kwargs={"scale": lora_scale}
75
- ).images
76
- return image
77
-
78
- # Define the Gradio interface
79
- def main():
80
- with gr.Blocks() as demo:
81
- with gr.Row():
82
- with gr.Column():
83
- gallery = gr.Gallery(
84
- label="Generate",
85
- object_fit="contain", height="512",
86
- interactive=False)
87
- positive_prompt = gr.Textbox(
88
- label="Enter Positive Prompt...",
89
- value= 'qwe dolly cat '
90
- )
91
- negative_prompt = gr.Textbox(
92
- label="Enter Negative Prompt...",
93
- value='worst quality, normal quality, low quality, low res, blurry,less realistic text,mutated, ugly, disgusting, amputation, easynegative, bad-hands-5, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch ,duplicate, ugly, monochrome, horror, geometry,mutation, disgusting'
94
- )
95
- with gr.Row():
96
- lora_model_dropdown = gr.Dropdown(list(lora_models.keys()), label='Select LoRA model',
97
- value='lora_dolly_cat.safetensors')
98
- with gr.Row():
99
- with gr.Column():
100
- guidance_scale = gr.Slider(minimum=0, maximum=15,
101
- value=9.5, label='guidance scale')
102
- lora_scale = gr.Slider(minimum=0.1, maximum=1,
103
- value=1, step=0.01, label='Lora scale')
104
- with gr.Column():
105
- num_images_per_prompt = gr.Slider(minimum=1,
106
- maximum=4, value=1, step=1, label='number of images per prompt')
107
- generator_seed = gr.Slider(minimum=-1,
108
- maximum=100, value=1, step=1, label='generator seed')
109
- with gr.Row():
110
- height = gr.Slider(minimum=512, maximum=2048,
111
- value=1024, label='Image height')
112
- width = gr.Slider(minimum=512, maximum=2048,
113
- value=1024, step=8, label='Image width')
114
- freeu = gr.Checkbox(value=True, label='Toggle FreeU')
115
- with gr.Row():
116
- btn = gr.Button("Generate")
117
- download_btn = gr.Button("Download", visible=False)
118
-
119
- btn.click(generate,
120
- inputs=[positive_prompt,
121
- guidance_scale,
122
- num_images_per_prompt,
123
- height,
124
- width,
125
- generator_seed,
126
- negative_prompt,
127
- lora_scale],
128
- outputs=gallery)
129
- download_btn.click(save_img,
130
- inputs=[gallery,
131
- positive_prompt])
132
- freeu.select(toggle_freeU, freeu)
133
- lora_model_dropdown.select(set_lora_model,
134
- [lora_model_dropdown, lora_scale],
135
- positive_prompt)
136
-
137
- demo.launch(share=True)
138
 
139
  if __name__ == "__main__":
140
  main()
 
2
  import gradio as gr
3
  import torch
4
  import uuid
5
+ import peft
6
  from PIL import Image
7
  from diffusers import AutoPipelineForText2Image, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline
8
+ from peft import PeftModel, PeftConfig
9
 
10
  # Define global variables
11
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
 
15
  # Load the pretrained model and add LoRAs
16
  pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
17
  pipe.to("cuda")
18
+
19
+ base_model = pipe.model
20
+ peft_config = PeftConfig.from_pretrained('lora_weights/qwe_cat_long.safetensors')
21
+ peft_model = PeftModel.from_pretrained(base_model, peft_config)
22
+ pipe.model = peft_model
23
 
24
  # Create a dictionary of available LoRAs and their corresponding trigger words
25
  for i in os.scandir('lora_weights'):
26
  if i.name != '.gitignore':
27
  lora_models[i.name] = i.path
28
+ trigger_word[i.name] = i.name.split('_')[0] + ' cat bright white fur'
29
 
30
  # Define helper functions
31
  def save_img(image_list, prompt):
 
43
  pipe.unfuse_lora(True)
44
  pipe.unload_lora_weights()
45
  print(lora_models[lora_name])
46
+ peft_config = PeftConfig.from_pretrained(lora_models[lora_name])
47
+ peft_config.lora_scale = lora_scale
48
+ peft_model = PeftModel.from_pretrained(base_model, peft_config)
49
+ pipe.model = peft_model
50
+ pipe.fuse_lora()
51
  print('Model swapped')
52
  return trigger_word[lora_name]
53
 
54
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  if __name__ == "__main__":
57
  main()