Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| import os | |
| from einops import rearrange | |
| from .util import HWC3, nms, safe_step, resize_image_with_pad, common_input_validate, get_intensity_mask, combine_layers | |
| from .pidi import pidinet | |
| from .ted import TED | |
| from .lineart import Generator as LineartGenerator | |
| from .informative_drawing import Generator | |
| from .hed import ControlNetHED_Apache2 | |
| from pathlib import Path | |
| from skimage import morphology | |
| import argparse | |
| from tqdm import tqdm | |
| PREPROCESSORS_ROOT = os.getenv("PREPROCESSORS_ROOT", os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), "models/preprocessors")) | |
| class HEDDetector: | |
| def __init__(self, netNetwork): | |
| self.netNetwork = netNetwork | |
| self.device = "cpu" | |
| def from_pretrained(cls, filename="ControlNetHED.pth"): | |
| model_path = os.path.join(PREPROCESSORS_ROOT, filename) | |
| netNetwork = ControlNetHED_Apache2() | |
| netNetwork.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| netNetwork.float().eval() | |
| return cls(netNetwork) | |
| def to(self, device): | |
| self.netNetwork.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, upscale_method="INTER_CUBIC", **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| assert input_image.ndim == 3 | |
| H, W, C = input_image.shape | |
| with torch.no_grad(): | |
| image_hed = torch.from_numpy(input_image).float().to(self.device) | |
| image_hed = rearrange(image_hed, 'h w c -> 1 c h w') | |
| edges = self.netNetwork(image_hed) | |
| edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] | |
| edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] | |
| edges = np.stack(edges, axis=2) | |
| edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) | |
| if safe: | |
| edge = safe_step(edge) | |
| edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | |
| detected_map = edge | |
| if scribble: | |
| detected_map = nms(detected_map, 127, 3.0) | |
| detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) | |
| detected_map[detected_map > 4] = 255 | |
| detected_map[detected_map < 255] = 0 | |
| detected_map = HWC3(remove_pad(detected_map)) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map | |
| class CannyDetector: | |
| def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| detected_map = cv2.Canny(detected_map, low_threshold, high_threshold) | |
| detected_map = HWC3(remove_pad(detected_map)) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map | |
| class PidiNetDetector: | |
| def __init__(self, netNetwork): | |
| self.netNetwork = netNetwork | |
| self.device = "cpu" | |
| def from_pretrained(cls, filename="table5_pidinet.pth"): | |
| model_path = os.path.join(PREPROCESSORS_ROOT, filename) | |
| netNetwork = pidinet() | |
| netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) | |
| netNetwork.eval() | |
| return cls(netNetwork) | |
| def to(self, device): | |
| self.netNetwork.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, input_image, detect_resolution=512, safe=False, output_type=None, scribble=True, apply_filter=False, upscale_method="INTER_CUBIC", **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| detected_map = detected_map[:, :, ::-1].copy() | |
| with torch.no_grad(): | |
| image_pidi = torch.from_numpy(detected_map).float().to(self.device) | |
| image_pidi = image_pidi / 255.0 | |
| image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') | |
| edge = self.netNetwork(image_pidi)[-1] | |
| edge = edge.cpu().numpy() | |
| if apply_filter: | |
| edge = edge > 0.5 | |
| if safe: | |
| edge = safe_step(edge) | |
| edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | |
| detected_map = edge[0, 0] | |
| if scribble: | |
| detected_map = nms(detected_map, 127, 3.0) | |
| detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) | |
| detected_map[detected_map > 4] = 255 | |
| detected_map[detected_map < 255] = 0 | |
| detected_map = HWC3(remove_pad(detected_map)) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map | |
| class TEDDetector: | |
| def __init__(self, model): | |
| self.model = model | |
| self.device = "cpu" | |
| def from_pretrained(cls, filename="7_model.pth"): | |
| model_path = os.path.join(PREPROCESSORS_ROOT, filename) | |
| model = TED() | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| return cls(model) | |
| def to(self, device): | |
| self.model.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, input_image, detect_resolution=512, safe_steps=2, upscale_method="INTER_CUBIC", output_type=None, **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| H, W, _ = input_image.shape | |
| with torch.no_grad(): | |
| image_teed = torch.from_numpy(input_image.copy()).float().to(self.device) | |
| image_teed = rearrange(image_teed, 'h w c -> 1 c h w') | |
| edges = self.model(image_teed) | |
| edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] | |
| edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] | |
| edges = np.stack(edges, axis=2) | |
| edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) | |
| if safe_steps != 0: | |
| edge = safe_step(edge, safe_steps) | |
| edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | |
| detected_map = remove_pad(HWC3(edge)) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map[..., :3]) | |
| return detected_map | |
| class LineartStandardDetector: | |
| def __call__(self, input_image=None, guassian_sigma=6.0, intensity_threshold=8, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| x = input_image.astype(np.float32) | |
| g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) | |
| intensity = np.min(g - x, axis=2).clip(0, 255) | |
| intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) | |
| intensity *= 127 | |
| detected_map = intensity.clip(0, 255).astype(np.uint8) | |
| detected_map = HWC3(remove_pad(detected_map)) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map | |
| class AnyLinePreprocessor: | |
| def __init__(self, mteed_model, lineart_standard_detector): | |
| self.device = "cpu" | |
| self.mteed_model = mteed_model | |
| self.lineart_standard_detector = lineart_standard_detector | |
| def from_pretrained(cls, mteed_filename="MTEED.pth"): | |
| mteed_model = TEDDetector.from_pretrained(filename=mteed_filename) | |
| lineart_standard_detector = LineartStandardDetector() | |
| return cls(mteed_model, lineart_standard_detector) | |
| def to(self, device): | |
| self.mteed_model.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, image, resolution=512, lineart_lower_bound=0, lineart_upper_bound=1, object_min_size=36, object_connectivity=1): | |
| # Process the image with MTEED model | |
| mteed_result = self.mteed_model(image, detect_resolution=resolution) | |
| # Process the image with the lineart standard preprocessor | |
| lineart_result = self.lineart_standard_detector(image, guassian_sigma=2, intensity_threshold=3, resolution=resolution) | |
| _lineart_result = get_intensity_mask(lineart_result, lower_bound=lineart_lower_bound, upper_bound=lineart_upper_bound) | |
| _cleaned = morphology.remove_small_objects(_lineart_result.astype(bool), min_size=object_min_size, connectivity=object_connectivity) | |
| _lineart_result = _lineart_result * _cleaned | |
| _mteed_result = mteed_result | |
| result = combine_layers(_mteed_result, _lineart_result) | |
| # print(result.shape) | |
| return result | |
| class LineartDetector: | |
| def __init__(self, model, coarse_model): | |
| self.model = model | |
| self.model_coarse = coarse_model | |
| self.device = "cpu" | |
| def from_pretrained(cls, filename="sk_model.pth", coarse_filename="sk_model2.pth"): | |
| model_path = os.path.join(PREPROCESSORS_ROOT, filename) | |
| coarse_model_path = os.path.join(PREPROCESSORS_ROOT, coarse_filename) | |
| model = LineartGenerator(3, 1, 3) | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| coarse_model = LineartGenerator(3, 1, 3) | |
| coarse_model.load_state_dict(torch.load(coarse_model_path, map_location="cpu")) | |
| coarse_model.eval() | |
| return cls(model, coarse_model) | |
| def to(self, device): | |
| self.model.to(device) | |
| self.model_coarse.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, input_image, coarse=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| model = self.model_coarse if coarse else self.model | |
| assert detected_map.ndim == 3 | |
| with torch.no_grad(): | |
| image = torch.from_numpy(detected_map).float().to(self.device) | |
| image = image / 255.0 | |
| image = rearrange(image, 'h w c -> 1 c h w') | |
| line = model(image)[0][0] | |
| line = line.cpu().numpy() | |
| line = (line * 255.0).clip(0, 255).astype(np.uint8) | |
| detected_map = HWC3(line) | |
| detected_map = remove_pad(255 - detected_map) | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map | |
| class InformativeDetector: | |
| def __init__(self, anime_model, contour_model): | |
| self.anime_model = anime_model | |
| self.contour_model = contour_model | |
| self.device = "cpu" | |
| def from_pretrained(cls, anime_filename="anime_style.pth", contour_filename="contour_style.pth"): | |
| anime_model_path = os.path.join(PREPROCESSORS_ROOT, anime_filename) | |
| contour_model_path = os.path.join(PREPROCESSORS_ROOT, contour_filename) | |
| # 创建两个Generator模型 | |
| anime_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3 | |
| anime_model.load_state_dict(torch.load(anime_model_path, map_location="cpu")) | |
| anime_model.eval() | |
| contour_model = Generator(3, 1, 3) # input_nc=3, output_nc=1, n_blocks=3 | |
| contour_model.load_state_dict(torch.load(contour_model_path, map_location="cpu")) | |
| contour_model.eval() | |
| return cls(anime_model, contour_model) | |
| def to(self, device): | |
| self.anime_model.to(device) | |
| self.contour_model.to(device) | |
| self.device = device | |
| return self | |
| def __call__(self, input_image, style="anime", detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): | |
| """ | |
| 提取sketch | |
| Args: | |
| input_image: 输入图像 | |
| style: "anime" 或 "contour" | |
| detect_resolution: 检测分辨率 | |
| output_type: 输出类型 | |
| upscale_method: 上采样方法 | |
| """ | |
| input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
| detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) | |
| # 选择模型 | |
| model = self.anime_model if style == "anime" else self.contour_model | |
| assert detected_map.ndim == 3 | |
| with torch.no_grad(): | |
| image = torch.from_numpy(detected_map).float().to(self.device) | |
| image = image / 255.0 | |
| # 转换维度 (h, w, c) -> (1, c, h, w) | |
| image = image.permute(2, 0, 1).unsqueeze(0) | |
| # 生成sketch | |
| sketch = model(image) | |
| sketch = sketch[0][0] # 取出第一个batch的第一个通道 | |
| sketch = sketch.cpu().numpy() | |
| sketch = (sketch * 255.0).clip(0, 255).astype(np.uint8) | |
| detected_map = HWC3(sketch) | |
| detected_map = remove_pad(255 - detected_map) # 反转颜色 | |
| if output_type == "pil": | |
| detected_map = Image.fromarray(detected_map) | |
| return detected_map |