Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,10 +41,6 @@ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtyp
|
|
| 41 |
txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
| 42 |
txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
|
| 43 |
|
| 44 |
-
# img2img model
|
| 45 |
-
img2img_pipe = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=txt2img_pipe.transformer, text_encoder=txt2img_pipe.text_encoder, tokenizer=txt2img_pipe.tokenizer, text_encoder_2=txt2img_pipe.text_encoder_2, tokenizer_2=txt2img_pipe.tokenizer_2, torch_dtype=dtype)
|
| 46 |
-
img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
|
| 47 |
-
|
| 48 |
|
| 49 |
MAX_SEED = 2**32 - 1
|
| 50 |
|
|
@@ -118,15 +114,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
| 118 |
img2img_model = False
|
| 119 |
orginal_image = None
|
| 120 |
print(device)
|
| 121 |
-
|
| 122 |
-
print("img2img")
|
| 123 |
-
orginal_image = load_image(image_url).to(device)
|
| 124 |
-
img2img_model = True
|
| 125 |
-
img2img_pipe.to(device)
|
| 126 |
-
else:
|
| 127 |
-
print("txt2img")
|
| 128 |
-
txt2img_pipe.to(device)
|
| 129 |
-
|
| 130 |
# Set random seed for reproducibility
|
| 131 |
if randomize_seed:
|
| 132 |
with calculateDuration("Set random seed"):
|
|
@@ -135,9 +123,8 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
| 135 |
# Load LoRA weights
|
| 136 |
gr.Info("Start to load LoRA ...")
|
| 137 |
with calculateDuration("Unloading LoRA"):
|
| 138 |
-
# img2img_pipe.unload_lora_weights()
|
| 139 |
txt2img_pipe.unload_lora_weights()
|
| 140 |
-
|
| 141 |
lora_configs = None
|
| 142 |
adapter_names = []
|
| 143 |
lora_names = []
|
|
@@ -165,19 +152,13 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
| 165 |
adapter_weights.append(adapter_weight)
|
| 166 |
if lora_repo and weights and adapter_name:
|
| 167 |
try:
|
| 168 |
-
|
| 169 |
-
img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
|
| 170 |
-
else:
|
| 171 |
-
txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
|
| 172 |
except:
|
| 173 |
print("load lora error")
|
| 174 |
|
| 175 |
# set lora weights
|
| 176 |
if len(lora_names) > 0:
|
| 177 |
-
|
| 178 |
-
img2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
|
| 179 |
-
else:
|
| 180 |
-
txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
|
| 181 |
|
| 182 |
# Generate image
|
| 183 |
error_message = ""
|
|
@@ -185,36 +166,20 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
|
|
| 185 |
gr.Info("Start to generate images ...")
|
| 186 |
with calculateDuration(f"Make a new generator: {seed}"):
|
| 187 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 188 |
-
|
| 189 |
with calculateDuration("Generating image"):
|
| 190 |
# Generate image
|
| 191 |
joint_attention_kwargs = {"scale": 1}
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
height=height,
|
| 203 |
-
generator=generator,
|
| 204 |
-
joint_attention_kwargs=joint_attention_kwargs
|
| 205 |
-
).images[0]
|
| 206 |
-
else:
|
| 207 |
-
txt2img_pipe.to(device)
|
| 208 |
-
final_image = txt2img_pipe(
|
| 209 |
-
prompt=prompt,
|
| 210 |
-
num_inference_steps=steps,
|
| 211 |
-
guidance_scale=cfg_scale,
|
| 212 |
-
width=width,
|
| 213 |
-
height=height,
|
| 214 |
-
max_sequence_length=512,
|
| 215 |
-
generator=generator,
|
| 216 |
-
joint_attention_kwargs=joint_attention_kwargs
|
| 217 |
-
).images[0]
|
| 218 |
|
| 219 |
except Exception as e:
|
| 220 |
error_message = str(e)
|
|
|
|
| 41 |
txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
|
| 42 |
txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
MAX_SEED = 2**32 - 1
|
| 46 |
|
|
|
|
| 114 |
img2img_model = False
|
| 115 |
orginal_image = None
|
| 116 |
print(device)
|
| 117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# Set random seed for reproducibility
|
| 119 |
if randomize_seed:
|
| 120 |
with calculateDuration("Set random seed"):
|
|
|
|
| 123 |
# Load LoRA weights
|
| 124 |
gr.Info("Start to load LoRA ...")
|
| 125 |
with calculateDuration("Unloading LoRA"):
|
|
|
|
| 126 |
txt2img_pipe.unload_lora_weights()
|
| 127 |
+
print(device)
|
| 128 |
lora_configs = None
|
| 129 |
adapter_names = []
|
| 130 |
lora_names = []
|
|
|
|
| 152 |
adapter_weights.append(adapter_weight)
|
| 153 |
if lora_repo and weights and adapter_name:
|
| 154 |
try:
|
| 155 |
+
txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
|
|
|
|
|
|
|
|
|
|
| 156 |
except:
|
| 157 |
print("load lora error")
|
| 158 |
|
| 159 |
# set lora weights
|
| 160 |
if len(lora_names) > 0:
|
| 161 |
+
txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
# Generate image
|
| 164 |
error_message = ""
|
|
|
|
| 166 |
gr.Info("Start to generate images ...")
|
| 167 |
with calculateDuration(f"Make a new generator: {seed}"):
|
| 168 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 169 |
+
print(device)
|
| 170 |
with calculateDuration("Generating image"):
|
| 171 |
# Generate image
|
| 172 |
joint_attention_kwargs = {"scale": 1}
|
| 173 |
+
final_image = txt2img_pipe(
|
| 174 |
+
prompt=prompt,
|
| 175 |
+
num_inference_steps=steps,
|
| 176 |
+
guidance_scale=cfg_scale,
|
| 177 |
+
width=width,
|
| 178 |
+
height=height,
|
| 179 |
+
max_sequence_length=512,
|
| 180 |
+
generator=generator,
|
| 181 |
+
joint_attention_kwargs=joint_attention_kwargs
|
| 182 |
+
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
except Exception as e:
|
| 185 |
error_message = str(e)
|