|
|
from core import runner |
|
|
import torch |
|
|
from torch import tensor |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f'{device.type=}') |
|
|
|
|
|
description = '<p> Choose an example below; OR <br>\ |
|
|
Upload by yourself: <br>\ |
|
|
1. Upload any test image (query) with any target object you wish to segment <br>\ |
|
|
2. Upload another image (support) with the target object or a variation of it <br>\ |
|
|
3. Upload a binary mask that segments the target objet in the support image <br>\ |
|
|
</p>' |
|
|
|
|
|
|
|
|
example_episodes = [ |
|
|
['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'], |
|
|
['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'], |
|
|
['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'], |
|
|
['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'], |
|
|
['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'], |
|
|
['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'] |
|
|
] |
|
|
blank_img = './imgs/blank.png' |
|
|
|
|
|
gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil") |
|
|
inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')] |
|
|
if device.type=='cpu': |
|
|
inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)')) |
|
|
|
|
|
def prepare_feat_maker(): |
|
|
config = runner.makeConfig() |
|
|
class DummyDataset: |
|
|
class_ids = [0] |
|
|
fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device) |
|
|
return fake_feat_maker |
|
|
|
|
|
feat_maker = prepare_feat_maker() |
|
|
has_fit = False |
|
|
|
|
|
|
|
|
|
|
|
def reset_layers(): |
|
|
global feat_maker |
|
|
feat_maker = prepare_feat_maker() |
|
|
|
|
|
def prepare_batch(q_img_pil, s_img_pil, s_mask_pil): |
|
|
from data.dataset import FSSDataset |
|
|
FSSDataset.initialize(img_size=400,datapath='') |
|
|
q_img_tensor = FSSDataset.transform(q_img_pil) |
|
|
s_img_tensor = FSSDataset.transform(s_img_pil) |
|
|
s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L'))) |
|
|
s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze() |
|
|
add_batch_dim = lambda t: t.unsqueeze(0) |
|
|
add_kshot_dim = lambda t: t.unsqueeze(1) |
|
|
fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])} |
|
|
return fake_batch |
|
|
|
|
|
norm = lambda t: (t - t.min()) / (t.max() - t.min()) |
|
|
def overlay(img, mask): |
|
|
|
|
|
return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5 |
|
|
|
|
|
def from_model(q_img, s_img, s_mask): |
|
|
batch = prepare_batch(q_img, s_img, s_mask) |
|
|
sseval = runner.SingleSampleEval(batch, feat_maker) |
|
|
pred_logits, pred_mask = sseval.forward() |
|
|
global has_fit |
|
|
has_fit = True |
|
|
|
|
|
return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy()) |
|
|
|
|
|
def predict(q,s,m,re_adapt,confirmed): |
|
|
print(f'predict with {re_adapt=}, {confirmed=}') |
|
|
print(f'{type(q)=}') |
|
|
is_cache_run = re_adapt is None and confirmed is None |
|
|
is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) |
|
|
print(f'{is_example=}') |
|
|
|
|
|
if is_cache_run: |
|
|
reset_layers() |
|
|
pred = from_model(q,s,m) |
|
|
msg = 'Results ready.' |
|
|
return msg, *pred |
|
|
elif re_adapt: |
|
|
if confirmed: |
|
|
reset_layers() |
|
|
pred = from_model(q,s,m) |
|
|
msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters." |
|
|
return msg, *pred |
|
|
else: |
|
|
msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start." |
|
|
return msg, blank_img, blank_img |
|
|
else: |
|
|
if is_example: |
|
|
msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'." |
|
|
return msg, blank_img, blank_img |
|
|
else: |
|
|
if has_fit: |
|
|
pred = from_model(q,s,m) |
|
|
msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'." |
|
|
return msg, *pred |
|
|
else: |
|
|
msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'." |
|
|
return msg, blank_img, blank_img |
|
|
|
|
|
gradio_app = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=inputs, |
|
|
outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")], |
|
|
description=description, |
|
|
examples=example_episodes, |
|
|
title="abcdfss", |
|
|
) |
|
|
|
|
|
gradio_app.launch() |