File size: 5,232 Bytes
492bb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92b52c5
 
492bb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from core import runner
import torch
from torch import tensor
from PIL import Image
import numpy as np
import torch.nn.functional as F
import gradio as gr

#-----------------------global definitions------------------------#

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'{device.type=}')

description = '<p> Choose an example below; OR <br>\
    Upload by yourself: <br>\
    1. Upload any test image (query) with any target object you wish to segment <br>\
    2. Upload another image (support) with the target object or a variation of it <br>\
    3. Upload a binary mask that segments the target objet in the support image <br>\
    </p>'
    
# qimg, simg, smask
example_episodes = [
            ['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'],
            ['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'],
            ['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'],
            ['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'],
            ['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'],
            ['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png']
            ]
blank_img = './imgs/blank.png'

gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil")
inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')]
if device.type=='cpu':
  inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)'))

def prepare_feat_maker():
    config = runner.makeConfig()
    class DummyDataset:
        class_ids = [0]
    fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device)
    return fake_feat_maker

feat_maker = prepare_feat_maker()
has_fit = False

#-----------------------------------------------------------------#

def reset_layers():
    global feat_maker
    feat_maker = prepare_feat_maker()
    
def prepare_batch(q_img_pil, s_img_pil, s_mask_pil):
    from data.dataset import FSSDataset
    FSSDataset.initialize(img_size=400,datapath='')
    q_img_tensor = FSSDataset.transform(q_img_pil)
    s_img_tensor = FSSDataset.transform(s_img_pil)
    s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L')))
    s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze()
    add_batch_dim = lambda t: t.unsqueeze(0)
    add_kshot_dim = lambda t: t.unsqueeze(1)
    fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])}
    return fake_batch

norm = lambda t: (t - t.min()) / (t.max() - t.min())
def overlay(img, mask):
    #img h,w,3(float) mask h,w(float)
    return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5
    
def from_model(q_img, s_img, s_mask):
    batch = prepare_batch(q_img, s_img, s_mask)
    sseval = runner.SingleSampleEval(batch, feat_maker)
    pred_logits, pred_mask = sseval.forward()
    global has_fit
    has_fit = True
    # logit mask in range from -1 to 1, and mask-overlaid query image 0 to 1
    return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy())

def predict(q,s,m,re_adapt,confirmed):
  print(f'predict with {re_adapt=}, {confirmed=}')
  print(f'{type(q)=}')
  is_cache_run = re_adapt is None and confirmed is None
  is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) #[2] pointing to support mask
  print(f'{is_example=}')

  if is_cache_run:
    reset_layers()
    pred = from_model(q,s,m)
    msg = 'Results ready.'
    return msg, *pred
  elif re_adapt:
    if confirmed:
        reset_layers()
        pred = from_model(q,s,m)
        msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters."
        return msg, *pred
    else:
        msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start."
        return msg, blank_img, blank_img
  else:
    if is_example:
        msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'."
        return msg, blank_img, blank_img
    else:
        if has_fit:
            pred = from_model(q,s,m)
            msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'."
            return msg, *pred
        else:
            msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'."
            return msg, blank_img, blank_img
    
gradio_app = gr.Interface(
    fn=predict,
    inputs=inputs,
    outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")],
    description=description,
    examples=example_episodes,
    title="abcdfss",
)

gradio_app.launch()