akaUNik commited on
Commit
35880ab
·
verified ·
1 Parent(s): 0178d80
Files changed (1) hide show
  1. app.py +47 -15
app.py CHANGED
@@ -2,16 +2,18 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
10
  MODEL_LIST = [
11
  "CompVis/stable-diffusion-v1-4",
12
  "stabilityai/sdxl-turbo",
13
  "runwayml/stable-diffusion-v1-5",
14
  "stabilityai/stable-diffusion-2-1",
 
15
  ]
16
 
17
  if torch.cuda.is_available():
@@ -19,24 +21,36 @@ if torch.cuda.is_available():
19
  else:
20
  torch_dtype = torch.float32
21
 
22
- # To avoid re-initializing pipelines repeatedly, we can cache them:
23
  model_cache = {}
24
 
25
  def load_pipeline(model_id: str):
26
- """Loads or retrieves a cached DiffusionPipeline."""
 
 
 
 
 
27
  if model_id in model_cache:
28
  return model_cache[model_id]
 
 
 
 
 
 
 
 
29
  else:
30
  pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
31
- pipe.to(device)
32
- model_cache[model_id] = pipe
33
- return pipe
 
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  MAX_IMAGE_SIZE = 1024
37
 
38
-
39
- # @spaces.GPU #[uncomment to use ZeroGPU]
40
  def infer(
41
  model_id,
42
  prompt,
@@ -47,6 +61,7 @@ def infer(
47
  height,
48
  guidance_scale,
49
  num_inference_steps,
 
50
  progress=gr.Progress(track_tqdm=True),
51
  ):
52
  # Load the pipeline for the chosen model
@@ -55,7 +70,15 @@ def infer(
55
  if randomize_seed:
56
  seed = random.randint(0, MAX_SEED)
57
 
58
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
59
 
60
  image = pipe(
61
  prompt=prompt,
@@ -69,7 +92,6 @@ def infer(
69
 
70
  return image, seed
71
 
72
-
73
  examples = [
74
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
75
  "An astronaut riding a green horse",
@@ -113,7 +135,6 @@ with gr.Blocks(css=css) as demo:
113
  label="Negative prompt",
114
  max_lines=1,
115
  placeholder="Enter a negative prompt",
116
- # visible=False,
117
  )
118
 
119
  seed = gr.Slider(
@@ -132,7 +153,7 @@ with gr.Blocks(css=css) as demo:
132
  minimum=256,
133
  maximum=MAX_IMAGE_SIZE,
134
  step=32,
135
- value=1024, # Replace with defaults that work for your model
136
  )
137
 
138
  height = gr.Slider(
@@ -140,7 +161,7 @@ with gr.Blocks(css=css) as demo:
140
  minimum=256,
141
  maximum=MAX_IMAGE_SIZE,
142
  step=32,
143
- value=1024, # Replace with defaults that work for your model
144
  )
145
 
146
  with gr.Row():
@@ -149,7 +170,7 @@ with gr.Blocks(css=css) as demo:
149
  minimum=0.0,
150
  maximum=20.0,
151
  step=0.5,
152
- value=7.0, # Common default for SD
153
  )
154
 
155
  num_inference_steps = gr.Slider(
@@ -157,9 +178,19 @@ with gr.Blocks(css=css) as demo:
157
  minimum=1,
158
  maximum=100,
159
  step=1,
160
- value=20, # Common default for SD
161
  )
162
 
 
 
 
 
 
 
 
 
 
 
163
  gr.Examples(examples=examples, inputs=[prompt])
164
  gr.on(
165
  triggers=[run_button.click, prompt.submit],
@@ -174,6 +205,7 @@ with gr.Blocks(css=css) as demo:
174
  height,
175
  guidance_scale,
176
  num_inference_steps,
 
177
  ],
178
  outputs=[result, seed],
179
  )
 
2
  import numpy as np
3
  import random
4
 
 
5
  from diffusers import DiffusionPipeline
6
  import torch
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Model list including your LoRA model
11
  MODEL_LIST = [
12
  "CompVis/stable-diffusion-v1-4",
13
  "stabilityai/sdxl-turbo",
14
  "runwayml/stable-diffusion-v1-5",
15
  "stabilityai/stable-diffusion-2-1",
16
+ "akaUNik/hw5-futurama-lora", # Your LoRA model option
17
  ]
18
 
19
  if torch.cuda.is_available():
 
21
  else:
22
  torch_dtype = torch.float32
23
 
24
+ # Cache to avoid re-initializing pipelines repeatedly
25
  model_cache = {}
26
 
27
  def load_pipeline(model_id: str):
28
+ """
29
+ Loads or retrieves a cached DiffusionPipeline.
30
+
31
+ If the chosen model is your LoRA adapter, then load the base model
32
+ (CompVis/stable-diffusion-v1-4) and apply the LoRA weights.
33
+ """
34
  if model_id in model_cache:
35
  return model_cache[model_id]
36
+
37
+ if model_id == "akaUNik/hw5-futurama-lora":
38
+ # Use the specified base model for your LoRA adapter.
39
+ base_model = "CompVis/stable-diffusion-v1-4"
40
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch_dtype)
41
+ # Load the LoRA weights into the U-Net.
42
+ # This assumes that load_attn_procs loads the LoRA weights.
43
+ pipe.unet.load_attn_procs(model_id)
44
  else:
45
  pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
46
+
47
+ pipe.to(device)
48
+ model_cache[model_id] = pipe
49
+ return pipe
50
 
51
  MAX_SEED = np.iinfo(np.int32).max
52
  MAX_IMAGE_SIZE = 1024
53
 
 
 
54
  def infer(
55
  model_id,
56
  prompt,
 
61
  height,
62
  guidance_scale,
63
  num_inference_steps,
64
+ lora_scale, # New parameter for adjusting LoRA scale
65
  progress=gr.Progress(track_tqdm=True),
66
  ):
67
  # Load the pipeline for the chosen model
 
70
  if randomize_seed:
71
  seed = random.randint(0, MAX_SEED)
72
 
73
+ generator = torch.Generator(device=device).manual_seed(seed)
74
+
75
+ # If using the LoRA model, update the LoRA scale if supported.
76
+ if model_id == "akaUNik/hw5-futurama-lora":
77
+ # This assumes your pipeline's unet has a method to update the LoRA scale.
78
+ if hasattr(pipe.unet, "set_lora_scale"):
79
+ pipe.unet.set_lora_scale(lora_scale)
80
+ else:
81
+ print("Warning: LoRA scale adjustment method not found on UNet.")
82
 
83
  image = pipe(
84
  prompt=prompt,
 
92
 
93
  return image, seed
94
 
 
95
  examples = [
96
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
97
  "An astronaut riding a green horse",
 
135
  label="Negative prompt",
136
  max_lines=1,
137
  placeholder="Enter a negative prompt",
 
138
  )
139
 
140
  seed = gr.Slider(
 
153
  minimum=256,
154
  maximum=MAX_IMAGE_SIZE,
155
  step=32,
156
+ value=1024,
157
  )
158
 
159
  height = gr.Slider(
 
161
  minimum=256,
162
  maximum=MAX_IMAGE_SIZE,
163
  step=32,
164
+ value=1024,
165
  )
166
 
167
  with gr.Row():
 
170
  minimum=0.0,
171
  maximum=20.0,
172
  step=0.5,
173
+ value=7.0,
174
  )
175
 
176
  num_inference_steps = gr.Slider(
 
178
  minimum=1,
179
  maximum=100,
180
  step=1,
181
+ value=20,
182
  )
183
 
184
+ # New slider for LoRA scale.
185
+ lora_scale = gr.Slider(
186
+ label="LoRA Scale",
187
+ minimum=0.0,
188
+ maximum=2.0,
189
+ step=0.1,
190
+ value=1.0,
191
+ info="Adjust the influence of the LoRA weights",
192
+ )
193
+
194
  gr.Examples(examples=examples, inputs=[prompt])
195
  gr.on(
196
  triggers=[run_button.click, prompt.submit],
 
205
  height,
206
  guidance_scale,
207
  num_inference_steps,
208
+ lora_scale, # Pass the new slider value
209
  ],
210
  outputs=[result, seed],
211
  )