shlok123 commited on
Commit
8eec341
·
1 Parent(s): b87c600

add everything

Browse files
Files changed (4) hide show
  1. app.py +49 -0
  2. inference.py +129 -0
  3. model.py +142 -0
  4. model/model_final.pth +3 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import numpy as np
5
+ # from outpaint import outpainting
6
+ # from model import colorazation, UNETmodel, utils1
7
+ # from model import inference, model
8
+ # from model import colorazation, deeplabmodel, utils
9
+ from model import MainModel
10
+ import inference as inf
11
+
12
+
13
+ # pretrained model
14
+ def colorize_image(image):
15
+ # Load the model
16
+ # file_path = 'ImageColorizationModel10.pth'
17
+ file_path = r'model\model_final.pth'
18
+ model_2 = inf.load_model(model_class=MainModel, file_path=file_path)
19
+ output_img = inf.predict_color(model_2, image=image)
20
+ return output_img
21
+
22
+
23
+
24
+ # pretrained model
25
+ colorization_interface = gr.Interface(
26
+ colorize_image,
27
+ gr.Image(type="pil", label="Input Image"),
28
+ [gr.Image(type="pil", label="Output Image")],
29
+ title="Image Colorization",
30
+ description="Upload an image to perform colorization.",
31
+
32
+ )
33
+
34
+ # deeplab model
35
+ # depinterface = gr.Interface(
36
+ # depColorize_image,
37
+ # gr.Image(type="pil", label="Input Image"),
38
+ # [gr.Image(type="pil", label="Output Image")],
39
+ # title="Image Colorization",
40
+ # description="Upload an image to perform colorization.",
41
+
42
+ # )
43
+
44
+ # scratch mod
45
+
46
+ # Launch the interface
47
+ # interface.launch(share=True)
48
+ with gr.TabbedInterface([ colorization_interface ], ["Colorization_pretrain_unet"]) as tabs:
49
+ tabs.launch(share=True)
inference.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import time
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ from tqdm.notebook import tqdm
8
+ import matplotlib.pyplot as plt
9
+ from skimage.color import rgb2lab, lab2rgb
10
+
11
+ import torch
12
+ from torch import nn, optim
13
+ from torchvision import transforms
14
+ from torchvision.utils import make_grid
15
+ from torch.utils.data import Dataset, DataLoader
16
+
17
+
18
+ def init_model(model, device):
19
+ model = model.to(device)
20
+ model = init_weights(model)
21
+ return model
22
+
23
+
24
+ def init_weights(net, init='norm', gain=0.02):
25
+
26
+ def init_func(m):
27
+ classname = m.__class__.__name__
28
+ if hasattr(m, 'weight') and 'Conv' in classname:
29
+ if init == 'norm':
30
+ nn.init.normal_(m.weight.data, mean=0.0, std=gain)
31
+ elif init == 'xavier':
32
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
33
+ elif init == 'kaiming':
34
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
35
+
36
+ if hasattr(m, 'bias') and m.bias is not None:
37
+ nn.init.constant_(m.bias.data, 0.0)
38
+ elif 'BatchNorm2d' in classname:
39
+ nn.init.normal_(m.weight.data, 1., gain)
40
+ nn.init.constant_(m.bias.data, 0.)
41
+
42
+ net.apply(init_func)
43
+ print(f"model initialized with {init} initialization")
44
+ return net
45
+
46
+ from fastai.vision.learner import create_body
47
+ from torchvision.models.resnet import resnet18
48
+ from fastai.vision.models.unet import DynamicUnet
49
+
50
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+
52
+ def lab_to_rgb(L, ab):
53
+ """
54
+ Takes a batch of images
55
+ """
56
+
57
+ L = (L + 1.) * 50.
58
+ ab = ab * 110.
59
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
60
+ rgb_imgs = []
61
+ for img in Lab:
62
+ img_rgb = lab2rgb(img)
63
+ rgb_imgs.append(img_rgb)
64
+ return np.stack(rgb_imgs, axis=0)
65
+
66
+ def build_res_unet(n_input=1, n_output=2, size=256):
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
+ body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
69
+ net_G = DynamicUnet(body, n_output, (size, size)).to(device)
70
+ return net_G
71
+
72
+ net_G = build_res_unet(n_input=1, n_output=2, size=256)
73
+
74
+ class GANLoss(nn.Module):
75
+ def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
76
+ super().__init__()
77
+ self.register_buffer('real_label', torch.tensor(real_label))
78
+ self.register_buffer('fake_label', torch.tensor(fake_label))
79
+ if gan_mode == 'vanilla':
80
+ self.loss = nn.BCEWithLogitsLoss()
81
+ elif gan_mode == 'lsgan':
82
+ self.loss = nn.MSELoss()
83
+
84
+ def get_labels(self, preds, target_is_real):
85
+ if target_is_real:
86
+ labels = self.real_label
87
+ else:
88
+ labels = self.fake_label
89
+ return labels.expand_as(preds)
90
+
91
+ def __call__(self, preds, target_is_real):
92
+ labels = self.get_labels(preds, target_is_real)
93
+ loss = self.loss(preds, labels)
94
+ return loss
95
+
96
+ def load_model(model_class, file_path):
97
+ model = model_class(net_G=net_G)
98
+ model.load_state_dict(torch.load(file_path, map_location=device))
99
+
100
+ resnet_weights = torch.load(file_path)
101
+ resnet_weights = torch.load(r"model\res18-unet.pt")
102
+ resnet_state_dict = resnet_weights['state_dict'] if 'state_dict' in resnet_weights else resnet_weights
103
+
104
+ model_dict = model.state_dict()
105
+ filtered_resnet_state_dict = {k: v for k, v in resnet_state_dict.items() if k in model_dict}
106
+ model_dict.update(filtered_resnet_state_dict)
107
+ model.load_state_dict(model_dict)
108
+ return model
109
+
110
+ # return model
111
+ # model = model_class()
112
+ # model.load_state_dict(torch.load(file_path))
113
+ # return model
114
+
115
+ def predict_color(model, image):
116
+ # img = Image.open(image)
117
+ img = image.resize((256, 256))
118
+ # to make it between -1 and 1
119
+ img = transforms.ToTensor()(img)[:1] * 2. - 1.
120
+
121
+ genimg = predict_and_return_image(model, img)
122
+ return genimg
123
+
124
+ def predict_and_return_image(model, img):
125
+ model.eval()
126
+ with torch.no_grad():
127
+ preds = model.net_G(img.unsqueeze(0).to(device))
128
+ colorized = lab_to_rgb(img.unsqueeze(0), preds.cpu())[0]
129
+ return colorized
model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, optim
3
+ from torchvision import transforms
4
+ from torchvision.utils import make_grid
5
+ from inference import init_model, GANLoss
6
+
7
+
8
+ class UnetBlock(nn.Module):
9
+ def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
10
+ innermost=False, outermost=False):
11
+ super().__init__()
12
+ self.outermost = outermost
13
+ if input_c is None: input_c = nf
14
+ downconv = nn.Conv2d(input_c, ni, kernel_size=4,
15
+ stride=2, padding=1, bias=False)
16
+ downrelu = nn.LeakyReLU(0.2, True)
17
+ downnorm = nn.BatchNorm2d(ni)
18
+ uprelu = nn.ReLU(True)
19
+ upnorm = nn.BatchNorm2d(nf)
20
+
21
+ if outermost:
22
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
23
+ stride=2, padding=1)
24
+ down = [downconv]
25
+ up = [uprelu, upconv, nn.Tanh()]
26
+ model = down + [submodule] + up
27
+ elif innermost:
28
+ upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
29
+ stride=2, padding=1, bias=False)
30
+ down = [downrelu, downconv]
31
+ up = [uprelu, upconv, upnorm]
32
+ model = down + up
33
+ else:
34
+ upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
35
+ stride=2, padding=1, bias=False)
36
+ down = [downrelu, downconv, downnorm]
37
+ up = [uprelu, upconv, upnorm]
38
+ if dropout: up += [nn.Dropout(0.5)]
39
+ model = down + [submodule] + up
40
+ self.model = nn.Sequential(*model)
41
+
42
+ def forward(self, x):
43
+ if self.outermost:
44
+ return self.model(x)
45
+ else:
46
+ return torch.cat([x, self.model(x)], 1)
47
+
48
+ class Unet(nn.Module):
49
+ def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
50
+ super().__init__()
51
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
52
+ for _ in range(n_down - 5):
53
+ unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
54
+ out_filters = num_filters * 8
55
+ for _ in range(3):
56
+ unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
57
+ out_filters //= 2
58
+ self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
59
+
60
+ def forward(self, x):
61
+ return self.model(x)
62
+
63
+ class PatchDiscriminator(nn.Module):
64
+ def __init__(self, input_c, num_filters=64, n_down=3):
65
+ super().__init__()
66
+ model = [self.get_layers(input_c, num_filters, norm=False)]
67
+ model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
68
+ for i in range(n_down)] # the 'if' statement is taking care of not using
69
+ # stride of 2 for the last block in this loop
70
+ model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
71
+ # activation for the last layer of the model
72
+ self.model = nn.Sequential(*model)
73
+
74
+ def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
75
+ layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
76
+ if norm: layers += [nn.BatchNorm2d(nf)]
77
+ if act: layers += [nn.LeakyReLU(0.2, True)]
78
+ return nn.Sequential(*layers)
79
+
80
+ def forward(self, x):
81
+ return self.model(x)
82
+
83
+ class MainModel(nn.Module):
84
+ def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
85
+ beta1=0.5, beta2=0.999, lambda_L1=100.):
86
+ super().__init__()
87
+
88
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ self.lambda_L1 = lambda_L1
90
+
91
+ if net_G is None:
92
+ self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
93
+ else:
94
+ self.net_G = net_G.to(self.device)
95
+ self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
96
+ self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
97
+ self.L1criterion = nn.L1Loss()
98
+ self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
99
+ self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
100
+
101
+ def set_requires_grad(self, model, requires_grad=True):
102
+ for p in model.parameters():
103
+ p.requires_grad = requires_grad
104
+
105
+ def setup_input(self, data):
106
+ self.L = data['L'].to(self.device)
107
+ self.ab = data['ab'].to(self.device)
108
+
109
+ def forward(self):
110
+ self.fake_color = self.net_G(self.L)
111
+
112
+ def backward_D(self):
113
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
114
+ fake_preds = self.net_D(fake_image.detach())
115
+ self.loss_D_fake = self.GANcriterion(fake_preds, False)
116
+ real_image = torch.cat([self.L, self.ab], dim=1)
117
+ real_preds = self.net_D(real_image)
118
+ self.loss_D_real = self.GANcriterion(real_preds, True)
119
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
120
+ self.loss_D.backward()
121
+
122
+ def backward_G(self):
123
+ fake_image = torch.cat([self.L, self.fake_color], dim=1)
124
+ fake_preds = self.net_D(fake_image)
125
+ self.loss_G_GAN = self.GANcriterion(fake_preds, True)
126
+ self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
127
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
128
+ self.loss_G.backward()
129
+
130
+ def optimize(self):
131
+ self.forward()
132
+ self.net_D.train()
133
+ self.set_requires_grad(self.net_D, True)
134
+ self.opt_D.zero_grad()
135
+ self.backward_D()
136
+ self.opt_D.step()
137
+
138
+ self.net_G.train()
139
+ self.set_requires_grad(self.net_D, False)
140
+ self.opt_G.zero_grad()
141
+ self.backward_G()
142
+ self.opt_G.step()
model/model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58876849eeea903233b5d0931ed0accabc5dd4230e5b897aa9aa0097df5ab93a
3
+ size 135588892