Commit ·
df20d82
1
Parent(s): c7c9ff6
added files
Browse files- .gitignore +1 -0
- __pycache__/models.cpython-310.pyc +0 -0
- generate_4ch.py +36 -4
- generate_4ch_from_huggingface.py +160 -0
- test_out/0_img.png +0 -0
- test_out/0_mask.png +0 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pre_trained_checkpoint_4ch
|
__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
generate_4ch.py
CHANGED
|
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|
| 5 |
from torchvision.datasets import ImageFolder
|
| 6 |
from torch.utils.data import DataLoader
|
| 7 |
from torchvision import utils as vutils
|
|
|
|
| 8 |
|
| 9 |
import os
|
| 10 |
import random
|
|
@@ -36,12 +37,22 @@ def batch_save(images, folder_name):
|
|
| 36 |
for i, image in enumerate(images):
|
| 37 |
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
if __name__ == "__main__":
|
| 41 |
parser = argparse.ArgumentParser(
|
| 42 |
description='generate images'
|
| 43 |
)
|
| 44 |
-
parser.add_argument('--ckpt', type=str, default="
|
| 45 |
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
| 46 |
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
| 47 |
parser.add_argument('--start_iter', type=int, default=6)
|
|
@@ -50,7 +61,7 @@ if __name__ == "__main__":
|
|
| 50 |
parser.add_argument('--dist', type=str, default='test_out')
|
| 51 |
parser.add_argument('--size', type=int, default=256)
|
| 52 |
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
| 53 |
-
parser.add_argument('--n_sample', type=int, default=
|
| 54 |
parser.add_argument('--big', action='store_true')
|
| 55 |
parser.add_argument('--im_size', type=int, default=256)
|
| 56 |
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
|
@@ -59,8 +70,17 @@ if __name__ == "__main__":
|
|
| 59 |
|
| 60 |
noise_dim = 256
|
| 61 |
device = torch.device('cuda:%d'%(args.cuda))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
net_ig.to(device)
|
| 65 |
|
| 66 |
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
|
@@ -69,13 +89,25 @@ if __name__ == "__main__":
|
|
| 69 |
checkpoint = torch.load(ckpt)
|
| 70 |
# Remove prefix `module`.
|
| 71 |
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
| 72 |
-
net_ig.load_state_dict(checkpoint['g'])
|
| 73 |
#load_params(net_ig, checkpoint['g_ema'])
|
| 74 |
|
| 75 |
#net_ig.eval()
|
| 76 |
print("load checkpoint success")
|
| 77 |
|
| 78 |
net_ig.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
del checkpoint
|
| 81 |
|
|
|
|
| 5 |
from torchvision.datasets import ImageFolder
|
| 6 |
from torch.utils.data import DataLoader
|
| 7 |
from torchvision import utils as vutils
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
|
| 10 |
import os
|
| 11 |
import random
|
|
|
|
| 37 |
for i, image in enumerate(images):
|
| 38 |
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
| 39 |
|
| 40 |
+
# To push the model to Huggingface model hub
|
| 41 |
+
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
|
| 42 |
+
|
| 43 |
+
def __init__(self, config: dict) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.model(x)
|
| 50 |
|
| 51 |
if __name__ == "__main__":
|
| 52 |
parser = argparse.ArgumentParser(
|
| 53 |
description='generate images'
|
| 54 |
)
|
| 55 |
+
parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
|
| 56 |
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
| 57 |
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
| 58 |
parser.add_argument('--start_iter', type=int, default=6)
|
|
|
|
| 61 |
parser.add_argument('--dist', type=str, default='test_out')
|
| 62 |
parser.add_argument('--size', type=int, default=256)
|
| 63 |
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
| 64 |
+
parser.add_argument('--n_sample', type=int, default=1)
|
| 65 |
parser.add_argument('--big', action='store_true')
|
| 66 |
parser.add_argument('--im_size', type=int, default=256)
|
| 67 |
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
|
|
|
| 70 |
|
| 71 |
noise_dim = 256
|
| 72 |
device = torch.device('cuda:%d'%(args.cuda))
|
| 73 |
+
|
| 74 |
+
# adding the model to the model hub
|
| 75 |
+
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
|
| 76 |
+
net_ig = MyFastGanModel(config=config)
|
| 77 |
+
|
| 78 |
|
| 79 |
+
|
| 80 |
+
# exit
|
| 81 |
+
#exit()
|
| 82 |
+
|
| 83 |
+
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
|
| 84 |
net_ig.to(device)
|
| 85 |
|
| 86 |
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
|
|
|
| 89 |
checkpoint = torch.load(ckpt)
|
| 90 |
# Remove prefix `module`.
|
| 91 |
checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
| 92 |
+
net_ig.model.load_state_dict(checkpoint['g'])
|
| 93 |
#load_params(net_ig, checkpoint['g_ema'])
|
| 94 |
|
| 95 |
#net_ig.eval()
|
| 96 |
print("load checkpoint success")
|
| 97 |
|
| 98 |
net_ig.to(device)
|
| 99 |
+
# Save locally
|
| 100 |
+
net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
|
| 101 |
+
print("Model saved locally. Pushing to Huggingface model hub...")
|
| 102 |
+
|
| 103 |
+
# Push to the Huggingface model hub
|
| 104 |
+
# push to the hub
|
| 105 |
+
net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
print("pushed to the Huggingface model hub. Done.")
|
| 109 |
+
exit()
|
| 110 |
+
|
| 111 |
|
| 112 |
del checkpoint
|
| 113 |
|
generate_4ch_from_huggingface.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch import optim
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.datasets import ImageFolder
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torchvision import utils as vutils
|
| 8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
import argparse
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from models import Generator
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_params(model, new_param):
|
| 19 |
+
for p, new_p in zip(model.parameters(), new_param):
|
| 20 |
+
p.data.copy_(new_p)
|
| 21 |
+
|
| 22 |
+
def resize(img):
|
| 23 |
+
return F.interpolate(img, size=256)
|
| 24 |
+
|
| 25 |
+
def batch_generate(zs, netG, batch=8):
|
| 26 |
+
g_images = []
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
for i in range(len(zs)//batch):
|
| 29 |
+
g_images.append( netG(zs[i*batch:(i+1)*batch]).cpu() )
|
| 30 |
+
if len(zs)%batch>0:
|
| 31 |
+
g_images.append( netG(zs[-(len(zs)%batch):]).cpu() )
|
| 32 |
+
return torch.cat(g_images)
|
| 33 |
+
|
| 34 |
+
def batch_save(images, folder_name):
|
| 35 |
+
if not os.path.exists(folder_name):
|
| 36 |
+
os.mkdir(folder_name)
|
| 37 |
+
for i, image in enumerate(images):
|
| 38 |
+
vutils.save_image(image.add(1).mul(0.5), folder_name+'/%d.jpg'%i)
|
| 39 |
+
|
| 40 |
+
# To push the model to Huggingface model hub
|
| 41 |
+
class MyFastGanModel(nn.Module, PyTorchModelHubMixin):
|
| 42 |
+
|
| 43 |
+
def __init__(self, config: dict) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.model = Generator( ngf=config["ngf"], nz=config["noise_dim"], nc=config["nc"], im_size=config["im_size"])
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.model(x)
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
parser = argparse.ArgumentParser(
|
| 53 |
+
description='generate images'
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument('--ckpt', type=str, default="/work/vajira/DL/FastGAN-pytorch/train_results/test1_4ch/models/all_50000.pth")
|
| 56 |
+
parser.add_argument('--artifacts', type=str, default=".", help='path to artifacts.')
|
| 57 |
+
parser.add_argument('--cuda', type=int, default=0, help='index of gpu to use')
|
| 58 |
+
parser.add_argument('--start_iter', type=int, default=6)
|
| 59 |
+
parser.add_argument('--end_iter', type=int, default=10)
|
| 60 |
+
|
| 61 |
+
parser.add_argument('--dist', type=str, default='test_out')
|
| 62 |
+
parser.add_argument('--size', type=int, default=256)
|
| 63 |
+
parser.add_argument('--batch', default=1, type=int, help='batch size')
|
| 64 |
+
parser.add_argument('--n_sample', type=int, default=1)
|
| 65 |
+
parser.add_argument('--big', action='store_true')
|
| 66 |
+
parser.add_argument('--im_size', type=int, default=256)
|
| 67 |
+
parser.add_argument("--save_option", default="image_and_mask", help="Options to svae output, image_only, mask_only, image_and_mask", choices=["image_only","mask_only", "image_and_mask"])
|
| 68 |
+
parser.set_defaults(big=False)
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
|
| 71 |
+
noise_dim = 256
|
| 72 |
+
device = torch.device('cuda:%d'%(args.cuda))
|
| 73 |
+
|
| 74 |
+
# adding the model to the model hub
|
| 75 |
+
config={"ngf":64, "noise_dim":noise_dim, "nc":4, "im_size":args.im_size}
|
| 76 |
+
net_ig = MyFastGanModel(config=config)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# exit
|
| 81 |
+
#exit()
|
| 82 |
+
|
| 83 |
+
#net_ig = model #Generator( ngf=64, nz=noise_dim, nc=4, im_size=args.im_size)#, big=args.big )
|
| 84 |
+
#net_ig.to(device)
|
| 85 |
+
|
| 86 |
+
#for epoch in [10000*i for i in range(args.start_iter, args.end_iter+1)]:
|
| 87 |
+
#ckpt = args.ckpt #f"{args.artifacts}/models/{epoch}.pth"
|
| 88 |
+
#checkpoint = torch.load(ckpt, map_location=lambda a,b: a)
|
| 89 |
+
#checkpoint = torch.load(ckpt)
|
| 90 |
+
# Remove prefix `module`.
|
| 91 |
+
#checkpoint['g'] = {k.replace('module.', ''): v for k, v in checkpoint['g'].items()}
|
| 92 |
+
#net_ig.model.load_state_dict(checkpoint['g'])
|
| 93 |
+
#load_params(net_ig, checkpoint['g_ema'])
|
| 94 |
+
|
| 95 |
+
net_ig = MyFastGanModel.from_pretrained("deepsynthbody/deepfake_gi_fastGAN", config=config) # Load the model from the hub
|
| 96 |
+
|
| 97 |
+
#net_ig.eval()
|
| 98 |
+
print("load checkpoint success")
|
| 99 |
+
|
| 100 |
+
net_ig.to(device)
|
| 101 |
+
# Save locally
|
| 102 |
+
# net_ig.save_pretrained("pre_trained_checkpoint_4ch", config=config) # Save the model locally
|
| 103 |
+
# print("Model saved locally. Pushing to Huggingface model hub...")
|
| 104 |
+
|
| 105 |
+
# Push to the Huggingface model hub
|
| 106 |
+
# push to the hub
|
| 107 |
+
# net_ig.push_to_hub("deepsynthbody/deepfake_gi_fastGAN", config=config)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#print("pushed to the Huggingface model hub. Done.")
|
| 111 |
+
#exit()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
#del checkpoint
|
| 115 |
+
|
| 116 |
+
#dist = 'eval_%d'%(epoch)
|
| 117 |
+
#dist = os.path.join(args.dist, 'img')
|
| 118 |
+
dist = args.dist
|
| 119 |
+
os.makedirs(dist, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
for i in tqdm(range(args.n_sample//args.batch)):
|
| 123 |
+
noise = torch.randn(args.batch, noise_dim).to(device)
|
| 124 |
+
g_imgs = net_ig(noise)[0]
|
| 125 |
+
g_imgs = F.interpolate(g_imgs, 512)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
for j, g_img in enumerate( g_imgs ):
|
| 129 |
+
#print("img sahpe=", g_img.shape)
|
| 130 |
+
g_mask = g_img.add(1).mul(0.5)[-1, :, :].expand(3, -1, -1)
|
| 131 |
+
g_img = g_img.add(1).mul(0.5)[0:3, :, :]
|
| 132 |
+
|
| 133 |
+
# Clean generated data using clamping
|
| 134 |
+
g_mask = torch.clamp(g_mask, min=0, max=1)
|
| 135 |
+
g_img = torch.clamp(g_img, min=0, max=1)
|
| 136 |
+
#print(g_mask.type())
|
| 137 |
+
g_mask = (g_mask > 0.5) * 1.0
|
| 138 |
+
#print(g_mask.type())
|
| 139 |
+
|
| 140 |
+
#print("gmask_min:", g_mask.min())
|
| 141 |
+
#print("gmask_max:", g_mask.max())
|
| 142 |
+
#exit()
|
| 143 |
+
|
| 144 |
+
#print("img sahpe=", g_img.shape)
|
| 145 |
+
|
| 146 |
+
if args.save_option == "image_and_mask":
|
| 147 |
+
vutils.save_image(g_img,
|
| 148 |
+
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
| 149 |
+
vutils.save_image(g_mask,
|
| 150 |
+
os.path.join(dist, '%d_mask.png'%(i*args.batch+j))) #, normalize=True, range=(0,1))
|
| 151 |
+
|
| 152 |
+
elif args.save_option == "image_only":
|
| 153 |
+
vutils.save_image(g_img,
|
| 154 |
+
os.path.join(dist, '%d_img.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
| 155 |
+
|
| 156 |
+
elif args.save_option == "mask_only":
|
| 157 |
+
vutils.save_image(g_mask,
|
| 158 |
+
os.path.join(dist, '%d_mask.png'%(i*args.batch+j)))#, normalize=True, range=(-1,1))
|
| 159 |
+
else:
|
| 160 |
+
print("wrong choise to save option.")
|
test_out/0_img.png
ADDED
|
test_out/0_mask.png
ADDED
|