File size: 11,804 Bytes
87c1bd1
b866603
6056c5f
 
6c1a1d5
6056c5f
 
d3fa5af
d7d8a78
af636b2
b2a4f5f
 
3b1f1d6
42b354f
d8264ca
71a649c
62db64f
71a649c
6056c5f
 
383d86e
dd76ffb
6056c5f
 
 
 
9d42a5c
100ae21
 
 
3df51c4
fa9a418
100ae21
36204dd
02710b1
d5b0ece
59f9e7b
5cab366
2d1e815
1ea08ed
2d1e815
 
 
1ea08ed
da8cd1e
100ae21
29437d1
 
 
aee8e3e
 
 
 
 
 
 
 
29437d1
 
aee8e3e
 
 
 
 
29437d1
 
 
 
 
 
 
 
 
 
 
 
 
 
eb9155a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2a4f5f
d3fa5af
d86b569
9d42a5c
472e711
6056c5f
d86b569
6056c5f
 
 
 
 
 
d86b569
09d07ec
6056c5f
d86b569
a335d23
6056c5f
 
 
 
 
 
 
 
 
b02f90e
dd76ffb
af636b2
9bcbd3d
 
cd5e4d5
09d07ec
 
602b01b
 
78142e8
9d42a5c
e8eee0c
697cdd7
6056c5f
 
 
 
 
aaf297a
963faaf
aaf297a
 
 
b406026
 
3df51c4
6056c5f
209dd06
6056c5f
 
 
 
 
 
 
 
 
 
 
 
 
09d07ec
 
 
6056c5f
 
d9dc8f6
09d07ec
d9dc8f6
6056c5f
 
05af878
ee6eede
 
 
100ae21
 
ee6eede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100ae21
7f70460
c08de50
 
 
 
 
 
 
7f06d9b
 
c08de50
 
 
7f70460
 
 
 
 
 
 
 
 
 
 
3df51c4
7f70460
 
 
eb9155a
7f70460
 
 
 
 
 
eb9155a
 
 
 
 
 
3df51c4
eb9155a
 
 
 
 
 
 
 
 
7f70460
 
 
 
 
 
91a4348
7f70460
 
100ae21
7f70460
100ae21
eb9155a
100ae21
eb9155a
100ae21
 
6056c5f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import tqdm
#import fastCNN
import numpy as np


import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
#os.system("/usr/local/bin/python -m pip install --upgrade pip")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
os.system("pip3 install opencv-python")
aaaa=0
#os.system("pip3 install pydensecrf")
#os.system("pip install argparse")
#import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
import cv2
import torch.nn.functional as F
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections

def show_coord(evt: gr.SelectData):
    return f"{evt.index[0]},{evt.index[1]}"

def generate_mask(model_type,img, coord):
    #x, y = map(int, coord.split(','))
    #
    mask = sepia(model_type,(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8), stack_image=False) 
    mask = F.interpolate(torch.from_numpy(mask).unsqueeze(0).unsqueeze(0),size=[img.shape[0],img.shape[1]],mode='bilinear').squeeze().numpy()
    col = torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
    col=col/col.max()
    mask_torch=torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
    mask_torch=mask_torch/mask_torch.max()
    #col[:,:,0]=0
    img=img/img.max()*255
    col=col*255
    col[:,:,0]=0
    mix = (1-mask_torch)*img+mask_torch*img*0.5+mask_torch*col*0.5
    return mix.numpy().astype(np.uint8)#overlay_mask(img, mask)

def create_mode2_interface():
    with gr.Blocks() as mode2:
        with gr.Column():
            img_input = gr.Image(
                type="numpy",
                sources=["upload"],  # 正确复数形式参数[2](@ref)
                label="点击上传图片并选择点",
                interactive=True
            )
        
            # 坐标存储组件
            coord_store = gr.Textbox(visible=False)
            
            # 绑定点击事件
            @img_input.select(inputs=[], outputs=coord_store)
            def capture_coordinates(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            # 修改3:正确绑定点击事件
            @img_input.select(inputs=img_input, outputs=coord_store)
            def store_coordinate(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            btn = gr.Button("生成分割掩码")
            mask_output = gr.Image(label="分割结果")
            
        btn.click(
            generate_mask, 
            inputs=[img_input, coord_store],
            outputs=mask_output
        )
    return mode2    

def create_mode3_interface():
    with gr.Blocks() as mode2:
        with gr.Column():
            img_input = gr.Image(
                type="numpy",
                sources=["upload"],  # 正确复数形式参数[2](@ref)
                label="点击上传图片并选择框",
                interactive=True
            )
        
            # 坐标存储组件
            coord_store = gr.Textbox(visible=False)
            
            # 绑定点击事件
            @img_input.select(inputs=[], outputs=coord_store)
            def capture_coordinates(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            # 修改3:正确绑定点击事件
            @img_input.select(inputs=img_input, outputs=coord_store)
            def store_coordinate(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            btn = gr.Button("生成分割掩码")
            mask_output = gr.Image(label="分割结果")
            
        btn.click(
            generate_mask, 
            inputs=[img_input, coord_store],
            outputs=mask_output
        )
    return mode2 
#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
  new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size,stack_image=True):
    print('test')
    #device=device
    hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
    img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.449], std=[0.226])])
    with torch.no_grad():
        
        group_img=torch.rand(5,3,224,224)
        for i in range(5):
          group_img[i]=img_transform(Image.fromarray(img_list[i]))
        _,pred_mask=net(group_img*1)
        pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
        #pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
        img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8) 
                    for i in range(5)]
        pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
        if not stack_image:
            return pred_mask[0]
        #for i in range(5):
        #    print(img_list[i].shape,pred_mask[i].shape)
        #pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
        print(pred_mask[0].shape)
        white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long()
        result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)]
        #w, h = 224,224#Image.open(image_list[i][j]).size
        #result = result.resize((w, h), Image.BILINEAR)
        #result.convert('L').save('0.png')
        print('done')
        return result
        
img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)]

#simly test
res=test('cpu',net,img_lst,5,224)
'''for i in range(5):
    assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3'''
def sepia(model_type,img1,img2,img3,img4,img5,stack_image=True):
  print('sepia')
  print(img1.shape,img2.shape,img3.shape,img4.shape,img5.shape)
  '''ans=[]
  print(len(input_imgs))
  for input_img in input_imgs:
    sepia_filter = np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    )
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    ans.append(input_img)'''
  img_list=[img1,img2,img3,img4,img5]
  h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
  #print(type(img1))
  #print(img1.shape)
  result_list=test(device,net,img_list,5,224,stack_image)
  if not stack_image:
     return result_list
  #result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
  img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
  white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
  
  return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)

#gr.Image(shape=(224, 2))
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"])
#demo.launch(debug=True)
#replace Interface with Blocks
def create_mode1_interface():
  with gr.Blocks() as demo:
    with gr.Row():
        # 创建5列网格布局
        with gr.Column(scale=1, min_width=150):
            input1 = gr.Image(label="image1", type="numpy")
        with gr.Column(scale=1, min_width=150):
            input2 = gr.Image(label="image2", type="numpy")
        with gr.Column(scale=1, min_width=150):
            input3 = gr.Image(label="image3", type="numpy")
        with gr.Column(scale=1, min_width=150):
            input4 = gr.Image(label="image4", type="numpy")
        with gr.Column(scale=1, min_width=150):
            input5 = gr.Image(label="image5", type="numpy")
    
    btn = gr.Button("start processing")
    
    with gr.Row():
        output = gr.Image(label="output", type="numpy")

    #bind function
    btn.click(
        fn=sepia,
        inputs=[input1, input2, input3, input4, input5],
        outputs=output
    )

with gr.Blocks(title="交互式图像组分割系统") as demo:
    # 模式选择器
    with gr.Row():
        mode = gr.Radio(
            ["多图协同分割", "点提示交互分割","框提示交互分割"], 
            value="多图协同分割",
            label="运行模式"
        )
        model_selector = gr.Dropdown(
            choices=["RepViT-SAM", "EdgeSAM", "CoSOD-SAM"],
            value="CoSOD-SAM",
            label="选择模型",
            container=False  # 去除默认容器边框
        )
    # 使用Tab容器替代独立Blocks
    with gr.Tabs() as mode_container:
        with gr.Tab("多图模式", id=0) as tab1:
            # 模式1界面组件
            with gr.Row():
                inputs = [gr.Image(type="numpy", label=f"图像{i+1}") for i in range(5)]
            process_btn = gr.Button("开始处理")
            output_img = gr.Image(label="处理结果")
            
            process_btn.click(
                sepia,
                inputs=[model_selector]+inputs,
                outputs=output_img
            )
            
        with gr.Tab("点选交互模式", id=1) as tab2:
            # 模式2界面组件
            img_input = gr.Image(type="numpy", label="点击上传图片并选择点")
            coord_store = gr.Textbox(visible=False)
            mask_btn = gr.Button("生成分割掩码")
            mask_output = gr.Image(label="分割结果")
            
            @img_input.select(inputs=[], outputs=coord_store)
            def store_coordinate(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            mask_btn.click(
                generate_mask,
                inputs=[model_selector,img_input, coord_store],
                outputs=mask_output
            )
        with gr.Tab("框选交互模式", id=2) as tab3:
            # 模式2界面组件
            img_input = gr.Image(type="numpy", label="点击上传图片并选择框")
            coord_store = gr.Textbox(visible=False)
            mask_btn = gr.Button("生成分割掩码")
            mask_output = gr.Image(label="分割结果")
            
            @img_input.select(inputs=[], outputs=coord_store)
            def store_coordinate(evt: gr.SelectData):
                return f"{evt.index[0]},{evt.index[1]}"
            
            mask_btn.click(
                generate_mask,
                inputs=[model_selector, img_input, coord_store],
                outputs=mask_output
            )
    
    # 动态显示控制
    mode.change(
        lambda x: (gr.update(visible=x=="多图协同分割"), gr.update(visible=x=="点提示交互分割"), gr.update(visible=x=="框提示交互分割")),
        inputs=mode,
        outputs=[tab1, tab2, tab3]
    )

demo.launch(debug=True)