for gradio app
Browse files- README.md +10 -49
- app.py +122 -0
- imgs/1.jpg +0 -0
- imgs/457070_00.jpg +0 -0
- imgs/457070_00.png +0 -0
- imgs/5.jpg +0 -0
- imgs/5.png +0 -0
- imgs/549870_35.jpg +0 -0
- imgs/CHNCXR_0282_0.png +0 -0
- imgs/CHNCXR_0324_0.png +0 -0
- imgs/CHNCXR_0324_0_mask.png +0 -0
- imgs/ISIC_0000372.jpg +0 -0
- imgs/ISIC_0013176.jpg +0 -0
- imgs/ISIC_0013176_segmentation.png +0 -0
- imgs/blank.png +0 -0
- imgs/cake1.png +0 -0
- imgs/cake2.png +0 -0
- imgs/cake2_mask.png +0 -0
- imgs/d_r_450_.jpg +0 -0
- imgs/d_r_465_.bmp +0 -0
- imgs/d_r_465_.jpg +0 -0
- main.py +0 -37
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,49 +1,10 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
-
|
| 11 |
-
- SUIM (RtD)
|
| 12 |
-
|
| 13 |
-
You do not need to get all datasets. Just prepare the one you want to test our method with.
|
| 14 |
-
|
| 15 |
-
## Python package prerequisites
|
| 16 |
-
1. torch
|
| 17 |
-
2. torchvision
|
| 18 |
-
3. cv2
|
| 19 |
-
4. numpy
|
| 20 |
-
5. for others, follow the console output
|
| 21 |
-
|
| 22 |
-
## Run it
|
| 23 |
-
Call
|
| 24 |
-
`python main.py --benchmark {} --datapath {} --nshot {}`
|
| 25 |
-
|
| 26 |
-
for example
|
| 27 |
-
`python main.py --benchmark deepglobe --datapath ./datasets/deepglobe/ --nshot 1`
|
| 28 |
-
|
| 29 |
-
Available `benchmark` strings: `deepglobe`,`isic`,`lung`,`fss`,`suim`. Easiest to prepare should be `lung` or `fss`.
|
| 30 |
-
|
| 31 |
-
Default is quick-infer mode.
|
| 32 |
-
To change this, set `config.featext.fit_every_episode=True` in the main file.
|
| 33 |
-
You can change all other parameters likewise, check the available parameters in `core/runner->makeConfig()`.
|
| 34 |
-
|
| 35 |
-
## Await it
|
| 36 |
-
|
| 37 |
-
You can experiment with this code. Before opening issues, I suggest awaiting nicer demonstrations and documentation to be added.
|
| 38 |
-
|
| 39 |
-
## Cite it
|
| 40 |
-
If you use ABCDFSS in your research, please use the following BibTeX entry.
|
| 41 |
-
```
|
| 42 |
-
@article{herzog2024cdfss,
|
| 43 |
-
title={Adapt Before Comparison: A New Perspective on Cross-Domain Few-Shot Segmentation},
|
| 44 |
-
author={Jonas Herzog},
|
| 45 |
-
journal={arXiv:2402.17614},
|
| 46 |
-
year={2024}
|
| 47 |
-
}
|
| 48 |
-
```
|
| 49 |
-
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Abcdfss
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.24.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from core import runner
|
| 2 |
+
import torch
|
| 3 |
+
from torch import tensor
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
#-----------------------global definitions------------------------#
|
| 10 |
+
|
| 11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
+
print(f'{device.type=}')
|
| 13 |
+
|
| 14 |
+
description = '<p> Choose an example below; OR <br>\
|
| 15 |
+
Upload by yourself: <br>\
|
| 16 |
+
1. Upload any test image (query) with any target object you wish to segment <br>\
|
| 17 |
+
2. Upload another image (support) with the target object or a variation of it <br>\
|
| 18 |
+
3. Upload a binary mask that segments the target objet in the support image <br>\
|
| 19 |
+
</p>'
|
| 20 |
+
|
| 21 |
+
# qimg, simg, smask
|
| 22 |
+
example_episodes = [
|
| 23 |
+
['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'],
|
| 24 |
+
['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'],
|
| 25 |
+
['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'],
|
| 26 |
+
['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'],
|
| 27 |
+
['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'],
|
| 28 |
+
['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png']
|
| 29 |
+
]
|
| 30 |
+
blank_img = './imgs/blank.png'
|
| 31 |
+
|
| 32 |
+
gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil")
|
| 33 |
+
inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')]
|
| 34 |
+
if device.type=='cpu':
|
| 35 |
+
inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)'))
|
| 36 |
+
|
| 37 |
+
def prepare_feat_maker():
|
| 38 |
+
config = runner.makeConfig()
|
| 39 |
+
class DummyDataset:
|
| 40 |
+
class_ids = [0]
|
| 41 |
+
fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device)
|
| 42 |
+
return fake_feat_maker
|
| 43 |
+
|
| 44 |
+
feat_maker = prepare_feat_maker()
|
| 45 |
+
has_fit = False
|
| 46 |
+
|
| 47 |
+
#-----------------------------------------------------------------#
|
| 48 |
+
|
| 49 |
+
def reset_layers():
|
| 50 |
+
global feat_maker
|
| 51 |
+
feat_maker = prepare_feat_maker()
|
| 52 |
+
|
| 53 |
+
def prepare_batch(q_img_pil, s_img_pil, s_mask_pil):
|
| 54 |
+
from data.dataset import FSSDataset
|
| 55 |
+
FSSDataset.initialize(img_size=400,datapath='')
|
| 56 |
+
q_img_tensor = FSSDataset.transform(q_img_pil)
|
| 57 |
+
s_img_tensor = FSSDataset.transform(s_img_pil)
|
| 58 |
+
s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L')))
|
| 59 |
+
s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze()
|
| 60 |
+
add_batch_dim = lambda t: t.unsqueeze(0)
|
| 61 |
+
add_kshot_dim = lambda t: t.unsqueeze(1)
|
| 62 |
+
fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])}
|
| 63 |
+
return fake_batch
|
| 64 |
+
|
| 65 |
+
norm = lambda t: (t - t.min()) / (t.max() - t.min())
|
| 66 |
+
def overlay(img, mask):
|
| 67 |
+
#img h,w,3(float) mask h,w(float)
|
| 68 |
+
return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5
|
| 69 |
+
|
| 70 |
+
def from_model(q_img, s_img, s_mask):
|
| 71 |
+
batch = prepare_batch(q_img, s_img, s_mask)
|
| 72 |
+
sseval = runner.SingleSampleEval(batch, feat_maker)
|
| 73 |
+
pred_logits, pred_mask = sseval.forward()
|
| 74 |
+
global has_fit
|
| 75 |
+
has_fit = True
|
| 76 |
+
# logit mask in range from -1 to 1, and mask-overlaid query image 0 to 1
|
| 77 |
+
return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy())
|
| 78 |
+
|
| 79 |
+
def predict(q,s,m,re_adapt,confirmed):
|
| 80 |
+
print(f'predict with {re_adapt=}, {confirmed=}')
|
| 81 |
+
print(f'{type(q)=}')
|
| 82 |
+
is_cache_run = re_adapt is None and confirmed is None
|
| 83 |
+
is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) #[2] pointing to support mask
|
| 84 |
+
print(f'{is_example=}')
|
| 85 |
+
|
| 86 |
+
if is_cache_run:
|
| 87 |
+
reset_layers()
|
| 88 |
+
pred = from_model(q,s,m)
|
| 89 |
+
msg = 'Results ready.'
|
| 90 |
+
return msg, *pred
|
| 91 |
+
elif re_adapt:
|
| 92 |
+
if confirmed:
|
| 93 |
+
reset_layers()
|
| 94 |
+
pred = from_model(q,s,m)
|
| 95 |
+
msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters."
|
| 96 |
+
return msg, *pred
|
| 97 |
+
else:
|
| 98 |
+
msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start."
|
| 99 |
+
return msg, blank_img, blank_img
|
| 100 |
+
else:
|
| 101 |
+
if is_example:
|
| 102 |
+
msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'."
|
| 103 |
+
return msg, blank_img, blank_img
|
| 104 |
+
else:
|
| 105 |
+
if has_fit:
|
| 106 |
+
pred = from_model(q,s,m)
|
| 107 |
+
msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'."
|
| 108 |
+
return msg, *pred
|
| 109 |
+
else:
|
| 110 |
+
msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'."
|
| 111 |
+
return msg, blank_img, blank_img
|
| 112 |
+
|
| 113 |
+
gradio_app = gr.Interface(
|
| 114 |
+
fn=predict,
|
| 115 |
+
inputs=inputs,
|
| 116 |
+
outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")],
|
| 117 |
+
description=description,
|
| 118 |
+
examples=example_episodes,
|
| 119 |
+
title="abcdfss",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
gradio_app.launch()
|
imgs/1.jpg
ADDED
|
imgs/457070_00.jpg
ADDED
|
imgs/457070_00.png
ADDED
|
imgs/5.jpg
ADDED
|
imgs/5.png
ADDED
|
imgs/549870_35.jpg
ADDED
|
imgs/CHNCXR_0282_0.png
ADDED
|
imgs/CHNCXR_0324_0.png
ADDED
|
imgs/CHNCXR_0324_0_mask.png
ADDED
|
imgs/ISIC_0000372.jpg
ADDED
|
imgs/ISIC_0013176.jpg
ADDED
|
imgs/ISIC_0013176_segmentation.png
ADDED
|
imgs/blank.png
ADDED
|
imgs/cake1.png
ADDED
|
imgs/cake2.png
ADDED
|
imgs/cake2_mask.png
ADDED
|
imgs/d_r_450_.jpg
ADDED
|
imgs/d_r_465_.bmp
ADDED
|
imgs/d_r_465_.jpg
ADDED
|
main.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
from core import runner
|
| 2 |
-
import torch
|
| 3 |
-
import argparse
|
| 4 |
-
|
| 5 |
-
def parse_opts():
|
| 6 |
-
r"""arguments"""
|
| 7 |
-
parser = argparse.ArgumentParser(description='Adapt Before Comparison - A New Perspective on Cross-Domain Few-Shot Segmentation')
|
| 8 |
-
|
| 9 |
-
# common
|
| 10 |
-
parser.add_argument('--benchmark', type=str, default='lung', choices=['fss', 'deepglobe', 'lung', 'isic', 'fss', 'lung'])
|
| 11 |
-
parser.add_argument('--datapath', type=str)
|
| 12 |
-
parser.add_argument('--nshot', type=int, default=1)
|
| 13 |
-
|
| 14 |
-
args = parser.parse_args()
|
| 15 |
-
return args
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
if __name__ == '__main__':
|
| 19 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 20 |
-
args = parse_opts()
|
| 21 |
-
print(args)
|
| 22 |
-
runner.args.benchmark = args.benchmark
|
| 23 |
-
runner.args.datapath = args.datapath
|
| 24 |
-
runner.args.nshot = args.nshot
|
| 25 |
-
|
| 26 |
-
dataloader = runner.makeDataloader()
|
| 27 |
-
config = runner.makeConfig()
|
| 28 |
-
feat_maker = runner.makeFeatureMaker(dataloader.dataset, config, device=device)
|
| 29 |
-
average_meter = runner.AverageMeterWrapper(dataloader, device)
|
| 30 |
-
|
| 31 |
-
for idx, batch in enumerate(dataloader):
|
| 32 |
-
sseval = runner.SingleSampleEval(batch, feat_maker)
|
| 33 |
-
sseval.forward()
|
| 34 |
-
sseval.calc_metrics()
|
| 35 |
-
average_meter.update(sseval)
|
| 36 |
-
average_meter.write(idx)
|
| 37 |
-
print('Result m|FB:', average_meter.average_meter.compute_iou())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
opencv-python
|
| 4 |
+
numpy
|
| 5 |
+
tensorboardX
|