Update app.py
Browse files
app.py
CHANGED
|
@@ -308,14 +308,15 @@ class CustomImageDataset(Dataset):
|
|
| 308 |
return image
|
| 309 |
|
| 310 |
@spaces.GPU
|
| 311 |
-
def invert(
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
| 316 |
|
| 317 |
-
|
| 318 |
-
network = LoRAw2w(
|
| 319 |
unet,
|
| 320 |
rank=1,
|
| 321 |
multiplier=1.0,
|
|
@@ -367,18 +368,27 @@ def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e
|
|
| 367 |
optim.zero_grad()
|
| 368 |
loss.backward()
|
| 369 |
optim.step()
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
|
| 375 |
@spaces.GPU
|
| 376 |
-
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
#sample an image
|
| 383 |
prompt = "sks person"
|
| 384 |
negative_prompt = "low quality, blurry, unfinished, nudity"
|
|
@@ -387,7 +397,7 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
|
| 387 |
steps = 25
|
| 388 |
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
| 389 |
torch.save(network.proj, "model.pt" )
|
| 390 |
-
return
|
| 391 |
|
| 392 |
|
| 393 |
@spaces.GPU
|
|
@@ -408,7 +418,7 @@ def file_upload(file, net):
|
|
| 408 |
cfg = 3.0
|
| 409 |
steps = 25
|
| 410 |
image = inference(net, prompt, negative_prompt, cfg, steps, seed)
|
| 411 |
-
return net,image
|
| 412 |
|
| 413 |
|
| 414 |
|
|
@@ -504,8 +514,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 504 |
|
| 505 |
|
| 506 |
invert_button.click(fn=run_inversion,
|
| 507 |
-
inputs=[input_image, pcs, epochs, weight_decay,lr],
|
| 508 |
-
outputs = [
|
| 509 |
|
| 510 |
|
| 511 |
sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])
|
|
|
|
| 308 |
return image
|
| 309 |
|
| 310 |
@spaces.GPU
|
| 311 |
+
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
| 312 |
+
device = "cuda"
|
| 313 |
+
mean.to(device)
|
| 314 |
+
std.to(device)
|
| 315 |
+
v.to(device)
|
| 316 |
+
|
| 317 |
|
| 318 |
+
weights = torch.zeros(1,pcs).bfloat16().to(device)
|
| 319 |
+
network = LoRAw2w( weights, mean, std, v[:, :pcs],
|
| 320 |
unet,
|
| 321 |
rank=1,
|
| 322 |
multiplier=1.0,
|
|
|
|
| 368 |
optim.zero_grad()
|
| 369 |
loss.backward()
|
| 370 |
optim.step()
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
#pad to 10000 PCs
|
| 374 |
+
pcs_original = weights.shape[1]
|
| 375 |
+
padding = torch.zeros((1,10000-pcs_original)).to(device)
|
| 376 |
+
weights = network.proj.detach()
|
| 377 |
+
weights = torch.cat((weights, padding), 1)
|
| 378 |
+
|
| 379 |
+
net = "model_"+str(uuid.uuid4())[:4]+".pt"
|
| 380 |
+
torch.save(weights, net)
|
| 381 |
+
|
| 382 |
+
return net
|
| 383 |
|
| 384 |
|
| 385 |
@spaces.GPU
|
| 386 |
+
def run_inversion(net, dict, pcs, epochs, weight_decay,lr):
|
| 387 |
+
init_image = dict["background"].convert("RGB").resize((512, 512))
|
| 388 |
+
mask = dict["layers"][0].convert("RGB").resize((512, 512))
|
| 389 |
+
|
| 390 |
+
net = invert(init_image, mask, pcs, epochs, weight_decay,lr)
|
| 391 |
+
|
| 392 |
#sample an image
|
| 393 |
prompt = "sks person"
|
| 394 |
negative_prompt = "low quality, blurry, unfinished, nudity"
|
|
|
|
| 397 |
steps = 25
|
| 398 |
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
| 399 |
torch.save(network.proj, "model.pt" )
|
| 400 |
+
return net, net, image
|
| 401 |
|
| 402 |
|
| 403 |
@spaces.GPU
|
|
|
|
| 418 |
cfg = 3.0
|
| 419 |
steps = 25
|
| 420 |
image = inference(net, prompt, negative_prompt, cfg, steps, seed)
|
| 421 |
+
return net, image
|
| 422 |
|
| 423 |
|
| 424 |
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
invert_button.click(fn=run_inversion,
|
| 517 |
+
inputs=[net, input_image, pcs, epochs, weight_decay,lr],
|
| 518 |
+
outputs = [net, file_output, input_image])
|
| 519 |
|
| 520 |
|
| 521 |
sample.click(fn=sample_then_run,inputs = [net], outputs=[net, file_output, input_image])
|