| import torch |
| import torchvision |
| from PIL import Image |
| from pathlib import Path |
| import os |
| import numpy as np |
| from carvekit.api.high import HiInterface |
| import gradio as gr |
| import torch |
|
|
|
|
| class PlatonicDistanceModel(torch.nn.Module): |
| def __init__(self, device, carvekit_object_type="object"): |
| """ |
| :param device: string or torch.device object to run the model on. |
| :param carvekit_object_type: object type for foreground segmentation. Can be "object" or "hairs-like". |
| We find that "object" works well for most images in the CUTE dataset as well as vehicle ReID. |
| """ |
| super().__init__() |
| self.device = device |
| self.encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') |
| self.encoder.to(self.device) |
|
|
| self.interface = HiInterface(object_type=carvekit_object_type, |
| batch_size_seg=5, |
| batch_size_matting=1, |
| device=str(self.device), |
| seg_mask_size=640, |
| matting_mask_size=2048, |
| trimap_prob_threshold=231, |
| trimap_dilation=30, |
| trimap_erosion_iters=5, |
| fp16=False) |
|
|
| def preprocess(self, x_list): |
|
|
| preprocessed_images = [] |
|
|
| for x in x_list: |
| |
| new_width = 336 |
| new_height = 336 |
|
|
| def _to_rgb(x): |
| if x.mode != "RGB": |
| x = x.convert("RGB") |
| return x |
|
|
| preprocessed_image = torchvision.transforms.Compose([ |
| _to_rgb, |
| torchvision.transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC), |
| torchvision.transforms.ToTensor(), |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ])(x) |
| preprocessed_images.append(preprocessed_image) |
|
|
| return torch.stack(preprocessed_images, dim=0).to(self.device) |
|
|
| def get_foreground_mask(self, tensor_imgs): |
| masks = [] |
| for tensor_img in tensor_imgs: |
| tensor_img = tensor_img.detach().cpu() |
| numpy_img_sum = tensor_img.sum(dim=0).numpy() |
| min_value = np.min(numpy_img_sum) |
| mask = ~(numpy_img_sum == min_value) |
| mask = mask.astype(np.uint8) |
| mask = Image.fromarray(mask * 255) |
| resized_mask = mask.resize((24, 24), Image.BILINEAR) |
| resized_mask_numpy = np.array(resized_mask) |
| resized_mask_numpy = resized_mask_numpy / 255.0 |
| tensor_mask = torch.from_numpy(resized_mask_numpy.astype(np.float32)) |
| tensor_mask[tensor_mask > 0.5] = 1.0 |
| tensor_mask = tensor_mask.unsqueeze(0).long().to(self.device) |
| if tensor_mask.sum() == 0: |
| tensor_mask = torch.ones_like(tensor_mask) |
| masks.append(tensor_mask) |
| return torch.stack(masks, dim=0) |
|
|
| def forward(self, variant, *x): |
| if len(x) == 1 and (isinstance(x[0], list) or isinstance(x[0], torch.Tensor)): |
| return self.forward_single(x[0], variant) |
| elif len(x) == 1: |
| return self.forward_single([x[0]], variant) |
| elif len(x) == 2: |
| return torch.cosine_similarity(self.forward_single(x[0], variant)[0], self.forward_single(x[1], variant)[0], dim=0).cpu().item() |
| else: |
| raise ValueError("Invalid number of inputs, only 1 or 2 inputs are supported.") |
|
|
| def forward_single(self, x_list, variant): |
|
|
| with torch.no_grad(): |
| original_sizes = [(x.size[1], x.size[0]) for x in x_list] |
| img_list = [np.array(self.interface([x])[0]) for x in x_list] |
| for img in img_list: |
| img[img[..., 3] == 0] = [0, 0, 0, 0] |
| img_list = [Image.fromarray(img) for img in img_list] |
| preprocessed_imgs = self.preprocess(img_list) |
| masks = self.get_foreground_mask(preprocessed_imgs) |
| if variant == "Crop-Feat": |
| emb = self.encoder.forward_features(preprocessed_imgs) |
| elif variant == "Crop-Img": |
| emb = self.encoder.forward_features(self.preprocess(x_list)) |
| else: |
| raise ValueError("Invalid variant, only Crop-Feat and Crop-Img are supported.") |
|
|
| grid = emb["x_norm_patchtokens"].view(len(x_list), 24, 24, -1) |
|
|
| return (grid * masks.permute(0, 2, 3, 1)).sum(dim=(1, 2)) / masks.sum(dim=(1, 2, 3)).unsqueeze(-1) |
|
|
|
|
| def compare(image_1, image_2, variant): |
| similarity_score = model(variant, [image_1], [image_2]) |
| return f"The similarity score is: {similarity_score:.2f}" |
|
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = PlatonicDistanceModel(device) |
|
|
| demo = gr.Interface(title="Foreground Feature Averaging (FFA) Intrinsic Object Similarity Demo", |
| description="Compare two images using the foreground feature averaging metric, a strong baseline for intrinsic object similarity. Please see our project website at https://s-tian.github.io/projects/cute/ for more information.", |
| fn=compare, |
| inputs=[gr.Image(type="pil", label="Image 1"), |
| gr.Image(type="pil", label="Image 2"), |
| gr.Radio(choices=["Crop-Feat", "Crop-Img"], value="Crop-Feat", label="Variant (use Crop-Feat if not sure)")], |
| outputs="text") |
|
|
| if __name__ == "__main__": |
| demo.launch() |