Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| import threading | |
| from torchvision import transforms | |
| from clip.clipseg import CLIPDensePredT | |
| import numpy as np | |
| from roop.typing import Frame | |
| THREAD_LOCK_CLIP = threading.Lock() | |
| class Mask_Clip2Seg: | |
| plugin_options: dict = None | |
| model_clip = None | |
| processorname = "clip2seg" | |
| type = "mask" | |
| def Initialize(self, plugin_options: dict): | |
| if self.plugin_options is not None: | |
| if self.plugin_options["devicename"] != plugin_options["devicename"]: | |
| self.Release() | |
| self.plugin_options = plugin_options | |
| if self.model_clip is None: | |
| self.model_clip = CLIPDensePredT( | |
| version="ViT-B/16", reduce_dim=64, complex_trans_conv=True | |
| ) | |
| self.model_clip.eval() | |
| self.model_clip.load_state_dict( | |
| torch.load( | |
| "models/CLIP/rd64-uni-refined.pth", map_location=torch.device("cpu") | |
| ), | |
| strict=False, | |
| ) | |
| device = torch.device(self.plugin_options["devicename"]) | |
| self.model_clip.to(device) | |
| def Run(self, img1, keywords: str) -> Frame: | |
| if keywords is None or len(keywords) < 1 or img1 is None: | |
| return img1 | |
| source_image_small = cv2.resize(img1, (256, 256)) | |
| img_mask = np.full( | |
| (source_image_small.shape[0], source_image_small.shape[1]), | |
| 0, | |
| dtype=np.float32, | |
| ) | |
| mask_border = 1 | |
| l = 0 | |
| t = 0 | |
| r = 1 | |
| b = 1 | |
| mask_blur = 5 | |
| clip_blur = 5 | |
| img_mask = cv2.rectangle( | |
| img_mask, | |
| (mask_border + int(l), mask_border + int(t)), | |
| (256 - mask_border - int(r), 256 - mask_border - int(b)), | |
| (255, 255, 255), | |
| -1, | |
| ) | |
| img_mask = cv2.GaussianBlur(img_mask, (mask_blur * 2 + 1, mask_blur * 2 + 1), 0) | |
| img_mask /= 255 | |
| input_image = source_image_small | |
| transform = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
| ), | |
| transforms.Resize((256, 256)), | |
| ] | |
| ) | |
| img = transform(input_image).unsqueeze(0) | |
| thresh = 0.5 | |
| prompts = keywords.split(",") | |
| with THREAD_LOCK_CLIP: | |
| with torch.no_grad(): | |
| preds = self.model_clip(img.repeat(len(prompts), 1, 1, 1), prompts)[0] | |
| clip_mask = torch.sigmoid(preds[0][0]) | |
| for i in range(len(prompts) - 1): | |
| clip_mask += torch.sigmoid(preds[i + 1][0]) | |
| clip_mask = clip_mask.data.cpu().numpy() | |
| np.clip(clip_mask, 0, 1) | |
| clip_mask[clip_mask > thresh] = 1.0 | |
| clip_mask[clip_mask <= thresh] = 0.0 | |
| kernel = np.ones((5, 5), np.float32) | |
| clip_mask = cv2.dilate(clip_mask, kernel, iterations=1) | |
| clip_mask = cv2.GaussianBlur( | |
| clip_mask, (clip_blur * 2 + 1, clip_blur * 2 + 1), 0 | |
| ) | |
| img_mask *= clip_mask | |
| img_mask[img_mask < 0.0] = 0.0 | |
| return img_mask | |
| def Release(self): | |
| self.model_clip = None | |