Update app.py
Browse files
app.py
CHANGED
|
@@ -124,10 +124,13 @@ def sample_then_run(net):
|
|
| 124 |
standev = torch.std(proj, 0)
|
| 125 |
|
| 126 |
# sample
|
| 127 |
-
sample = torch.zeros([1,
|
|
|
|
|
|
|
| 128 |
for i in range(1000):
|
| 129 |
sample[0, i] = torch.normal(m[i], standev[i], (1,1))
|
| 130 |
|
|
|
|
| 131 |
net = "model_"+str(uuid.uuid4())[:4]+".pt"
|
| 132 |
torch.save(sample, net)
|
| 133 |
|
|
@@ -148,7 +151,7 @@ def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
| 148 |
v.to(device)
|
| 149 |
|
| 150 |
weights = torch.load(net).to(device)
|
| 151 |
-
network = LoRAw2w(weights, mean, std, v[:, :
|
| 152 |
unet,
|
| 153 |
rank=1,
|
| 154 |
multiplier=1.0,
|
|
@@ -215,7 +218,7 @@ def edit_inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, see
|
|
| 215 |
|
| 216 |
|
| 217 |
weights = torch.load(net).to(device)
|
| 218 |
-
network = LoRAw2w(weights, mean, std, v[:, :
|
| 219 |
unet,
|
| 220 |
rank=1,
|
| 221 |
multiplier=1.0,
|
|
@@ -386,32 +389,25 @@ def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
|
|
| 386 |
|
| 387 |
|
| 388 |
@spaces.GPU
|
| 389 |
-
def file_upload(
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
#pad to 10000 Principal components to keep everything consistent
|
| 393 |
-
pcs =
|
| 394 |
padding = torch.zeros((1,10000-pcs)).to(device)
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
rank=1,
|
| 402 |
-
multiplier=1.0,
|
| 403 |
-
alpha=27.0,
|
| 404 |
-
train_method="xattn-strict"
|
| 405 |
-
).to(device, torch.bfloat16)
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
prompt = "sks person"
|
| 409 |
-
negative_prompt = "low quality, blurry, unfinished, nudity"
|
| 410 |
seed = 5
|
| 411 |
cfg = 3.0
|
| 412 |
steps = 25
|
| 413 |
-
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
| 414 |
-
return image
|
| 415 |
|
| 416 |
|
| 417 |
|
|
@@ -516,7 +512,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 516 |
submit.click(
|
| 517 |
fn=edit_inference, inputs=[net, prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[net, gallery]
|
| 518 |
)
|
| 519 |
-
|
| 520 |
|
| 521 |
|
| 522 |
|
|
|
|
| 124 |
standev = torch.std(proj, 0)
|
| 125 |
|
| 126 |
# sample
|
| 127 |
+
sample = torch.zeros([1, 10000]).to(device)
|
| 128 |
+
|
| 129 |
+
#only first 1000 PCs
|
| 130 |
for i in range(1000):
|
| 131 |
sample[0, i] = torch.normal(m[i], standev[i], (1,1))
|
| 132 |
|
| 133 |
+
|
| 134 |
net = "model_"+str(uuid.uuid4())[:4]+".pt"
|
| 135 |
torch.save(sample, net)
|
| 136 |
|
|
|
|
| 151 |
v.to(device)
|
| 152 |
|
| 153 |
weights = torch.load(net).to(device)
|
| 154 |
+
network = LoRAw2w(weights, mean, std, v[:, :10000],
|
| 155 |
unet,
|
| 156 |
rank=1,
|
| 157 |
multiplier=1.0,
|
|
|
|
| 218 |
|
| 219 |
|
| 220 |
weights = torch.load(net).to(device)
|
| 221 |
+
network = LoRAw2w(weights, mean, std, v[:, :10000],
|
| 222 |
unet,
|
| 223 |
rank=1,
|
| 224 |
multiplier=1.0,
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
@spaces.GPU
|
| 392 |
+
def file_upload(file):
|
| 393 |
+
device="cuda"
|
| 394 |
+
weights = torch.load(file.name).to(device)
|
| 395 |
+
net = "model_"+str(uuid.uuid4())[:4]+".pt"
|
| 396 |
+
torch.save(weights, net)
|
| 397 |
|
| 398 |
#pad to 10000 Principal components to keep everything consistent
|
| 399 |
+
pcs = net.shape[1]
|
| 400 |
padding = torch.zeros((1,10000-pcs)).to(device)
|
| 401 |
+
weights = torch.cat((weights, padding), 1)
|
| 402 |
+
|
|
|
|
| 403 |
|
| 404 |
+
image = prompt = "sks person"
|
| 405 |
+
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
seed = 5
|
| 407 |
cfg = 3.0
|
| 408 |
steps = 25
|
| 409 |
+
image = inference(net, prompt, negative_prompt, cfg, steps, seed)
|
| 410 |
+
return net,net,image
|
| 411 |
|
| 412 |
|
| 413 |
|
|
|
|
| 512 |
submit.click(
|
| 513 |
fn=edit_inference, inputs=[net, prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[net, gallery]
|
| 514 |
)
|
| 515 |
+
file_input.change(fn=file_upload, inputs=[file_input, net], outputs = [net, gallery])
|
| 516 |
|
| 517 |
|
| 518 |
|