| import PIL.Image
|
| import torch, gc
|
| from controlnet_aux_local import NormalBaeDetector
|
|
|
| class Preprocessor:
|
| MODEL_ID = "lllyasviel/Annotators"
|
|
|
| def __init__(self):
|
| self.model = None
|
| self.name = ""
|
|
|
| def load(self, name: str) -> None:
|
| if name == self.name:
|
| return
|
| elif name == "NormalBae":
|
| print("Loading NormalBae")
|
| self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
|
| torch.cuda.empty_cache()
|
| self.name = name
|
| else:
|
| raise ValueError
|
| return
|
|
|
| def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
| return self.model(image, **kwargs) |