Spaces:
Runtime error
Runtime error
| import random | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from torchvision.transforms.functional import InterpolationMode | |
| from tools.controlnet.annotator.hed import HEDdetector | |
| from tools.controlnet.annotator.util import HWC3, nms, resize_image | |
| preprocessor = None | |
| def transform_control_signal(control_signal, hw): | |
| if isinstance(control_signal, str): | |
| control_signal = Image.open(control_signal) | |
| elif isinstance(control_signal, Image.Image): | |
| control_signal = control_signal | |
| elif isinstance(control_signal, np.ndarray): | |
| control_signal = Image.fromarray(control_signal) | |
| else: | |
| raise ValueError("control_signal must be a path or a PIL.Image.Image or a numpy array") | |
| transform = T.Compose( | |
| [ | |
| T.Lambda(lambda img: img.convert("RGB")), | |
| T.Resize((int(hw[0, 0]), int(hw[0, 1])), interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC | |
| T.CenterCrop((int(hw[0, 0]), int(hw[0, 1]))), | |
| T.ToTensor(), | |
| T.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| return transform(control_signal).unsqueeze(0) | |
| def get_scribble_map(input_image, det, detect_resolution=512, thickness=None): | |
| """ | |
| Generate scribble map from input image | |
| Args: | |
| input_image: Input image (numpy array, HWC format) | |
| det: Detector type ('Scribble_HED', 'Scribble_PIDI', 'None') | |
| detect_resolution: Processing resolution | |
| thickness: Line thickness (between 0-24, None for random) | |
| Returns: | |
| Processed scribble map | |
| """ | |
| global preprocessor | |
| # Initialize detector | |
| if "HED" in det and not isinstance(preprocessor, HEDdetector): | |
| preprocessor = HEDdetector() | |
| input_image = HWC3(input_image) | |
| if det == "None": | |
| detected_map = input_image.copy() | |
| else: | |
| # Generate scribble map | |
| detected_map = preprocessor(resize_image(input_image, detect_resolution)) | |
| detected_map = HWC3(detected_map) | |
| # Post-processing | |
| detected_map = nms(detected_map, 127, 3.0) | |
| detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) | |
| detected_map[detected_map > 4] = 255 | |
| detected_map[detected_map < 255] = 0 | |
| # Control line thickness | |
| if thickness is None: | |
| thickness = random.randint(0, 24) # Random thickness, including 0 | |
| if thickness == 0: | |
| # Use erosion operation to get thinner lines | |
| kernel = np.ones((4, 4), np.uint8) | |
| detected_map = cv2.erode(detected_map, kernel, iterations=1) | |
| elif thickness > 1: | |
| kernel_size = thickness // 2 | |
| kernel = np.ones((kernel_size, kernel_size), np.uint8) | |
| detected_map = cv2.dilate(detected_map, kernel, iterations=1) | |
| return detected_map | |