Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from typing import List, ClassVar, Dict, Optional, Set | |
| from dataclasses import dataclass, field | |
| from modules import shared | |
| from scripts.logging import logger | |
| from scripts.utils import ndarray_lru_cache | |
| CACHE_SIZE = getattr(shared.cmd_opts, "controlnet_preprocessor_cache_size", 0) | |
| class PreprocessorParameter: | |
| """ | |
| Class representing a parameter for a preprocessor. | |
| Attributes: | |
| label (str): The label for the parameter. | |
| minimum (float): The minimum value of the parameter. Default is 0.0. | |
| maximum (float): The maximum value of the parameter. Default is 1.0. | |
| step (float): The step size for the parameter. Default is 0.01. | |
| value (float): The initial value of the parameter. Default is 0.5. | |
| visible (bool): Whether the parameter is visible or not. Default is False. | |
| """ | |
| label: str = "EMPTY_LABEL" | |
| minimum: float = 0.0 | |
| maximum: float = 1.0 | |
| step: float = 0.01 | |
| value: float = 0.5 | |
| visible: bool = True | |
| def gradio_update_kwargs(self) -> dict: | |
| return dict( | |
| minimum=self.minimum, | |
| maximum=self.maximum, | |
| step=self.step, | |
| label=self.label, | |
| value=self.value, | |
| visible=self.visible, | |
| ) | |
| def api_json(self) -> dict: | |
| return dict( | |
| name=self.label, | |
| value=self.value, | |
| min=self.minimum, | |
| max=self.maximum, | |
| step=self.step, | |
| ) | |
| class Preprocessor(ABC): | |
| """ | |
| Class representing a preprocessor. | |
| Attributes: | |
| name (str): The name of the preprocessor. | |
| tags (List[str]): The tags associated with the preprocessor. | |
| slider_resolution (PreprocessorParameter): The parameter representing the resolution of the slider. | |
| slider_1 (PreprocessorParameter): The first parameter of the slider. | |
| slider_2 (PreprocessorParameter): The second parameter of the slider. | |
| slider_3 (PreprocessorParameter): The third parameter of the slider. | |
| show_control_mode (bool): Whether to show the control mode or not. | |
| do_not_need_model (bool): Whether the preprocessor needs a model or not. | |
| sorting_priority (int): The sorting priority of the preprocessor. | |
| corp_image_with_a1111_mask_when_in_img2img_inpaint_tab (bool): Whether to crop the image with a1111 mask when in img2img inpaint tab or not. | |
| fill_mask_with_one_when_resize_and_fill (bool): Whether to fill the mask with one when resizing and filling or not. | |
| use_soft_projection_in_hr_fix (bool): Whether to use soft projection in hr fix or not. | |
| expand_mask_when_resize_and_fill (bool): Whether to expand the mask when resizing and filling or not. | |
| """ | |
| name: str | |
| _label: str = None | |
| tags: List[str] = field(default_factory=list) | |
| slider_resolution = PreprocessorParameter( | |
| label="Resolution", | |
| minimum=64, | |
| maximum=2048, | |
| value=512, | |
| step=8, | |
| visible=True, | |
| ) | |
| slider_1 = PreprocessorParameter(visible=False) | |
| slider_2 = PreprocessorParameter(visible=False) | |
| slider_3 = PreprocessorParameter(visible=False) | |
| returns_image: bool = True | |
| show_control_mode = True | |
| do_not_need_model = False | |
| sorting_priority = 0 # higher goes to top in the list | |
| corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True | |
| fill_mask_with_one_when_resize_and_fill = False | |
| use_soft_projection_in_hr_fix = False | |
| expand_mask_when_resize_and_fill = False | |
| all_processors: ClassVar[Dict[str, "Preprocessor"]] = {} | |
| all_processors_by_name: ClassVar[Dict[str, "Preprocessor"]] = {} | |
| def label(self) -> str: | |
| """Display name on UI.""" | |
| return self._label if self._label is not None else self.name | |
| def add_supported_preprocessor(cls, p: "Preprocessor"): | |
| assert p.label not in cls.all_processors, f"{p.label} already registered!" | |
| cls.all_processors[p.label] = p | |
| assert p.name not in cls.all_processors_by_name, f"{p.name} already registered!" | |
| cls.all_processors_by_name[p.name] = p | |
| logger.debug(f"{p.name} registered. Total preprocessors ({len(cls.all_processors)}).") | |
| def get_preprocessor(cls, name: str) -> Optional["Preprocessor"]: | |
| return cls.all_processors.get(name, cls.all_processors_by_name.get(name, None)) | |
| def get_sorted_preprocessors(cls) -> List["Preprocessor"]: | |
| preprocessors = [p for k, p in cls.all_processors.items() if k != "none"] | |
| return [cls.all_processors["none"]] + sorted( | |
| preprocessors, | |
| key=lambda x: str(x.sorting_priority).zfill(8) + x.label, | |
| reverse=True, | |
| ) | |
| def get_all_preprocessor_tags(cls): | |
| tags = set() | |
| for _, p in cls.all_processors.items(): | |
| tags.update(set(p.tags)) | |
| return ["All"] + sorted(list(tags)) | |
| def get_filtered_preprocessors(cls, tag: str) -> List["Preprocessor"]: | |
| if tag == "All": | |
| return cls.all_processors | |
| return [ | |
| p | |
| for p in cls.get_sorted_preprocessors() | |
| if tag in p.tags or p.label == "none" | |
| ] | |
| def get_default_preprocessor(cls, tag: str) -> "Preprocessor": | |
| ps = cls.get_filtered_preprocessors(tag) | |
| assert len(ps) > 0 | |
| return ps[0] if len(ps) == 1 else ps[1] | |
| def tag_to_filters(cls, tag: str) -> Set[str]: | |
| filters_aliases = { | |
| "instructp2p": ["ip2p"], | |
| "segmentation": ["seg"], | |
| "normalmap": ["normal"], | |
| "t2i-adapter": ["t2i_adapter", "t2iadapter", "t2ia"], | |
| "ip-adapter": ["ip_adapter", "ipadapter"], | |
| "openpose": ["openpose", "densepose"], | |
| "instant-id": ["instant_id", "instantid"], | |
| "scribble": ["sketch"], | |
| "tile": ["blur"], | |
| } | |
| tag = tag.lower() | |
| return set([tag] + filters_aliases.get(tag, [])) | |
| def cached_call(self, *args, **kwargs): | |
| logger.debug(f"Calling preprocessor {self.name} outside of cache.") | |
| return self(*args, **kwargs) | |
| def __hash__(self): | |
| return hash(self.name) | |
| def __eq__(self, other): | |
| return self.__hash__() == other.__hash__() | |
| def __call__( | |
| self, | |
| input_image, | |
| resolution, | |
| slider_1=None, | |
| slider_2=None, | |
| slider_3=None, | |
| input_mask=None, | |
| **kwargs, | |
| ): | |
| pass | |