PhotoAI_Mellow / app.py
nik1-kaj's picture
Update app.py
48093c9
raw
history blame
34.3 kB
from share import *
import config
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
###
import cv2
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import matplotlib.pyplot as plt
import warnings
###
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.hed import HEDdetector, nms
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
apply_hed = HEDdetector()
model = create_model('./models/cldm_v15.yaml').cpu()
#model.load_state_dict(load_state_dict('./control_sd15_scribble.pth', location='cuda'))
ddim_sampler = DDIMSampler(model)
from safetensors.torch import load_file as safe_load_file #add
pl_sd = safe_load_file('./Realistic_Vision_V2.0.safetensors') #add
model.load_state_dict(load_state_dict('./Realistic_Vision_V2.0.safetensors', location='cuda'),strict=False) #add
model.control_model.load_state_dict(load_state_dict('./control_scribble-fp16.safetensors',location='cuda'))
#model.load_state_dict(load_state_dict(pl_sd, strict=False)) #add
model = model.cuda()
#########
########
import torch
#
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
bce_loss = nn.BCELoss(size_average=True)
def muti_loss_fusion(preds, target):
loss0 = 0.0
loss = 0.0
for i in range(0,len(preds)):
# print("i: ", i, preds[i].shape)
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
# tmp_target = _upsample_like(target,preds[i])
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
loss = loss + bce_loss(preds[i],tmp_target)
else:
loss = loss + bce_loss(preds[i],target)
if(i==0):
loss0 = loss
return loss0, loss
fea_loss = nn.MSELoss(size_average=True)
kl_loss = nn.KLDivLoss(size_average=True)
l1_loss = nn.L1Loss(size_average=True)
smooth_l1_loss = nn.SmoothL1Loss(size_average=True)
def muti_loss_fusion_kl(preds, target, dfs, fs, mode='MSE'):
loss0 = 0.0
loss = 0.0
for i in range(0,len(preds)):
# print("i: ", i, preds[i].shape)
if(preds[i].shape[2]!=target.shape[2] or preds[i].shape[3]!=target.shape[3]):
# tmp_target = _upsample_like(target,preds[i])
tmp_target = F.interpolate(target, size=preds[i].size()[2:], mode='bilinear', align_corners=True)
loss = loss + bce_loss(preds[i],tmp_target)
else:
loss = loss + bce_loss(preds[i],target)
if(i==0):
loss0 = loss
for i in range(0,len(dfs)):
if(mode=='MSE'):
loss = loss + fea_loss(dfs[i],fs[i]) ### add the mse loss of features as additional constraints
# print("fea_loss: ", fea_loss(dfs[i],fs[i]).item())
elif(mode=='KL'):
loss = loss + kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1))
# print("kl_loss: ", kl_loss(F.log_softmax(dfs[i],dim=1),F.softmax(fs[i],dim=1)).item())
elif(mode=='MAE'):
loss = loss + l1_loss(dfs[i],fs[i])
# print("ls_loss: ", l1_loss(dfs[i],fs[i]))
elif(mode=='SmoothL1'):
loss = loss + smooth_l1_loss(dfs[i],fs[i])
# print("SmoothL1: ", smooth_l1_loss(dfs[i],fs[i]).item())
return loss0, loss
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7,self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
class myrebnconv(nn.Module):
def __init__(self, in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1):
super(myrebnconv,self).__init__()
self.conv = nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)
def forward(self,x):
return self.rl(self.bn(self.conv(x)))
class ISNetGTEncoder(nn.Module):
def __init__(self,in_ch=1,out_ch=1):
super(ISNetGTEncoder,self).__init__()
self.conv_in = myrebnconv(in_ch,16,3,stride=2,padding=1) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.stage1 = RSU7(16,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,32,128)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(128,32,256)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(256,64,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,64,512)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
def compute_loss(self, preds, targets):
return muti_loss_fusion(preds,targets)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
# hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
#side output
d1 = self.side1(hx1)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1,hx2,hx3,hx4,hx5,hx6]
class ISNetDIS(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(ISNetDIS,self).__init__()
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage1 = RSU7(64,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def compute_loss_kl(self, preds, targets, dfs, fs, mode='MSE'):
# return muti_loss_fusion(preds,targets)
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
def compute_loss(self, preds, targets):
# return muti_loss_fusion(preds,targets)
return muti_loss_fusion(preds, targets)
def forward(self,x):
hx = x
hxin = self.conv_in(hx)
#hx = self.pool_in(hxin)
#stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1,x)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,x)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,x)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,x)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,x)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,x)
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
###
##
######
warnings.filterwarnings("ignore")
from data_loader_cache import normalize, im_reader, im_preprocess
from models import *
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GOSNormalize(object):
'''
Normalize the Image using torch.transforms
'''
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
self.mean = mean
self.std = std
def __call__(self,image):
image = normalize(image,self.mean,self.std)
return image
transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
def load_image(im_path, hypar):
#im = im_reader(im_path)
im, im_shp = im_preprocess(im_path, hypar["cache_size"])
im = torch.divide(im,255.0)
shape = torch.from_numpy(np.array(im_shp))
return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
def build_model(hypar,device):
net = hypar["model"]#GOSNETINC(3,1)
# convert to half precision
if(hypar["model_digit"]=="half"):
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
net.to(device)
if(hypar["restore_model"]!=""):
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
net.to(device)
net.eval()
return net
def predict(net, inputs_val, shapes_val, hypar, device):
'''
Given an Image, predict the mask
'''
net.eval()
if(hypar["model_digit"]=="full"):
inputs_val = inputs_val.type(torch.FloatTensor)
else:
inputs_val = inputs_val.type(torch.HalfTensor)
inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
ds_val = net(inputs_val_v)[0] # list of 6 results
pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val-mi)/(ma-mi) # max = 1
if device == 'cuda': torch.cuda.empty_cache()
return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
# Set Parameters
hypar = {} # paramters for inferencing
hypar["model_path"] ="./model" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
## choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
## data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
hypar["model"] = ISNetDIS()
# Build Model
net = build_model(hypar, device)
######
from numpy import asarray
from PIL import Image, ImageEnhance, ImageFilter
########
from diffusers import (ControlNetModel, DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
import gc
######
from rembg import remove
from PIL import Image
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
with torch.no_grad():
image = input_image
w, h = 512, 512
data = np.zeros((h, w, 3), dtype=np.uint8)
data[0:256, 0:256] = [255, 0, 0] # red patch in upper left
img = Image.fromarray(input_image)
kmg = Image.fromarray(input_image)
# image_tensor, orig_size = load_image(input_image, hypar)
# mask = predict(net, image_tensor, orig_size, hypar, device)
# pil_mask = Image.fromarray(mask).convert('L')
# pil_mask1=pil_mask.copy()
####
# pil_mask1=asarray(pil_mask1)
# pil_mask1[pil_mask1>0]=255
# pil_mask1=Image.fromarray(pil_mask1).convert('L')
# pil_mask1 = pil_mask1.filter(ImageFilter.GaussianBlur(radius=1))
##dis
output = remove(img)
im_rgb = output #img.convert('RGB')
im_rgx = output #img.convert('RGB')
img_enhancer = ImageEnhance.Brightness(im_rgb)
factor = 0.09
im_rgb = img_enhancer.enhance(factor)
im_rgba = im_rgb.copy()
im_rgbx=im_rgx.copy()
# im_rgba.putalpha(pil_mask)
# im_rgbx.putalpha(pil_mask1)
#dis end
# img=asarray(im_rgx.copy())
# # Find the contours of the masked object
# contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
# # Find the bounding box of the masked object
# x, y, w, h = cv2.boundingRect(contours[0])
# # Create a mask for the background
# bg_mask = np.zeros(img.shape[:2], dtype=np.uint8)
# bg_mask[y:y+h, x:x+w] = 255
# # Create a blurred version of the mask
# blur_mask = cv2.GaussianBlur(mask, (15, 15), 0)
# # Perform seamless cloning
# im_rgbx = cv2.seamlessClone(img, img, blur_mask, (x + w // 2, y + h // 2), cv2.NORMAL_CLONE)
input_image = asarray(im_rgba)
# input_image = asarray(img_rembg)
###############
inp_img=asarray(im_rgbx)
inp_img = HWC3(inp_img)
detected_map = apply_hed(resize_image(inp_img, detect_resolution))
detected_map = HWC3(detected_map)
img_x = resize_image(inp_img, image_resolution)
############
input_image = HWC3(input_image)
detected_map = apply_hed(resize_image(input_image, detect_resolution))
detected_map = HWC3(detected_map)
img = resize_image(input_image, image_resolution)
H, W, C = img.shape
#####
# control_image = np.zeros_like(img, dtype=np.uint8)
# control_image[np.min(img, axis=2) < 127] = 255
# vis_control_image = 255 - control_image
# control_image, vis_control_image= Image.fromarray(control_image),Image.fromarray(vis_control_image)
# model_id = '/content/drive/MyDrive/sasha/control_sd15_scribble.pth'
# controlnet = ControlNetModel.from_pretrained(model_id,
# torch_dtype=torch.float16)
# base_model_id='/content/drive/MyDrive/sasha/Realistic_Vision_V1.3.safetensors'
# pipe = StableDiffusionControlNetPipeline.from_pretrained(
# base_model_id,
# safety_checker=None,
# controlnet=controlnet,
# torch_dtype=torch.float16)
# pipe.scheduler = UniPCMultistepScheduler.from_config(
# pipe.scheduler.config)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.to(device)
# torch.cuda.empty_cache()
# gc.collect()
# if seed == -1:
# seed = np.random.randint(0, np.iinfo(np.int64).max)
# generator = torch.Generator().manual_seed(seed)
# resolt= pipe(prompt=prompt,
# negative_prompt=n_prompt,
# guidance_scale=scale,
# num_images_per_prompt=num_samples,
# num_inference_steps=ddim_steps,
# generator=generator,
# image=control_image).images
#####################################
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning(['RAW photo,'+prompt +', '+', minimal product photo, In the style of David Newton, Helen Koker, Aneta Laura, Nikki Astwood, Amy Shamblen, Hyperrealism, soft smooth lighting, luxury, pinterest, Product photography, product studio, sharp focus, digital art, hyper-realistic, 4K, Unreal Engine, Highly Detailed, HD, Dramatic Lighting by Brom, trending on Artstation' +', '+ a_prompt] * num_samples)]}
un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
if config.save_memory:
model.low_vram_shift(is_diffusing=True)
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = np.array([x_samples[i] for i in range(num_samples)])
#img_x= Image.fromarray(img_x)
#results=Image.fromarray(results)
# img_rembg=Image.fromarray(img_rembg)
# img_rembg=img_rembg.convert("RGBA")
in_img=im_rgbx.copy()
im_img=im_rgbx.copy()
# width, height = in_img.size
# print(img_rembg)
# alpha = in_img.split()[-1]
# in_img = Image.merge('RGBA', [in_img.split()[0], in_img.split()[1], in_img.split()[2], alpha.point(lambda x: 255 if x > 0 else 0)])
background = Image.new("RGBA", in_img.size, (0, 0, 0,0))
# in_img = Image.alpha_composite(background, in_img)
background.paste(in_img, in_img)
# Convert the transparent background to an RGB mode
# rgb_bg_img = bg_img.convert('RGB')
in_img = background.convert("RGB")
in_img=asarray(in_img)
im_img=asarray(im_img)
in_img = resize_image(in_img, image_resolution)
im_img = resize_image(im_img, image_resolution)
im_img=Image.fromarray(im_img)
#in_img=in_img.resize(512,512)
# umg_y_k=asarray(in_img)
in_img=Image.fromarray(in_img)
umg_y_k=in_img.copy()
img_x_r=in_img.copy()
umg_y_k=asarray(umg_y_k)
img_x_r=asarray(img_x_r)
# for x in range(512):
# for y in range(512):
# # Get the pixel value as a tuple (R,G,B)
# pixel = img_x_r[x,y]
# # Check each channel and change any pixel with a value of 253 to 255
# if pixel[0] == 253 or pixel[0]==254:
# pixel = (255, pixel[1], pixel[2])
# if pixel[1] == 253 or pixel[1] == 254:
# pixel = (pixel[0], 255, pixel[2])
# if pixel[2] == 253 or pixel[2] == 254:
# pixel = (pixel[0], pixel[1], 255)
# # Update the pixel value in the image
# img_x_r[x,y]=pixel
# results=cv2.imread(results)
xxsample=[]
# Y,X=np.where(np.all(img_x_r==[0,0,0],axis=2))
# Y, X = np.where(np.all((img_x_r < 8) & (img_x_r == img_x_r[:,:,0][:,:,np.newaxis]), axis=2))
# p,q=np.where(np.all(img_x_r==[254,254,254],axis=2))
for i in range(num_samples):
results=results[i]
# img_x_r[np.where(np.all((img_x_r < 8) & (img_x_r == img_x_r[:,:,0][:,:,np.newaxis]), axis=2))]=results[Y,X]
# img_x_r[np.where(np.all(img_x_r==[0,0,0],axis=2))]=results[Y,X]
results = resize_image(results, image_resolution)
results=Image.fromarray(results)
results.paste(im_img, im_img)
img_x_r=asarray(results)
xxsample.append(img_x_r)
# print(results.shape)
print(img_x_r.shape)
img_txx=[xxsample[i] for i in range (num_samples)]
#img_x=asarray(img_x)
#return [detected_map] + img_txx
return img_x_r
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Background Generator")
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="numpy")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
detect_resolution = gr.Slider(label="HED Resolution", minimum=128, maximum=1024, value=512, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
eta = gr.Number(label="eta (DDIM)", value=0.0)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
with gr.Column():
#result_gallery = gr.Textbox()
#result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
result_gallery = gr.Image(label="Result I")
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
run_button.click(fn=process, inputs=ips, outputs=result_gallery,api_name="process")
block.launch(show_api=True, show_error=True,enable_queue=True, debug=True)