|
|
import gc |
|
|
import gradio as gr |
|
|
from fastapi import FastAPI, Body,Header,status |
|
|
from gradio import components,Blocks,Row |
|
|
import json |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import torch |
|
|
from pathlib import Path |
|
|
import os |
|
|
import requests |
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
|
import numpy as np |
|
|
from modules.safe import unsafe_torch_load, load |
|
|
from modules.devices import device, torch_gc, cpu |
|
|
|
|
|
from modules.processing import process_images |
|
|
import modules.scripts as scripts |
|
|
|
|
|
UNIT_DEBUG=False |
|
|
def import_or_install(package,pip_name=None): |
|
|
import importlib |
|
|
import subprocess |
|
|
if pip_name is None: |
|
|
pip_name=package |
|
|
try: |
|
|
importlib.import_module(package) |
|
|
print(f"{package} is already installed") |
|
|
except ImportError: |
|
|
print(f"{package} is not installed, installing now...") |
|
|
subprocess.call(['pip', 'install', package]) |
|
|
print(f"{package} has been installed") |
|
|
|
|
|
import_or_install("segment_anything","git+https://github.com/facebookresearch/segment-anything.git") |
|
|
|
|
|
class InteractiveImageSegmentor: |
|
|
def download_file_if_not_exists(file_url, file_name): |
|
|
if not os.path.isfile(file_name): |
|
|
response = requests.get(file_url) |
|
|
if response.status_code == 200: |
|
|
with open(file_name, 'wb') as file: |
|
|
file.write(response.content) |
|
|
print("File downloaded successfully!") |
|
|
else: |
|
|
print("Failed to download the file.") |
|
|
|
|
|
def load_model(self,model_choice="sam_vit_b"): |
|
|
sam_checkpoint=f"{model_choice}.pth" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.device = device |
|
|
if model_choice=="sam_vit_b":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",sam_checkpoint) |
|
|
elif model_choice=="sam_vit_l":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",sam_checkpoint) |
|
|
elif model_choice=="sam_vit_h":InteractiveImageSegmentor.download_file_if_not_exists("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",sam_checkpoint) |
|
|
model_type=model_choice.replace("sam_","") |
|
|
if model_type not in sam_model_registry: |
|
|
model_type="default" |
|
|
print(f"Loading model {model_type} from {sam_checkpoint}") |
|
|
torch.load = unsafe_torch_load |
|
|
self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
|
self.sam.to(self.device) |
|
|
self.predictor = SamPredictor(self.sam) |
|
|
torch.load = load |
|
|
def clear_sam_cache(self): |
|
|
self.sam.unload_model() |
|
|
gc.collect() |
|
|
torch_gc() |
|
|
|
|
|
def mask2image_multi(self,mask:torch.Tensor): |
|
|
|
|
|
if mask.dim() == 3 and mask.shape[-1] == 3: |
|
|
mask = mask.permute(2, 0, 1) |
|
|
elif mask.dim() == 3 and mask.shape[0] == 3: |
|
|
pass |
|
|
else: |
|
|
print(mask.shape) |
|
|
raise ValueError("Mask tensor has an unexpected shape.") |
|
|
color = torch.Tensor([255/255, 155/255, 114/255, 0.6]).to(self.device) |
|
|
binary_mask = mask[0, :, :] |
|
|
h, w = binary_mask.shape |
|
|
mask_image = binary_mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
|
return mask_image.permute(2, 0, 1) |
|
|
def mask2image(self, mask: torch.Tensor): |
|
|
if mask.dim() == 3 and mask.shape[0] == 1: |
|
|
binary_mask = mask.squeeze(0) |
|
|
elif mask.dim() == 2: |
|
|
binary_mask = mask |
|
|
else: |
|
|
print(mask.shape) |
|
|
raise ValueError("Mask tensor has an unexpected shape.") |
|
|
h, w = binary_mask.shape |
|
|
rgb_image = binary_mask.repeat(3, 1, 1) |
|
|
alpha_channel = torch.full((1, h, w), 0.6).to(self.device) |
|
|
rgba_image = torch.cat((rgb_image, alpha_channel), dim=0) |
|
|
color = torch.Tensor([255/255, 155/255, 114/255]).to(self.device).reshape(3, 1, 1) |
|
|
return rgba_image |
|
|
def preview_segment(self,image:Image,points:list[list[float]]=None,bbox=None,labels:list[int]=None): |
|
|
pil_2_tensor = transforms.PILToTensor() |
|
|
rgba_image = image.convert("RGBA") |
|
|
image_tensor = pil_2_tensor(rgba_image).cuda() |
|
|
result_tensor=image_tensor.clone() |
|
|
mask_tensor=self.segment(points,bbox,labels) |
|
|
mask_image_tensor=self.mask2image(mask_tensor) |
|
|
mask_image=transforms.ToPILImage()(mask_image_tensor) |
|
|
result_image:Image = transforms.ToPILImage()(result_tensor) |
|
|
result_image=Image.alpha_composite(result_image,mask_image) |
|
|
return result_image |
|
|
def segment(self,points:list[list[float]]=None,bbox=None,labels:list[int]=None)->torch.Tensor: |
|
|
if len(points)==0:points=None |
|
|
if len(labels)==0:labels=None |
|
|
if len(bbox)==0:bbox=None |
|
|
if points is not None:points = torch.Tensor(np.array(points)).to(self.device).unsqueeze(0) |
|
|
if labels is not None:labels = torch.Tensor(np.array(labels)).to(self.device).unsqueeze(0) |
|
|
if bbox is not None:bbox = torch.Tensor(np.array(bbox)).to(self.device) |
|
|
print(points,labels,bbox) |
|
|
masks, scores, logits = self.predictor.predict_torch( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
boxes=bbox, |
|
|
multimask_output=False, |
|
|
) |
|
|
return masks[0] |
|
|
|
|
|
def remove_selected(self,image:Image,points:list[list[float]]=None,boxes=None,labels:list[int]=None): |
|
|
pil_2_tensor = transforms.PILToTensor() |
|
|
rgba_image = image.convert("RGBA") |
|
|
image_tensor = pil_2_tensor(rgba_image).cuda() |
|
|
mask_tensor = image_segmentor.segment(points=points,bbox=boxes,labels=labels) |
|
|
result_tensor=image_tensor*(1-mask_tensor) |
|
|
result_image:Image = transforms.ToPILImage()(result_tensor) |
|
|
return result_image |
|
|
def remove_unselected(self,image:Image,points:list[list[float]]=None,boxes=None,labels:list[int]=None): |
|
|
pil_2_tensor = transforms.PILToTensor() |
|
|
rgba_image = image.convert("RGBA") |
|
|
image_tensor = pil_2_tensor(rgba_image).cuda() |
|
|
mask_tensor = image_segmentor.segment(points=points,bbox=boxes,labels=labels) |
|
|
print(image_tensor.shape,mask_tensor.shape) |
|
|
result_tensor=image_tensor*mask_tensor |
|
|
result_image:Image = transforms.ToPILImage()(result_tensor) |
|
|
return result_image |
|
|
pass |
|
|
|
|
|
def reset_image(image:Image): |
|
|
global image_segmentor |
|
|
if image_segmentor is None: |
|
|
image_segmentor=InteractiveImageSegmentor() |
|
|
image_segmentor.load_model() |
|
|
image_segmentor.predictor.reset_image() |
|
|
image_array = np.array(image) |
|
|
image_segmentor.predictor.set_image(image_array) |
|
|
return image |
|
|
|
|
|
def on_image_changed(image:Image): |
|
|
global points,labels,box_cache,boxes |
|
|
points=[] |
|
|
labels=[] |
|
|
boxes=[] |
|
|
box_cache=[] |
|
|
reset_image(image) |
|
|
return image |
|
|
|
|
|
def on_image_clicked(image:Image,choice,input_type,event_data:gr.SelectData): |
|
|
global box_cache,boxes,points,labels |
|
|
if isinstance(choice,str): |
|
|
if choice=="Select":choice=1 |
|
|
elif choice=="Deselect":choice=0 |
|
|
if input_type=="Point": |
|
|
points.append(event_data.index) |
|
|
labels.append(choice) |
|
|
return image_segmentor.preview_segment(image,points=points,bbox=boxes,labels=labels) |
|
|
elif input_type=="Box": |
|
|
box_cache.extend(event_data.index) |
|
|
if len(box_cache)==4: |
|
|
boxes.append(box_cache) |
|
|
box_cache=[] |
|
|
return image_segmentor.preview_segment(image,points=points,bbox=boxes,labels=labels) |
|
|
return image |
|
|
|
|
|
def on_remove_btn_clicked(image:Image,remove_type:str): |
|
|
global points,labels,box_cache,boxes |
|
|
if remove_type=="Selected": |
|
|
return image_segmentor.remove_selected(image,points=points,boxes=boxes,labels=labels) |
|
|
elif remove_type=="Unselected": |
|
|
return image_segmentor.remove_unselected(image,points=points,boxes=boxes,labels=labels) |
|
|
return image |
|
|
|
|
|
class Script(scripts.Script): |
|
|
def title(self): |
|
|
return "Interactive Image Segmentor" |
|
|
def show(self, is_img2img): |
|
|
return is_img2img |
|
|
def ui(self, is_img2img): |
|
|
if not is_img2img: return |
|
|
with Blocks(): |
|
|
with Row(equal_height=True): |
|
|
choice=components.Radio(choices=["Select","Deselect"],value="Select",label="Selection Type") |
|
|
input_type=components.Radio(choices=["Point","Box"],value="Point",label="Input Type") |
|
|
remove_type=components.Radio(choices=["Selected","Unselected"],value="Selected",label="Remove Type") |
|
|
with Row(equal_height=True): |
|
|
image=components.Image(type="pil",interactive=True,image_mode="RGB") |
|
|
resulting_image=components.Image(type="pil",image_mode="RGBA") |
|
|
image.change(on_image_changed,inputs=[image],outputs=[resulting_image]) |
|
|
image.select(on_image_clicked,inputs=[image,choice,input_type],outputs=[resulting_image]) |
|
|
with Row(equal_height=True): |
|
|
remove_btn = components.Button(value="Preview Remove Effect") |
|
|
remove_btn.click(on_remove_btn_clicked,inputs=[image,remove_type],outputs=[resulting_image]) |
|
|
pass |
|
|
return [image,points,labels,boxes] |
|
|
|
|
|
def run(self,p,image,points,labels,boxes): |
|
|
if image is None: |
|
|
image=p.init_images[0] |
|
|
image_segmentor.predictor.set_image(np.array(image)) |
|
|
mask=image_segmentor.predictor.predict_torch(points,labels,boxes) |
|
|
p.image_mask=mask |
|
|
proc = process_images(p) |
|
|
proc.images.append(mask) |
|
|
return proc |
|
|
pass |
|
|
|
|
|
def interactive_image_segmentor_api(_: Blocks, app: FastAPI): |
|
|
@app.post("/figma/interactive_image_segmentor/upload_image") |
|
|
async def upload_image(image_str:str = Body(...)): |
|
|
import base64 |
|
|
import io |
|
|
image_bytes = base64.b64decode(image_str) |
|
|
image = Image.open(io.BytesIO(image_bytes),formats=["PNG"]) |
|
|
image_segmentor.predictor.reset_image() |
|
|
image_segmentor.predictor.set_image(np.array(image)) |
|
|
return image |
|
|
@app.post("/figma/interactive_image_segmentor/image_x_mask") |
|
|
async def remove_selected(image_str:str = Body(...),points:list[list[float]]=Body(...),\ |
|
|
boxes:list[list[float]]=Body(...),labels:list[int]=Body(...), remove_type:bool=Body(...)): |
|
|
import base64 |
|
|
import io |
|
|
image_bytes = base64.b64decode(image_str) |
|
|
image = Image.open(io.BytesIO(image_bytes),formats=["PNG"]) |
|
|
if remove_type=="Selected": |
|
|
image= image_segmentor.remove_selected(image,points=points,boxes=boxes,labels=labels) |
|
|
elif remove_type=="Unselected": |
|
|
image= image_segmentor.remove_unselected(image,points=points,boxes=boxes,labels=labels) |
|
|
return image |
|
|
pass |
|
|
|
|
|
points=[] |
|
|
labels=[] |
|
|
box_cache:list=[] |
|
|
boxes=[] |
|
|
|
|
|
image_segmentor=None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import modules.script_callbacks as script_callbacks |
|
|
|
|
|
script_callbacks.on_app_started(interactive_image_segmentor_api) |
|
|
except: |
|
|
pass |
|
|
pass |