Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib import cm | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from transformers import AutoImageProcessor, UperNetForSemanticSegmentation | |
| from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation | |
| from diffusers import StableDiffusionInpaintPipeline | |
| class VirtualStagingToolV2(): | |
| def __init__(self, | |
| segmentation_version='openmmlab/upernet-convnext-tiny', | |
| diffusion_version="stabilityai/stable-diffusion-2-inpainting" | |
| ): | |
| self.segmentation_version = segmentation_version | |
| self.diffusion_version = diffusion_version | |
| if segmentation_version == "openmmlab/upernet-convnext-tiny": | |
| self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version) | |
| self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version) | |
| elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640": | |
| self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version) | |
| self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version) | |
| self.diffution_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
| self.diffusion_version, | |
| torch_dtype=torch.float16, | |
| ) | |
| self.diffution_pipeline = self.diffution_pipeline.to("cuda") | |
| def _predict(self, image): | |
| inputs = self.feature_extractor(images=image, return_tensors="pt") | |
| outputs = self.segmentation_model(**inputs) | |
| prediction = \ | |
| self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
| return prediction | |
| def _save_mask(self, img, prediction_array, mask_items=[]): | |
| mask = np.zeros_like(prediction_array, dtype=np.uint8) | |
| mask[np.isin(prediction_array, mask_items)] = 0 | |
| mask[~np.isin(prediction_array, mask_items)] = 255 | |
| buffer_size = 10 | |
| # Dilate the binary image | |
| kernel = np.ones((buffer_size, buffer_size), np.uint8) | |
| dilated_image = cv2.dilate(mask, kernel, iterations=1) | |
| # Subtract the original binary image | |
| buffer_area = dilated_image - mask | |
| # Apply buffer area to the original image | |
| mask = cv2.bitwise_or(mask, buffer_area) | |
| # # # Create a PIL Image object from the mask | |
| mask_image = Image.fromarray(mask, mode='L') | |
| # display(mask_image) | |
| # mask_image = mask_image.resize((512, 512)) | |
| # mask_image.save(".tmp/mask_1.png", "PNG") | |
| # img = img.resize((512, 512)) | |
| # img.save(".tmp/input_1.png", "PNG") | |
| return mask_image | |
| def _save_transparent_mask(self, img, prediction_array, mask_items=[]): | |
| mask = np.array(img) | |
| mask[~np.isin(prediction_array, mask_items), :] = 255 | |
| mask_image = Image.fromarray(mask).convert('RGBA') | |
| # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent) | |
| mask_data = mask_image.getdata() | |
| mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data] | |
| mask_image.putdata(mask_data) | |
| return mask_image | |
| def get_mask(self, image_path=None, image=None): | |
| if image_path: | |
| image = Image.open(image_path) | |
| else: | |
| if not image: | |
| raise ValueError("no image provided") | |
| # display(image) | |
| prediction = self._predict(image) | |
| label_ids = np.unique(prediction) | |
| mask_items = [0, 3, 5, 8, 14] | |
| if 1 in label_ids or 25 in label_ids: | |
| mask_items = [1, 2, 4, 25, 32] | |
| room = 'backyard' | |
| elif 73 in label_ids or 50 in label_ids or 61 in label_ids: | |
| mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129 | |
| ] | |
| room = 'kitchen' | |
| elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids): | |
| mask_items = [0, 3, 5, 8, 14, 27, 65] | |
| room = 'bathroom' | |
| elif 7 in label_ids: | |
| room = 'bedroom' | |
| elif 23 in label_ids or 49 in label_ids: | |
| mask_items = [0, 3, 5, 8, 14, 49] | |
| room = 'living room' | |
| elif 15 in label_ids and 19 in label_ids: | |
| room = 'dining room' | |
| else: | |
| room ='room' | |
| label_ids_without_mask = [i for i in label_ids if i not in mask_items] | |
| items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask] | |
| mask_image = self._save_mask(image, prediction, mask_items) | |
| transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items) | |
| return mask_image, transparent_mask_image, image, items, room | |
| def _edit_image(self, init_image, mask_image, prompt, # height, width, | |
| number_images=1): | |
| init_image = init_image.resize((512, 512)).convert("RGB") | |
| mask_image = mask_image.resize((512, 512)).convert("RGB") | |
| output_images = self.diffution_pipeline( | |
| prompt=prompt, image=init_image, mask_image=mask_image, | |
| # width=width, height=height, | |
| num_images_per_prompt=number_images).images | |
| # display(output_image) | |
| return output_images | |
| def virtual_stage(self, image_path=None, image=None, style=None, | |
| color_preference=None, additional_info=None, number_images=1): | |
| mask_image, transparent_mask_image, init_image, items, room = self.get_mask(image_path, image) | |
| if not style: | |
| raise ValueError('style not provided.') | |
| if room == 'kitchen': | |
| items = [i for i in items if i in ['cabinet', 'shelf', 'counter', 'countertop', 'stool']] | |
| elif room == 'bedroom': | |
| items = [i for i in items if i in ['bed ', 'table', 'chest of drawers', 'desk', 'armchair', 'wardrobe']] | |
| elif room == 'bathroom': | |
| items = [i for i in items if | |
| i in ['shower', 'bathtub', 'screen door', 'cabinet']] | |
| elif room == 'living room': | |
| items = [i for i in items if | |
| i in ['table', 'sofa', 'chest of drawers', 'armchair', 'cabinet', 'coffee table']] | |
| elif room == 'dining room': | |
| items = [i for i in items if i in ['table', 'chair', 'cabinet']] | |
| items = ', '.join(items) | |
| if room == 'backyard': | |
| prompt = f'Realistic, high resolution, {room} with {style}' | |
| else: | |
| prompt = f'Realistic {items}, high resolution, in the {style} style {room}' | |
| if color_preference: | |
| prompt = f"{prompt} in {color_preference}" | |
| if additional_info: | |
| prompt = f'{prompt}. {additional_info}' | |
| print(prompt) | |
| output_images = self._edit_image(init_image, mask_image, prompt, number_images) | |
| final_output_images = [] | |
| for output_image in output_images: | |
| output_image = output_image.resize(init_image.size) | |
| final_output_images.append(output_image) | |
| return final_output_images, transparent_mask_image | |