Spaces:
Runtime error
Runtime error
| from share import * | |
| import config | |
| import cv2 | |
| import einops | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import random | |
| ### | |
| import cv2 | |
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torch.autograd import Variable | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import gdown | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| ### | |
| from pytorch_lightning import seed_everything | |
| from annotator.util import resize_image, HWC3 | |
| from annotator.hed import HEDdetector, nms | |
| from cldm.model import create_model, load_state_dict | |
| from cldm.ddim_hacked import DDIMSampler | |
| apply_hed = HEDdetector() | |
| model = create_model('./models/cldm_v15.yaml').cpu() | |
| #model.load_state_dict(load_state_dict('./control_sd15_scribble.pth', location='cuda')) | |
| ddim_sampler = DDIMSampler(model) | |
| from safetensors.torch import load_file as safe_load_file #add | |
| pl_sd = safe_load_file('./Realistic_Vision_V2.0.safetensors') #add | |
| model.load_state_dict(load_state_dict('./Realistic_Vision_V2.0.safetensors', location='cuda'),strict=False) #add | |
| model.control_model.load_state_dict(load_state_dict('./control_scribble-fp16.safetensors',location='cuda')) | |
| #model.load_state_dict(load_state_dict(pl_sd, strict=False)) #add | |
| model = model.cuda() | |
| ######### | |
| ######## | |
| import torch | |
| # | |
| import torch.nn as nn | |
| from torchvision import models | |
| import torch.nn.functional as F | |
| bce_loss = nn.BCELoss(size_average=True) | |
| def muti_loss_fusion(preds, target): | |
| loss0 = 0.0 | |
| loss = 0.0 | |
| for i in range(0,len(preds)): | |
| # print("i: ", i, preds[i].shape) | |
| if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): | |
| # tmp_target = _upsample_like(target,preds[i]) | |
| tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) | |
| loss = loss + bce_loss(preds[i],tmp_target) | |
| else: | |
| loss = loss + bce_loss(preds[i],target) | |
| if(i==0): | |
| loss0 = loss | |
| return loss0, loss | |
| fea_loss = nn.MSELoss(size_average=True) | |
| kl_loss = nn.KLDivLoss(size_average=True) | |
| l1_loss = nn.L1Loss(size_average=True) | |
| smooth_l1_loss = nn.SmoothL1Loss(size_average=True) | |
| def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'): | |
| loss0 = 0.0 | |
| loss = 0.0 | |
| for i in range(0,len(preds)): | |
| # print("i: ", i, preds[i].shape) | |
| if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]): | |
| # tmp_target = _upsample_like(target,preds[i]) | |
| tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True) | |
| loss = loss + bce_loss(preds[i],tmp_target) | |
| else: | |
| loss = loss + bce_loss(preds[i],target) | |
| if(i==0): | |
| loss0 = loss | |
| for i in range(0,len(dfs)): | |
| if(mode=='MSE'): | |
| loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints | |
| # print("fea_loss: ", fea_loss(dfs[i],fs[i]).item()) | |
| elif(mode=='KL'): | |
| loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)) | |
| # print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item()) | |
| elif(mode=='MAE'): | |
| loss = loss + l1_loss(dfs[i],fs[i]) | |
| # print("ls_loss: ", l1_loss(dfs[i],fs[i])) | |
| elif(mode=='SmoothL1'): | |
| loss = loss + smooth_l1_loss(dfs[i],fs[i]) | |
| # print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item()) | |
| return loss0, loss | |
| class REBNCONV(nn.Module): | |
| def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): | |
| super(REBNCONV,self).__init__() | |
| self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) | |
| self.bn_s1 = nn.BatchNorm2d(out_ch) | |
| self.relu_s1 = nn.ReLU(inplace=True) | |
| def forward(self,x): | |
| hx = x | |
| xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) | |
| return xout | |
| ## upsample tensor 'src' to have the same spatial size with tensor 'tar' | |
| def _upsample_like(src,tar): | |
| src = F.upsample(src,size=tar.shape[2:],mode='bilinear') | |
| return src | |
| ### RSU-7 ### | |
| class RSU7(nn.Module): | |
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): | |
| super(RSU7,self).__init__() | |
| self.in_ch = in_ch | |
| self.mid_ch = mid_ch | |
| self.out_ch = out_ch | |
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2 | |
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) | |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) | |
| self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) | |
| def forward(self,x): | |
| b, c, h, w = x.shape | |
| hx = x | |
| hxin = self.rebnconvin(hx) | |
| hx1 = self.rebnconv1(hxin) | |
| hx = self.pool1(hx1) | |
| hx2 = self.rebnconv2(hx) | |
| hx = self.pool2(hx2) | |
| hx3 = self.rebnconv3(hx) | |
| hx = self.pool3(hx3) | |
| hx4 = self.rebnconv4(hx) | |
| hx = self.pool4(hx4) | |
| hx5 = self.rebnconv5(hx) | |
| hx = self.pool5(hx5) | |
| hx6 = self.rebnconv6(hx) | |
| hx7 = self.rebnconv7(hx6) | |
| hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) | |
| hx6dup = _upsample_like(hx6d,hx5) | |
| hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) | |
| hx5dup = _upsample_like(hx5d,hx4) | |
| hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) | |
| hx4dup = _upsample_like(hx4d,hx3) | |
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) | |
| hx3dup = _upsample_like(hx3d,hx2) | |
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) | |
| hx2dup = _upsample_like(hx2d,hx1) | |
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) | |
| return hx1d + hxin | |
| ### RSU-6 ### | |
| class RSU6(nn.Module): | |
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): | |
| super(RSU6,self).__init__() | |
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) | |
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) | |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) | |
| self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.rebnconvin(hx) | |
| hx1 = self.rebnconv1(hxin) | |
| hx = self.pool1(hx1) | |
| hx2 = self.rebnconv2(hx) | |
| hx = self.pool2(hx2) | |
| hx3 = self.rebnconv3(hx) | |
| hx = self.pool3(hx3) | |
| hx4 = self.rebnconv4(hx) | |
| hx = self.pool4(hx4) | |
| hx5 = self.rebnconv5(hx) | |
| hx6 = self.rebnconv6(hx5) | |
| hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) | |
| hx5dup = _upsample_like(hx5d,hx4) | |
| hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) | |
| hx4dup = _upsample_like(hx4d,hx3) | |
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) | |
| hx3dup = _upsample_like(hx3d,hx2) | |
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) | |
| hx2dup = _upsample_like(hx2d,hx1) | |
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) | |
| return hx1d + hxin | |
| ### RSU-5 ### | |
| class RSU5(nn.Module): | |
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): | |
| super(RSU5,self).__init__() | |
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) | |
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) | |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) | |
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.rebnconvin(hx) | |
| hx1 = self.rebnconv1(hxin) | |
| hx = self.pool1(hx1) | |
| hx2 = self.rebnconv2(hx) | |
| hx = self.pool2(hx2) | |
| hx3 = self.rebnconv3(hx) | |
| hx = self.pool3(hx3) | |
| hx4 = self.rebnconv4(hx) | |
| hx5 = self.rebnconv5(hx4) | |
| hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) | |
| hx4dup = _upsample_like(hx4d,hx3) | |
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) | |
| hx3dup = _upsample_like(hx3d,hx2) | |
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) | |
| hx2dup = _upsample_like(hx2d,hx1) | |
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) | |
| return hx1d + hxin | |
| ### RSU-4 ### | |
| class RSU4(nn.Module): | |
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): | |
| super(RSU4,self).__init__() | |
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) | |
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) | |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) | |
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) | |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) | |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.rebnconvin(hx) | |
| hx1 = self.rebnconv1(hxin) | |
| hx = self.pool1(hx1) | |
| hx2 = self.rebnconv2(hx) | |
| hx = self.pool2(hx2) | |
| hx3 = self.rebnconv3(hx) | |
| hx4 = self.rebnconv4(hx3) | |
| hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) | |
| hx3dup = _upsample_like(hx3d,hx2) | |
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) | |
| hx2dup = _upsample_like(hx2d,hx1) | |
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) | |
| return hx1d + hxin | |
| ### RSU-4F ### | |
| class RSU4F(nn.Module): | |
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): | |
| super(RSU4F,self).__init__() | |
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) | |
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) | |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) | |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) | |
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) | |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) | |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) | |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.rebnconvin(hx) | |
| hx1 = self.rebnconv1(hxin) | |
| hx2 = self.rebnconv2(hx1) | |
| hx3 = self.rebnconv3(hx2) | |
| hx4 = self.rebnconv4(hx3) | |
| hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) | |
| hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) | |
| hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) | |
| return hx1d + hxin | |
| class myrebnconv(nn.Module): | |
| def __init__(self, in_ch=3, | |
| out_ch=1, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| dilation=1, | |
| groups=1): | |
| super(myrebnconv,self).__init__() | |
| self.conv = nn.Conv2d(in_ch, | |
| out_ch, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups) | |
| self.bn = nn.BatchNorm2d(out_ch) | |
| self.rl = nn.ReLU(inplace=True) | |
| def forward(self,x): | |
| return self.rl(self.bn(self.conv(x))) | |
| class ISNetGTEncoder(nn.Module): | |
| def __init__(self,in_ch=1,out_ch=1): | |
| super(ISNetGTEncoder,self).__init__() | |
| self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1) | |
| self.stage1 = RSU7(16,16,64) | |
| self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage2 = RSU6(64,16,64) | |
| self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage3 = RSU5(64,32,128) | |
| self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage4 = RSU4(128,32,256) | |
| self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage5 = RSU4F(256,64,512) | |
| self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage6 = RSU4F(512,64,512) | |
| self.side1 = nn.Conv2d(64,out_ch,3,padding=1) | |
| self.side2 = nn.Conv2d(64,out_ch,3,padding=1) | |
| self.side3 = nn.Conv2d(128,out_ch,3,padding=1) | |
| self.side4 = nn.Conv2d(256,out_ch,3,padding=1) | |
| self.side5 = nn.Conv2d(512,out_ch,3,padding=1) | |
| self.side6 = nn.Conv2d(512,out_ch,3,padding=1) | |
| def compute_loss(self, preds, targets): | |
| return muti_loss_fusion(preds,targets) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.conv_in(hx) | |
| # hx = self.pool_in(hxin) | |
| #stage 1 | |
| hx1 = self.stage1(hxin) | |
| hx = self.pool12(hx1) | |
| #stage 2 | |
| hx2 = self.stage2(hx) | |
| hx = self.pool23(hx2) | |
| #stage 3 | |
| hx3 = self.stage3(hx) | |
| hx = self.pool34(hx3) | |
| #stage 4 | |
| hx4 = self.stage4(hx) | |
| hx = self.pool45(hx4) | |
| #stage 5 | |
| hx5 = self.stage5(hx) | |
| hx = self.pool56(hx5) | |
| #stage 6 | |
| hx6 = self.stage6(hx) | |
| #side output | |
| d1 = self.side1(hx1) | |
| d1 = _upsample_like(d1,x) | |
| d2 = self.side2(hx2) | |
| d2 = _upsample_like(d2,x) | |
| d3 = self.side3(hx3) | |
| d3 = _upsample_like(d3,x) | |
| d4 = self.side4(hx4) | |
| d4 = _upsample_like(d4,x) | |
| d5 = self.side5(hx5) | |
| d5 = _upsample_like(d5,x) | |
| d6 = self.side6(hx6) | |
| d6 = _upsample_like(d6,x) | |
| # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) | |
| return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6] | |
| class ISNetDIS(nn.Module): | |
| def __init__(self,in_ch=3,out_ch=1): | |
| super(ISNetDIS,self).__init__() | |
| self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) | |
| self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage1 = RSU7(64,32,64) | |
| self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage2 = RSU6(64,32,128) | |
| self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage3 = RSU5(128,64,256) | |
| self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage4 = RSU4(256,128,512) | |
| self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage5 = RSU4F(512,256,512) | |
| self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) | |
| self.stage6 = RSU4F(512,256,512) | |
| # decoder | |
| self.stage5d = RSU4F(1024,256,512) | |
| self.stage4d = RSU4(1024,128,256) | |
| self.stage3d = RSU5(512,64,128) | |
| self.stage2d = RSU6(256,32,64) | |
| self.stage1d = RSU7(128,16,64) | |
| self.side1 = nn.Conv2d(64,out_ch,3,padding=1) | |
| self.side2 = nn.Conv2d(64,out_ch,3,padding=1) | |
| self.side3 = nn.Conv2d(128,out_ch,3,padding=1) | |
| self.side4 = nn.Conv2d(256,out_ch,3,padding=1) | |
| self.side5 = nn.Conv2d(512,out_ch,3,padding=1) | |
| self.side6 = nn.Conv2d(512,out_ch,3,padding=1) | |
| # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) | |
| def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'): | |
| # return muti_loss_fusion(preds,targets) | |
| return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode) | |
| def compute_loss(self, preds, targets): | |
| # return muti_loss_fusion(preds,targets) | |
| return muti_loss_fusion(preds, targets) | |
| def forward(self,x): | |
| hx = x | |
| hxin = self.conv_in(hx) | |
| #hx = self.pool_in(hxin) | |
| #stage 1 | |
| hx1 = self.stage1(hxin) | |
| hx = self.pool12(hx1) | |
| #stage 2 | |
| hx2 = self.stage2(hx) | |
| hx = self.pool23(hx2) | |
| #stage 3 | |
| hx3 = self.stage3(hx) | |
| hx = self.pool34(hx3) | |
| #stage 4 | |
| hx4 = self.stage4(hx) | |
| hx = self.pool45(hx4) | |
| #stage 5 | |
| hx5 = self.stage5(hx) | |
| hx = self.pool56(hx5) | |
| #stage 6 | |
| hx6 = self.stage6(hx) | |
| hx6up = _upsample_like(hx6,hx5) | |
| #-------------------- decoder -------------------- | |
| hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) | |
| hx5dup = _upsample_like(hx5d,hx4) | |
| hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) | |
| hx4dup = _upsample_like(hx4d,hx3) | |
| hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) | |
| hx3dup = _upsample_like(hx3d,hx2) | |
| hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) | |
| hx2dup = _upsample_like(hx2d,hx1) | |
| hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) | |
| #side output | |
| d1 = self.side1(hx1d) | |
| d1 = _upsample_like(d1,x) | |
| d2 = self.side2(hx2d) | |
| d2 = _upsample_like(d2,x) | |
| d3 = self.side3(hx3d) | |
| d3 = _upsample_like(d3,x) | |
| d4 = self.side4(hx4d) | |
| d4 = _upsample_like(d4,x) | |
| d5 = self.side5(hx5d) | |
| d5 = _upsample_like(d5,x) | |
| d6 = self.side6(hx6) | |
| d6 = _upsample_like(d6,x) | |
| # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) | |
| return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] | |
| ### | |
| ## | |
| ###### | |
| warnings.filterwarnings("ignore") | |
| from data_loader_cache import normalize, im_reader, im_preprocess | |
| from models import * | |
| import torch.nn as nn | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class GOSNormalize(object): | |
| ''' | |
| Normalize the Image using torch.transforms | |
| ''' | |
| def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self,image): | |
| image = normalize(image,self.mean,self.std) | |
| return image | |
| transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])]) | |
| def load_image(im_path, hypar): | |
| #im = im_reader(im_path) | |
| im, im_shp = im_preprocess(im_path, hypar["cache_size"]) | |
| im = torch.divide(im,255.0) | |
| shape = torch.from_numpy(np.array(im_shp)) | |
| return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape | |
| def build_model(hypar,device): | |
| net = hypar["model"]#GOSNETINC(3,1) | |
| # convert to half precision | |
| if(hypar["model_digit"]=="half"): | |
| net.half() | |
| for layer in net.modules(): | |
| if isinstance(layer, nn.BatchNorm2d): | |
| layer.float() | |
| net.to(device) | |
| if(hypar["restore_model"]!=""): | |
| net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device)) | |
| net.to(device) | |
| net.eval() | |
| return net | |
| def predict(net, inputs_val, shapes_val, hypar, device): | |
| ''' | |
| Given an Image, predict the mask | |
| ''' | |
| net.eval() | |
| if(hypar["model_digit"]=="full"): | |
| inputs_val = inputs_val.type(torch.FloatTensor) | |
| else: | |
| inputs_val = inputs_val.type(torch.HalfTensor) | |
| inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable | |
| ds_val = net(inputs_val_v)[0] # list of 6 results | |
| pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction | |
| ## recover the prediction spatial size to the orignal image size | |
| pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear')) | |
| ma = torch.max(pred_val) | |
| mi = torch.min(pred_val) | |
| pred_val = (pred_val-mi)/(ma-mi) # max = 1 | |
| if device == 'cuda': torch.cuda.empty_cache() | |
| return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need | |
| # Set Parameters | |
| hypar = {} # paramters for inferencing | |
| hypar["model_path"] ="./model" ## load trained weights from this path | |
| hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights | |
| hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision | |
| ## choose floating point accuracy -- | |
| hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number | |
| hypar["seed"] = 0 | |
| hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size | |
| ## data augmentation parameters --- | |
| hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images | |
| hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation | |
| hypar["model"] = ISNetDIS() | |
| # Build Model | |
| net = build_model(hypar, device) | |
| ###### | |
| from numpy import asarray | |
| from PIL import Image, ImageEnhance, ImageFilter | |
| ######## | |
| from diffusers import (ControlNetModel, DiffusionPipeline, | |
| StableDiffusionControlNetPipeline, | |
| UniPCMultistepScheduler) | |
| import gc | |
| ###### | |
| from rembg import remove | |
| from PIL import Image | |
| def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta): | |
| with torch.no_grad(): | |
| image = input_image | |
| w, h = 512, 512 | |
| data = np.zeros((h, w, 3), dtype=np.uint8) | |
| data[0:256, 0:256] = [255, 0, 0] # red patch in upper left | |
| img = Image.fromarray(input_image) | |
| kmg = Image.fromarray(input_image) | |
| # image_tensor, orig_size = load_image(input_image, hypar) | |
| # mask = predict(net, image_tensor, orig_size, hypar, device) | |
| # pil_mask = Image.fromarray(mask).convert('L') | |
| # pil_mask1=pil_mask.copy() | |
| #### | |
| # pil_mask1=asarray(pil_mask1) | |
| # pil_mask1[pil_mask1>0]=255 | |
| # pil_mask1=Image.fromarray(pil_mask1).convert('L') | |
| # pil_mask1 = pil_mask1.filter(ImageFilter.GaussianBlur(radius=1)) | |
| ##dis | |
| output = remove(img) | |
| im_rgb = output #img.convert('RGB') | |
| im_rgx = output #img.convert('RGB') | |
| img_enhancer = ImageEnhance.Brightness(im_rgb) | |
| factor = 0.09 | |
| im_rgb = img_enhancer.enhance(factor) | |
| im_rgba = im_rgb.copy() | |
| im_rgbx=im_rgx.copy() | |
| # im_rgba.putalpha(pil_mask) | |
| # im_rgbx.putalpha(pil_mask1) | |
| #dis end | |
| # img=asarray(im_rgx.copy()) | |
| # # Find the contours of the masked object | |
| # contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
| # # Find the bounding box of the masked object | |
| # x, y, w, h = cv2.boundingRect(contours[0]) | |
| # # Create a mask for the background | |
| # bg_mask = np.zeros(img.shape[:2], dtype=np.uint8) | |
| # bg_mask[y:y+h, x:x+w] = 255 | |
| # # Create a blurred version of the mask | |
| # blur_mask = cv2.GaussianBlur(mask, (15, 15), 0) | |
| # # Perform seamless cloning | |
| # im_rgbx = cv2.seamlessClone(img, img, blur_mask, (x + w // 2, y + h // 2), cv2.NORMAL_CLONE) | |
| input_image = asarray(im_rgba) | |
| # input_image = asarray(img_rembg) | |
| ############### | |
| inp_img=asarray(im_rgbx) | |
| inp_img = HWC3(inp_img) | |
| detected_map = apply_hed(resize_image(inp_img, detect_resolution)) | |
| detected_map = HWC3(detected_map) | |
| img_x = resize_image(inp_img, image_resolution) | |
| ############ | |
| input_image = HWC3(input_image) | |
| detected_map = apply_hed(resize_image(input_image, detect_resolution)) | |
| detected_map = HWC3(detected_map) | |
| img = resize_image(input_image, image_resolution) | |
| H, W, C = img.shape | |
| ##### | |
| # control_image = np.zeros_like(img, dtype=np.uint8) | |
| # control_image[np.min(img, axis=2) < 127] = 255 | |
| # vis_control_image = 255 - control_image | |
| # control_image, vis_control_image= Image.fromarray(control_image),Image.fromarray(vis_control_image) | |
| # model_id = '/content/drive/MyDrive/sasha/control_sd15_scribble.pth' | |
| # controlnet = ControlNetModel.from_pretrained(model_id, | |
| # torch_dtype=torch.float16) | |
| # base_model_id='/content/drive/MyDrive/sasha/Realistic_Vision_V1.3.safetensors' | |
| # pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| # base_model_id, | |
| # safety_checker=None, | |
| # controlnet=controlnet, | |
| # torch_dtype=torch.float16) | |
| # pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| # pipe.scheduler.config) | |
| # pipe.enable_xformers_memory_efficient_attention() | |
| # pipe.to(device) | |
| # torch.cuda.empty_cache() | |
| # gc.collect() | |
| # if seed == -1: | |
| # seed = np.random.randint(0, np.iinfo(np.int64).max) | |
| # generator = torch.Generator().manual_seed(seed) | |
| # resolt= pipe(prompt=prompt, | |
| # negative_prompt=n_prompt, | |
| # guidance_scale=scale, | |
| # num_images_per_prompt=num_samples, | |
| # num_inference_steps=ddim_steps, | |
| # generator=generator, | |
| # image=control_image).images | |
| ##################################### | |
| detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) | |
| detected_map = nms(detected_map, 127, 3.0) | |
| detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) | |
| detected_map[detected_map > 4] = 255 | |
| detected_map[detected_map < 255] = 0 | |
| control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 | |
| control = torch.stack([control for _ in range(num_samples)], dim=0) | |
| control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
| if seed == -1: | |
| seed = random.randint(0, 65535) | |
| seed_everything(seed) | |
| if config.save_memory: | |
| model.low_vram_shift(is_diffusing=False) | |
| cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning(['RAW photo,'+prompt +', '+', minimal product photo, In the style of David Newton, Helen Koker, Aneta Laura, Nikki Astwood, Amy Shamblen, Hyperrealism, soft smooth lighting, luxury, pinterest, Product photography, product studio, sharp focus, digital art, hyper-realistic, 4K, Unreal Engine, Highly Detailed, HD, Dramatic Lighting by Brom, trending on Artstation' +', '+ a_prompt] * num_samples)]} | |
| un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} | |
| shape = (4, H // 8, W // 8) | |
| if config.save_memory: | |
| model.low_vram_shift(is_diffusing=True) | |
| model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 | |
| samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, | |
| shape, cond, verbose=False, eta=eta, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=un_cond) | |
| if config.save_memory: | |
| model.low_vram_shift(is_diffusing=False) | |
| x_samples = model.decode_first_stage(samples) | |
| x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
| results = np.array([x_samples[i] for i in range(num_samples)]) | |
| #img_x= Image.fromarray(img_x) | |
| #results=Image.fromarray(results) | |
| # img_rembg=Image.fromarray(img_rembg) | |
| # img_rembg=img_rembg.convert("RGBA") | |
| in_img=im_rgbx.copy() | |
| im_img=im_rgbx.copy() | |
| # width, height = in_img.size | |
| # print(img_rembg) | |
| # alpha = in_img.split()[-1] | |
| # in_img = Image.merge('RGBA', [in_img.split()[0], in_img.split()[1], in_img.split()[2], alpha.point(lambda x: 255 if x > 0 else 0)]) | |
| background = Image.new("RGBA", in_img.size, (0, 0, 0,0)) | |
| # in_img = Image.alpha_composite(background, in_img) | |
| background.paste(in_img, in_img) | |
| # Convert the transparent background to an RGB mode | |
| # rgb_bg_img = bg_img.convert('RGB') | |
| in_img = background.convert("RGB") | |
| in_img=asarray(in_img) | |
| im_img=asarray(im_img) | |
| in_img = resize_image(in_img, image_resolution) | |
| im_img = resize_image(im_img, image_resolution) | |
| im_img=Image.fromarray(im_img) | |
| #in_img=in_img.resize(512,512) | |
| # umg_y_k=asarray(in_img) | |
| in_img=Image.fromarray(in_img) | |
| umg_y_k=in_img.copy() | |
| img_x_r=in_img.copy() | |
| umg_y_k=asarray(umg_y_k) | |
| img_x_r=asarray(img_x_r) | |
| # for x in range(512): | |
| # for y in range(512): | |
| # # Get the pixel value as a tuple (R,G,B) | |
| # pixel = img_x_r[x,y] | |
| # # Check each channel and change any pixel with a value of 253 to 255 | |
| # if pixel[0] == 253 or pixel[0]==254: | |
| # pixel = (255, pixel[1], pixel[2]) | |
| # if pixel[1] == 253 or pixel[1] == 254: | |
| # pixel = (pixel[0], 255, pixel[2]) | |
| # if pixel[2] == 253 or pixel[2] == 254: | |
| # pixel = (pixel[0], pixel[1], 255) | |
| # # Update the pixel value in the image | |
| # img_x_r[x,y]=pixel | |
| # results=cv2.imread(results) | |
| xxsample=[] | |
| # Y,X=np.where(np.all(img_x_r==[0,0,0],axis=2)) | |
| # Y, X = np.where(np.all((img_x_r < 8) & (img_x_r == img_x_r[:,:,0][:,:,np.newaxis]), axis=2)) | |
| # p,q=np.where(np.all(img_x_r==[254,254,254],axis=2)) | |
| for i in range(num_samples): | |
| results=results[i] | |
| # img_x_r[np.where(np.all((img_x_r < 8) & (img_x_r == img_x_r[:,:,0][:,:,np.newaxis]), axis=2))]=results[Y,X] | |
| # img_x_r[np.where(np.all(img_x_r==[0,0,0],axis=2))]=results[Y,X] | |
| results = resize_image(results, image_resolution) | |
| results=Image.fromarray(results) | |
| results.paste(im_img, im_img) | |
| img_x_r=asarray(results) | |
| xxsample.append(img_x_r) | |
| # print(results.shape) | |
| print(img_x_r.shape) | |
| img_txx=[xxsample[i] for i in range (num_samples)] | |
| #img_x=asarray(img_x) | |
| #return [detected_map] + img_txx | |
| return img_x_r | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("## Background Generator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(source='upload', type="numpy") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button(label="Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
| image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64) | |
| strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) | |
| guess_mode = gr.Checkbox(label='Guess Mode', value=False) | |
| detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1) | |
| ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
| scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
| eta = gr.Number(label="eta (DDIM)", value=0.0) | |
| a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
| n_prompt = gr.Textbox(label="Negative Prompt", | |
| value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
| with gr.Column(): | |
| #result_gallery = gr.Textbox() | |
| #result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
| result_gallery = gr.Image(label="Result I") | |
| ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta] | |
| run_button.click(fn=process, inputs=ips, outputs=result_gallery,api_name="process") | |
| block.launch(show_api=True, show_error=True,enable_queue=True, debug=True) | |