toto10's picture
49dd4a12034a58d8195abccf923cb1e9c2031af1a7963206e766ea35985d7fb7
7e74c48
raw
history blame
4.29 kB
import os
import cv2
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
from modules import devices
from annotator.annotator_path import models_path
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
# Initial convolution block
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
# Downsampling
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
# Residual blocks
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
# Upsampling
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
# Output layer
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
class LineartDetector:
model_dir = os.path.join(models_path, "lineart")
model_default = 'sk_model.pth'
model_coarse = 'sk_model2.pth'
def __init__(self, model_name):
self.model = None
self.model_name = model_name
self.device = devices.get_device_for("controlnet")
def load_model(self, name):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
model_path = os.path.join(self.model_dir, name)
if not os.path.exists(model_path):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=self.model_dir)
model = Generator(3, 1, 3)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
self.model = model.to(self.device)
def unload_model(self):
if self.model is not None:
self.model.cpu()
def __call__(self, input_image):
if self.model is None:
self.load_model(self.model_name)
self.model.to(self.device)
assert input_image.ndim == 3
image = input_image
with torch.no_grad():
image = torch.from_numpy(image).float().to(self.device)
image = image / 255.0
image = rearrange(image, 'h w c -> 1 c h w')
line = self.model(image)[0][0]
line = line.cpu().numpy()
line = (line * 255.0).clip(0, 255).astype(np.uint8)
return line