Spaces:
Runtime error
Runtime error
Update demo/model.py
Browse files- demo/model.py +23 -17
demo/model.py
CHANGED
|
@@ -17,6 +17,7 @@ import cv2
|
|
| 17 |
import numpy as np
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
|
|
|
| 20 |
def preprocessing(image, device):
|
| 21 |
# Resize
|
| 22 |
scale = 640 / max(image.shape[:2])
|
|
@@ -39,6 +40,7 @@ def preprocessing(image, device):
|
|
| 39 |
|
| 40 |
return image, raw_image
|
| 41 |
|
|
|
|
| 42 |
def imshow_keypoints(img,
|
| 43 |
pose_result,
|
| 44 |
skeleton=None,
|
|
@@ -138,18 +140,22 @@ class Model_all:
|
|
| 138 |
use_conv=False).to(device)
|
| 139 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
| 140 |
self.model_edge = pidinet().to(device)
|
| 141 |
-
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# segmentation part
|
| 144 |
self.model_seger = seger().to(device)
|
| 145 |
self.model_seger.eval()
|
| 146 |
self.coler = Colorize(n=182)
|
| 147 |
-
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
|
|
|
| 148 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
| 149 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
| 150 |
|
| 151 |
# depth part
|
| 152 |
-
self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
|
|
|
| 153 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
| 154 |
|
| 155 |
# keypose part
|
|
@@ -183,7 +189,7 @@ class Model_all:
|
|
| 183 |
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
| 184 |
[51, 153, 255],
|
| 185 |
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
| 186 |
-
|
| 187 |
def load_vae(self):
|
| 188 |
vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
|
| 189 |
sd = vae_sd["state_dict"]
|
|
@@ -254,7 +260,7 @@ class Model_all:
|
|
| 254 |
|
| 255 |
@torch.no_grad()
|
| 256 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
| 257 |
-
|
| 258 |
if self.current_base != base_model:
|
| 259 |
ckpt = os.path.join("models", base_model)
|
| 260 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
@@ -312,7 +318,8 @@ class Model_all:
|
|
| 312 |
return [im_depth, x_samples_ddim]
|
| 313 |
|
| 314 |
@torch.no_grad()
|
| 315 |
-
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
|
|
|
|
| 316 |
if self.current_base != base_model:
|
| 317 |
ckpt = os.path.join("models", base_model)
|
| 318 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
@@ -343,8 +350,7 @@ class Model_all:
|
|
| 343 |
|
| 344 |
# get keypose
|
| 345 |
if type_in_keypose == 'Keypose':
|
| 346 |
-
im_keypose_out = im_keypose.copy()
|
| 347 |
-
pose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
| 348 |
elif type_in_keypose == 'Image':
|
| 349 |
image = im_keypose.copy()
|
| 350 |
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
|
@@ -378,7 +384,7 @@ class Model_all:
|
|
| 378 |
pose_link_color=self.pose_link_color,
|
| 379 |
radius=2,
|
| 380 |
thickness=2)
|
| 381 |
-
im_keypose_out = im_keypose_out.astype(np.uint8)[
|
| 382 |
|
| 383 |
# extract condition features
|
| 384 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
|
@@ -387,7 +393,8 @@ class Model_all:
|
|
| 387 |
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
| 388 |
pose = pose.unsqueeze(0)
|
| 389 |
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
| 390 |
-
features_adapter = [f_d*w_depth + f_k*w_keypose for f_d, f_k in
|
|
|
|
| 391 |
shape = [4, 64, 64]
|
| 392 |
|
| 393 |
# sampling
|
|
@@ -416,7 +423,7 @@ class Model_all:
|
|
| 416 |
|
| 417 |
@torch.no_grad()
|
| 418 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
| 419 |
-
|
| 420 |
if self.current_base != base_model:
|
| 421 |
ckpt = os.path.join("models", base_model)
|
| 422 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
@@ -450,10 +457,10 @@ class Model_all:
|
|
| 450 |
labelmap = np.argmax(probs, axis=0)
|
| 451 |
|
| 452 |
labelmap = self.coler(labelmap)
|
| 453 |
-
labelmap = np.transpose(labelmap, (1,2,0))
|
| 454 |
labelmap = cv2.resize(labelmap, (512, 512))
|
| 455 |
-
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True)/255.
|
| 456 |
-
im_seg = tensor2img(labelmap)[
|
| 457 |
labelmap = labelmap.unsqueeze(0)
|
| 458 |
|
| 459 |
# extract condition features
|
|
@@ -564,8 +571,7 @@ class Model_all:
|
|
| 564 |
im = cv2.resize(input_img, (512, 512))
|
| 565 |
|
| 566 |
if type_in == 'Keypose':
|
| 567 |
-
im_pose = im.copy()
|
| 568 |
-
im = img2tensor(im).unsqueeze(0) / 255.
|
| 569 |
elif type_in == 'Image':
|
| 570 |
image = im.copy()
|
| 571 |
im = img2tensor(im).unsqueeze(0) / 255.
|
|
@@ -599,7 +605,7 @@ class Model_all:
|
|
| 599 |
pose_link_color=self.pose_link_color,
|
| 600 |
radius=2,
|
| 601 |
thickness=2)
|
| 602 |
-
im_pose = cv2.resize(im_pose, (512, 512))
|
| 603 |
|
| 604 |
# extract condition features
|
| 605 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
|
|
|
| 17 |
import numpy as np
|
| 18 |
import torch.nn.functional as F
|
| 19 |
|
| 20 |
+
|
| 21 |
def preprocessing(image, device):
|
| 22 |
# Resize
|
| 23 |
scale = 640 / max(image.shape[:2])
|
|
|
|
| 40 |
|
| 41 |
return image, raw_image
|
| 42 |
|
| 43 |
+
|
| 44 |
def imshow_keypoints(img,
|
| 45 |
pose_result,
|
| 46 |
skeleton=None,
|
|
|
|
| 140 |
use_conv=False).to(device)
|
| 141 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
| 142 |
self.model_edge = pidinet().to(device)
|
| 143 |
+
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in
|
| 144 |
+
torch.load('models/table5_pidinet.pth', map_location=device)[
|
| 145 |
+
'state_dict'].items()})
|
| 146 |
|
| 147 |
# segmentation part
|
| 148 |
self.model_seger = seger().to(device)
|
| 149 |
self.model_seger.eval()
|
| 150 |
self.coler = Colorize(n=182)
|
| 151 |
+
self.model_seg = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
| 152 |
+
use_conv=False).to(device)
|
| 153 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
| 154 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
| 155 |
|
| 156 |
# depth part
|
| 157 |
+
self.model_depth = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
| 158 |
+
use_conv=False).to(device)
|
| 159 |
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
| 160 |
|
| 161 |
# keypose part
|
|
|
|
| 189 |
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
| 190 |
[51, 153, 255],
|
| 191 |
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
| 192 |
+
|
| 193 |
def load_vae(self):
|
| 194 |
vae_sd = torch.load(os.path.join('models', 'anything-v4.0.vae.pt'), map_location="cuda")
|
| 195 |
sd = vae_sd["state_dict"]
|
|
|
|
| 260 |
|
| 261 |
@torch.no_grad()
|
| 262 |
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
| 263 |
+
con_strength, base_model):
|
| 264 |
if self.current_base != base_model:
|
| 265 |
ckpt = os.path.join("models", base_model)
|
| 266 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
|
| 318 |
return [im_depth, x_samples_ddim]
|
| 319 |
|
| 320 |
@torch.no_grad()
|
| 321 |
+
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth,
|
| 322 |
+
w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 323 |
if self.current_base != base_model:
|
| 324 |
ckpt = os.path.join("models", base_model)
|
| 325 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
|
| 350 |
|
| 351 |
# get keypose
|
| 352 |
if type_in_keypose == 'Keypose':
|
| 353 |
+
im_keypose_out = im_keypose.copy()[:,:,::-1]
|
|
|
|
| 354 |
elif type_in_keypose == 'Image':
|
| 355 |
image = im_keypose.copy()
|
| 356 |
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
|
|
|
| 384 |
pose_link_color=self.pose_link_color,
|
| 385 |
radius=2,
|
| 386 |
thickness=2)
|
| 387 |
+
im_keypose_out = im_keypose_out.astype(np.uint8)[:, :, ::-1]
|
| 388 |
|
| 389 |
# extract condition features
|
| 390 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
|
|
|
| 393 |
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
| 394 |
pose = pose.unsqueeze(0)
|
| 395 |
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
| 396 |
+
features_adapter = [f_d * w_depth + f_k * w_keypose for f_d, f_k in
|
| 397 |
+
zip(features_adapter_depth, features_adapter_keypose)]
|
| 398 |
shape = [4, 64, 64]
|
| 399 |
|
| 400 |
# sampling
|
|
|
|
| 423 |
|
| 424 |
@torch.no_grad()
|
| 425 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
| 426 |
+
con_strength, base_model):
|
| 427 |
if self.current_base != base_model:
|
| 428 |
ckpt = os.path.join("models", base_model)
|
| 429 |
pl_sd = torch.load(ckpt, map_location="cuda")
|
|
|
|
| 457 |
labelmap = np.argmax(probs, axis=0)
|
| 458 |
|
| 459 |
labelmap = self.coler(labelmap)
|
| 460 |
+
labelmap = np.transpose(labelmap, (1, 2, 0))
|
| 461 |
labelmap = cv2.resize(labelmap, (512, 512))
|
| 462 |
+
labelmap = img2tensor(labelmap, bgr2rgb=False, float32=True) / 255.
|
| 463 |
+
im_seg = tensor2img(labelmap)[:, :, ::-1]
|
| 464 |
labelmap = labelmap.unsqueeze(0)
|
| 465 |
|
| 466 |
# extract condition features
|
|
|
|
| 571 |
im = cv2.resize(input_img, (512, 512))
|
| 572 |
|
| 573 |
if type_in == 'Keypose':
|
| 574 |
+
im_pose = im.copy()[:,:,::-1]
|
|
|
|
| 575 |
elif type_in == 'Image':
|
| 576 |
image = im.copy()
|
| 577 |
im = img2tensor(im).unsqueeze(0) / 255.
|
|
|
|
| 605 |
pose_link_color=self.pose_link_color,
|
| 606 |
radius=2,
|
| 607 |
thickness=2)
|
| 608 |
+
# im_pose = cv2.resize(im_pose, (512, 512))
|
| 609 |
|
| 610 |
# extract condition features
|
| 611 |
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|