heyoujue commited on
Commit
492bb92
·
1 Parent(s): afd410c

for gradio app

Browse files
README.md CHANGED
@@ -1,49 +1,10 @@
1
- # Adapt Before Comparision: A New Perspective on Cross-Domain Few-Shot Segmentation
2
-
3
- [[`Paper`](https://arxiv.org/abs/2402.17614)] accepted for CVPR'24.
4
-
5
- ## Preparing Data
6
- Because we follow the evaluation procedure of PATNet and Remember the Difference (RtD), please refer to their work for prepration of the following datasets:
7
- - Deepglobe (PAT)
8
- - ISIC (PAT)
9
- - Chest X-Ray (Lung) (PAT)
10
- - FSS-1000 (PAT)
11
- - SUIM (RtD)
12
-
13
- You do not need to get all datasets. Just prepare the one you want to test our method with.
14
-
15
- ## Python package prerequisites
16
- 1. torch
17
- 2. torchvision
18
- 3. cv2
19
- 4. numpy
20
- 5. for others, follow the console output
21
-
22
- ## Run it
23
- Call
24
- `python main.py --benchmark {} --datapath {} --nshot {}`
25
-
26
- for example
27
- `python main.py --benchmark deepglobe --datapath ./datasets/deepglobe/ --nshot 1`
28
-
29
- Available `benchmark` strings: `deepglobe`,`isic`,`lung`,`fss`,`suim`. Easiest to prepare should be `lung` or `fss`.
30
-
31
- Default is quick-infer mode.
32
- To change this, set `config.featext.fit_every_episode=True` in the main file.
33
- You can change all other parameters likewise, check the available parameters in `core/runner->makeConfig()`.
34
-
35
- ## Await it
36
-
37
- You can experiment with this code. Before opening issues, I suggest awaiting nicer demonstrations and documentation to be added.
38
-
39
- ## Cite it
40
- If you use ABCDFSS in your research, please use the following BibTeX entry.
41
- ```
42
- @article{herzog2024cdfss,
43
- title={Adapt Before Comparison: A New Perspective on Cross-Domain Few-Shot Segmentation},
44
- author={Jonas Herzog},
45
- journal={arXiv:2402.17614},
46
- year={2024}
47
- }
48
- ```
49
-
 
1
+ ---
2
+ title: Abcdfss
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.24.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core import runner
2
+ import torch
3
+ from torch import tensor
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import gradio as gr
8
+
9
+ #-----------------------global definitions------------------------#
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f'{device.type=}')
13
+
14
+ description = '<p> Choose an example below; OR <br>\
15
+ Upload by yourself: <br>\
16
+ 1. Upload any test image (query) with any target object you wish to segment <br>\
17
+ 2. Upload another image (support) with the target object or a variation of it <br>\
18
+ 3. Upload a binary mask that segments the target objet in the support image <br>\
19
+ </p>'
20
+
21
+ # qimg, simg, smask
22
+ example_episodes = [
23
+ ['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'],
24
+ ['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'],
25
+ ['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'],
26
+ ['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'],
27
+ ['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'],
28
+ ['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png']
29
+ ]
30
+ blank_img = './imgs/blank.png'
31
+
32
+ gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil")
33
+ inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')]
34
+ if device.type=='cpu':
35
+ inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)'))
36
+
37
+ def prepare_feat_maker():
38
+ config = runner.makeConfig()
39
+ class DummyDataset:
40
+ class_ids = [0]
41
+ fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device)
42
+ return fake_feat_maker
43
+
44
+ feat_maker = prepare_feat_maker()
45
+ has_fit = False
46
+
47
+ #-----------------------------------------------------------------#
48
+
49
+ def reset_layers():
50
+ global feat_maker
51
+ feat_maker = prepare_feat_maker()
52
+
53
+ def prepare_batch(q_img_pil, s_img_pil, s_mask_pil):
54
+ from data.dataset import FSSDataset
55
+ FSSDataset.initialize(img_size=400,datapath='')
56
+ q_img_tensor = FSSDataset.transform(q_img_pil)
57
+ s_img_tensor = FSSDataset.transform(s_img_pil)
58
+ s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L')))
59
+ s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze()
60
+ add_batch_dim = lambda t: t.unsqueeze(0)
61
+ add_kshot_dim = lambda t: t.unsqueeze(1)
62
+ fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])}
63
+ return fake_batch
64
+
65
+ norm = lambda t: (t - t.min()) / (t.max() - t.min())
66
+ def overlay(img, mask):
67
+ #img h,w,3(float) mask h,w(float)
68
+ return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5
69
+
70
+ def from_model(q_img, s_img, s_mask):
71
+ batch = prepare_batch(q_img, s_img, s_mask)
72
+ sseval = runner.SingleSampleEval(batch, feat_maker)
73
+ pred_logits, pred_mask = sseval.forward()
74
+ global has_fit
75
+ has_fit = True
76
+ # logit mask in range from -1 to 1, and mask-overlaid query image 0 to 1
77
+ return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy())
78
+
79
+ def predict(q,s,m,re_adapt,confirmed):
80
+ print(f'predict with {re_adapt=}, {confirmed=}')
81
+ print(f'{type(q)=}')
82
+ is_cache_run = re_adapt is None and confirmed is None
83
+ is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) #[2] pointing to support mask
84
+ print(f'{is_example=}')
85
+
86
+ if is_cache_run:
87
+ reset_layers()
88
+ pred = from_model(q,s,m)
89
+ msg = 'Results ready.'
90
+ return msg, *pred
91
+ elif re_adapt:
92
+ if confirmed:
93
+ reset_layers()
94
+ pred = from_model(q,s,m)
95
+ msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters."
96
+ return msg, *pred
97
+ else:
98
+ msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start."
99
+ return msg, blank_img, blank_img
100
+ else:
101
+ if is_example:
102
+ msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'."
103
+ return msg, blank_img, blank_img
104
+ else:
105
+ if has_fit:
106
+ pred = from_model(q,s,m)
107
+ msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'."
108
+ return msg, *pred
109
+ else:
110
+ msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'."
111
+ return msg, blank_img, blank_img
112
+
113
+ gradio_app = gr.Interface(
114
+ fn=predict,
115
+ inputs=inputs,
116
+ outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")],
117
+ description=description,
118
+ examples=example_episodes,
119
+ title="abcdfss",
120
+ )
121
+
122
+ gradio_app.launch()
imgs/1.jpg ADDED
imgs/457070_00.jpg ADDED
imgs/457070_00.png ADDED
imgs/5.jpg ADDED
imgs/5.png ADDED
imgs/549870_35.jpg ADDED
imgs/CHNCXR_0282_0.png ADDED
imgs/CHNCXR_0324_0.png ADDED
imgs/CHNCXR_0324_0_mask.png ADDED
imgs/ISIC_0000372.jpg ADDED
imgs/ISIC_0013176.jpg ADDED
imgs/ISIC_0013176_segmentation.png ADDED
imgs/blank.png ADDED
imgs/cake1.png ADDED
imgs/cake2.png ADDED
imgs/cake2_mask.png ADDED
imgs/d_r_450_.jpg ADDED
imgs/d_r_465_.bmp ADDED
imgs/d_r_465_.jpg ADDED
main.py DELETED
@@ -1,37 +0,0 @@
1
- from core import runner
2
- import torch
3
- import argparse
4
-
5
- def parse_opts():
6
- r"""arguments"""
7
- parser = argparse.ArgumentParser(description='Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation')
8
-
9
- # common
10
- parser.add_argument('--benchmark', type=str, default='lung', choices=['fss', 'deepglobe', 'lung', 'isic', 'fss', 'lung'])
11
- parser.add_argument('--datapath', type=str)
12
- parser.add_argument('--nshot', type=int, default=1)
13
-
14
- args = parser.parse_args()
15
- return args
16
-
17
-
18
- if __name__ == '__main__':
19
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
- args = parse_opts()
21
- print(args)
22
- runner.args.benchmark = args.benchmark
23
- runner.args.datapath = args.datapath
24
- runner.args.nshot = args.nshot
25
-
26
- dataloader = runner.makeDataloader()
27
- config = runner.makeConfig()
28
- feat_maker = runner.makeFeatureMaker(dataloader.dataset, config, device=device)
29
- average_meter = runner.AverageMeterWrapper(dataloader, device)
30
-
31
- for idx, batch in enumerate(dataloader):
32
- sseval = runner.SingleSampleEval(batch, feat_maker)
33
- sseval.forward()
34
- sseval.calc_metrics()
35
- average_meter.update(sseval)
36
- average_meter.write(idx)
37
- print('Result m|FB:', average_meter.average_meter.compute_iou())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ numpy
5
+ tensorboardX