| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| import gradio as gr |
| from PIL import Image |
| import torch |
| import matplotlib.pyplot as plt |
| import torch |
| import numpy as np |
|
|
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
|
|
| def process_image(image, prompt): |
| inputs = processor( |
| text=prompt, images=image, padding="max_length", return_tensors="pt" |
| ) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| preds = outputs.logits |
|
|
| pred = torch.sigmoid(preds) |
| mat = pred.cpu().numpy() |
| mask = Image.fromarray(np.uint8(mat * 255), "L") |
| mask = mask.convert("RGB") |
| mask = mask.resize(image.size) |
| mask = np.array(mask)[:, :, 0] |
|
|
| |
| mask_min = mask.min() |
| mask_max = mask.max() |
| mask = (mask - mask_min) / (mask_max - mask_min) |
| return mask |
|
|
|
|
| def get_masks(prompts, img, threhsold): |
| prompts = prompts.split(",") |
| masks = [] |
| for prompt in prompts: |
| mask = process_image(img, prompt) |
| mask = mask > threhsold |
| masks.append(mask) |
| return masks |
|
|
|
|
| def extract_image(img, pos_prompts, neg_prompts, threshold): |
| positive_masks = get_masks(pos_prompts, img, threshold) |
| negative_masks = get_masks(neg_prompts, img, threshold) |
|
|
| |
| pos_mask = np.any(np.stack(positive_masks), axis=0) |
| neg_mask = np.any(np.stack(negative_masks), axis=0) |
| final_mask = pos_mask & ~neg_mask |
|
|
| |
| final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") |
| inverse_mask = np.invert(final_mask) |
| output_image = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
| output_image.paste(img, mask=final_mask) |
|
|
| return output_image, final_mask, inverse_mask |
|
|
|
|
| title = "Interactive demo: zero-shot image segmentation with CLIPSeg" |
| description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds." |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>" |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts") |
| gr.Markdown(article) |
| gr.Markdown(description) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(type="pil") |
| positive_prompts = gr.Textbox( |
| label="Please describe what you want to identify (comma separated)" |
| ) |
| negative_prompts = gr.Textbox( |
| label="Please describe what you want to ignore (comma separated)" |
| ) |
|
|
| input_slider_T = gr.Slider( |
| minimum=0, maximum=1, value=0.4, label="Threshold" |
| ) |
| btn_process = gr.Button(label="Process") |
|
|
| with gr.Column(): |
| output_image = gr.Image(label="Result") |
| output_mask = gr.Image(label="Mask") |
| inverse_mask = gr.Image(label="Inverse") |
|
|
| btn_process.click( |
| extract_image, |
| inputs=[ |
| input_image, |
| positive_prompts, |
| negative_prompts, |
| input_slider_T, |
| ], |
| outputs=[output_image, output_mask, inverse_mask], |
| api_name="mask" |
| ) |
|
|
|
|
| demo.launch() |