Spaces:
Build error
Build error
| # Import libraries | |
| import cv2 | |
| from tensorflow import keras | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import segmentation_models as sm | |
| def get_mask(image:Image) -> Image: | |
| """ | |
| This function generates a mask of the image that highlights all the sofas in the image. | |
| This uses a pre-trained Unet model with a resnet50 backbone. | |
| Remark: The model was trained on 640by640 images and it is therefore best that the image has the same size. | |
| Parameters: | |
| image = original image | |
| Return: | |
| mask = corresponding maks of the image | |
| """ | |
| model_path = "model_checkpoint.h5" | |
| CLASSES = ['sofa'] | |
| BACKBONE = 'resnet50' | |
| # define network parameters | |
| n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) # case for binary and multiclass segmentation | |
| activation = 'sigmoid' if n_classes == 1 else 'softmax' | |
| preprocess_input = sm.get_preprocessing(BACKBONE) | |
| sm.set_framework('tf.keras') | |
| LR=0.0001 | |
| #create model architecture | |
| model = sm.Unet(BACKBONE, classes=n_classes, activation=activation) | |
| # define optomizer | |
| optim = keras.optimizers.Adam(LR) | |
| # Segmentation models losses can be combined together by '+' and scaled by integer or float factor | |
| dice_loss = sm.losses.DiceLoss() | |
| focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss() | |
| total_loss = dice_loss + (1 * focal_loss) | |
| # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses | |
| # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss | |
| metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)] | |
| # compile keras model with defined optimozer, loss and metrics | |
| model.compile(optim, total_loss, metrics) | |
| #load model | |
| model.load_weights(model_path) | |
| test_img = np.array(image)#cv2.imread(path, cv2.IMREAD_COLOR) | |
| test_img = cv2.resize(test_img, (640, 640)) | |
| test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR) | |
| test_img = np.expand_dims(test_img, axis=0) | |
| prediction = model.predict(preprocess_input(np.array(test_img))).round() | |
| mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L") | |
| return mask | |
| def replace_sofa(image:Image, mask:Image, styled_sofa:Image) -> Image: | |
| """ | |
| This function replaces the original sofa in the image by the new styled sofa according | |
| to the mask. | |
| Remark: All images should have the same size. | |
| Input: | |
| image = Original image | |
| mask = Generated masks highlighting the sofas in the image | |
| styled_sofa = Styled image | |
| Return: | |
| new_image = Image containing the styled sofa | |
| """ | |
| image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa) | |
| #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB) | |
| _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY) | |
| mask_inv = cv2.bitwise_not(mask) | |
| image_bg = cv2.bitwise_and(image,image,mask = mask_inv) | |
| sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask) | |
| new_image = cv2.add(image_bg,sofa_fg) | |
| return Image.fromarray(new_image) | |
| # image = cv2.imread('input/sofa.jpg') | |
| # mask = cv2.imread('masks/sofa.jpg') | |
| # styled_sofa = cv2.imread('output/sofa_stylized_style.jpg') | |
| # #get_mask(image) | |
| # plt.imshow(replace_sofa(image,mask,styled_sofa)) | |
| # plt.show() | |