Spaces:
Runtime error
Runtime error
Update warp_design_on_dress.py
Browse files- warp_design_on_dress.py +11 -6
warp_design_on_dress.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
| 4 |
from torchvision import transforms
|
| 5 |
from PIL import Image
|
| 6 |
from networks import GMM, UnetGenerator, load_checkpoint, Options
|
| 7 |
-
|
| 8 |
|
| 9 |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
|
| 10 |
os.makedirs(output_dir, exist_ok=True)
|
|
@@ -13,7 +13,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
| 13 |
im_h, im_w = 256, 192
|
| 14 |
tf = transforms.Compose([
|
| 15 |
transforms.Resize((im_h, im_w)),
|
| 16 |
-
transforms.ToTensor()
|
|
|
|
| 17 |
])
|
| 18 |
|
| 19 |
dress_img = Image.open(dress_path).convert("RGB")
|
|
@@ -28,11 +29,15 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
| 28 |
|
| 29 |
opt = Options()
|
| 30 |
gmm = GMM(opt)
|
| 31 |
-
load_checkpoint(gmm, gmm_ckpt, strict
|
| 32 |
gmm.cpu().eval()
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
with torch.no_grad():
|
| 35 |
-
grid, _ = gmm(
|
| 36 |
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
|
| 37 |
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
|
| 38 |
|
|
@@ -52,8 +57,8 @@ def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output
|
|
| 52 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
| 53 |
|
| 54 |
# Save output
|
| 55 |
-
out_img = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
|
| 56 |
-
out_img = (
|
| 57 |
out_pil = Image.fromarray(out_img)
|
| 58 |
|
| 59 |
output_path = os.path.join(output_dir, "tryon.jpg")
|
|
|
|
| 4 |
from torchvision import transforms
|
| 5 |
from PIL import Image
|
| 6 |
from networks import GMM, UnetGenerator, load_checkpoint, Options
|
| 7 |
+
from preprocessing import pad_to_22_channels
|
| 8 |
|
| 9 |
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
|
| 10 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 13 |
im_h, im_w = 256, 192
|
| 14 |
tf = transforms.Compose([
|
| 15 |
transforms.Resize((im_h, im_w)),
|
| 16 |
+
transforms.ToTensor(),
|
| 17 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Added normalization
|
| 18 |
])
|
| 19 |
|
| 20 |
dress_img = Image.open(dress_path).convert("RGB")
|
|
|
|
| 29 |
|
| 30 |
opt = Options()
|
| 31 |
gmm = GMM(opt)
|
| 32 |
+
load_checkpoint(gmm, gmm_ckpt, strict=False)
|
| 33 |
gmm.cpu().eval()
|
| 34 |
|
| 35 |
+
# Convert agnostic to 22 channels before passing to GMM
|
| 36 |
+
agnostic_22ch = pad_to_22_channels(agnostic)
|
| 37 |
+
design_mask_22ch = pad_to_22_channels(design_mask)
|
| 38 |
+
|
| 39 |
with torch.no_grad():
|
| 40 |
+
grid, _ = gmm(agnostic_22ch, design_mask_22ch) # Use padded inputs
|
| 41 |
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border')
|
| 42 |
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros')
|
| 43 |
|
|
|
|
| 57 |
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
|
| 58 |
|
| 59 |
# Save output
|
| 60 |
+
out_img = (tryon.squeeze().permute(1, 2, 0).cpu().numpy() + 1) * 127.5 # Denormalize
|
| 61 |
+
out_img = out_img.clip(0, 255).astype("uint8")
|
| 62 |
out_pil = Image.fromarray(out_img)
|
| 63 |
|
| 64 |
output_path = os.path.join(output_dir, "tryon.jpg")
|