Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| from PIL import Image, ImageDraw | |
| import mediapipe as mp | |
| from transformers import pipeline | |
| from skimage.measure import label, regionprops | |
| import gradio as gr | |
| import torch | |
| import diffusers | |
| import tqdm as notebook_tqdm | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| import cv2 | |
| import math | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import mediapipe as mp | |
| from mediapipe.tasks import python | |
| from mediapipe.tasks.python import vision | |
| from mediapipe.tasks.python.components import containers | |
| from skimage.measure import label, regionprops | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from skimage.measure import label | |
| from skimage.measure import regionprops | |
| from PIL import Image | |
| import torch | |
| import requests | |
| import tensorflow as tf | |
| import spaces | |
| def _normalized_to_pixel_coordinates( | |
| normalized_x: float, normalized_y: float, image_width: int, | |
| image_height: int): | |
| """Converts normalized value pair to pixel coordinates.""" | |
| # Checks if the float value is between 0 and 1. | |
| def is_valid_normalized_value(value: float) -> bool: | |
| return (value > 0 or math.isclose(0, value)) and (value < 1 or | |
| math.isclose(1, value)) | |
| if not (is_valid_normalized_value(normalized_x) and | |
| is_valid_normalized_value(normalized_y)): | |
| # TODO: Draw coordinates even if it's outside of the image bounds. | |
| return None | |
| x_px = min(math.floor(normalized_x * image_width), image_width - 1) | |
| y_px = min(math.floor(normalized_y * image_height), image_height - 1) | |
| return x_px, y_px | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| ).to(device) | |
| BG_COLOR = (192, 192, 192) # gray | |
| MASK_COLOR = (255, 255, 255) # white | |
| RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest | |
| NormalizedKeypoint = containers.keypoint.NormalizedKeypoint | |
| # Create the options that will be used for InteractiveSegmenter | |
| base_options = python.BaseOptions(model_asset_path='model.tflite') | |
| options = vision.ImageSegmenterOptions(base_options=base_options, | |
| output_category_mask=True) | |
| def segment(image_file_name, x, y, prompt): | |
| OVERLAY_COLOR = (255, 105, 180) # Rose | |
| # Créer le segmenteur | |
| with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter: | |
| # Créer l'image MediaPipe | |
| image = mp.Image.create_from_file(image_file_name) | |
| # Récupérer les masques de catégorie pour l'image | |
| roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT, | |
| keypoint=NormalizedKeypoint(x, y)) | |
| segmentation_result = segmenter.segment(image, roi) | |
| category_mask = segmentation_result.category_mask | |
| # Trouver la boîte englobante de la région segmentée | |
| mask = category_mask.numpy_view().astype(np.uint8) | |
| # Trouver la boîte englobante de la région segmentée | |
| x, y, w, h = cv2.boundingRect(mask) | |
| # Convertir l'image BGR en RGB | |
| image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB) | |
| # Créer une image d'incrustation avec la couleur désirée (par exemple, (255, 0, 0) pour le rouge) | |
| overlay_image = np.zeros(image_data.shape, dtype=np.uint8) | |
| overlay_image[:] = OVERLAY_COLOR | |
| # Créer la condition à partir du tableau category_masks | |
| alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) <= 0.1 | |
| # Créer un canal alpha à partir de la condition avec l'opacité désirée (par exemple, 0.7 pour 70%) | |
| alpha = alpha.astype(float) * 0.5 # Réduire l'opacité à 50% | |
| # Fusionner l'image originale et l'image d'incrustation en fonction du canal alpha | |
| output_image = image_data * (1 - alpha) + overlay_image * alpha | |
| output_image = output_image.astype(np.uint8) | |
| # Dessiner un point blanc avec une bordure noire pour indiquer le point d'intérêt | |
| thickness, radius = 6, 10 | |
| keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height) | |
| cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius) | |
| cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius) | |
| # Convert the mask to binary if it's not already | |
| binary_mask = (mask == 255).astype(np.uint8) | |
| # Label the regions in the mask | |
| labels = label(binary_mask) | |
| # Obtain properties of the labeled regions | |
| props = regionprops(labels) | |
| # Initialize bounding box coordinates | |
| minr, minc, maxr, maxc = 0, 0, 0, 0 | |
| for prop in props: | |
| minr, minc, maxr, maxc = prop.bbox | |
| # Add a 30-pixel margin | |
| minr = max(0, minr - 300) | |
| minc = max(0, minc - 300) | |
| maxr = min(binary_mask.shape[0], maxr + 400) | |
| maxc = min(binary_mask.shape[1], maxc + 400) | |
| # Create a new black image | |
| bbox_image = np.zeros_like(binary_mask) | |
| # Draw the bounding box in white | |
| bbox_image[minr:maxr, minc:maxc] = 255 | |
| print(bbox_image) | |
| plt.imshow(bbox_image) | |
| plt.show() | |
| return output_image, bbox_image | |
| def generate(image_file_path, x, y, prompt): | |
| output_image, bbox_image = segment(image_file_path, x, y, prompt) | |
| # Check and process images | |
| if image_file_path is None or bbox_image is None: | |
| return None | |
| # Read image | |
| img = Image.open(image_file_path).convert("RGB") | |
| # Generate images using images and prompts | |
| images = pipe(prompt=prompt, | |
| image=img, | |
| mask_image=bbox_image, | |
| generator=torch.Generator(device="cuda").manual_seed(0), | |
| num_images_per_prompt=3, | |
| plms=True).images | |
| # Create an image grid | |
| def image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows*cols | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| return grid | |
| grid_image = image_grid(images, 1, 3) | |
| return output_image, grid_image | |
| webapp = gr.Interface(fn=generate, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Upload an image"), | |
| gr.Slider(minimum=0, maximum=1, step=0.01, label="x"), | |
| gr.Slider(minimum=0, maximum=1, step=0.01, label="y"), | |
| gr.Textbox(label="Prompt")], | |
| outputs=[ | |
| gr.Image(type="pil", label="Segmented Image"), | |
| gr.Image(type="pil", label="Generated Image Grid")]) | |
| webapp.launch() | |