Spaces:
Sleeping
Sleeping
File size: 11,804 Bytes
87c1bd1 b866603 6056c5f 6c1a1d5 6056c5f d3fa5af d7d8a78 af636b2 b2a4f5f 3b1f1d6 42b354f d8264ca 71a649c 62db64f 71a649c 6056c5f 383d86e dd76ffb 6056c5f 9d42a5c 100ae21 3df51c4 fa9a418 100ae21 36204dd 02710b1 d5b0ece 59f9e7b 5cab366 2d1e815 1ea08ed 2d1e815 1ea08ed da8cd1e 100ae21 29437d1 aee8e3e 29437d1 aee8e3e 29437d1 eb9155a b2a4f5f d3fa5af d86b569 9d42a5c 472e711 6056c5f d86b569 6056c5f d86b569 09d07ec 6056c5f d86b569 a335d23 6056c5f b02f90e dd76ffb af636b2 9bcbd3d cd5e4d5 09d07ec 602b01b 78142e8 9d42a5c e8eee0c 697cdd7 6056c5f aaf297a 963faaf aaf297a b406026 3df51c4 6056c5f 209dd06 6056c5f 09d07ec 6056c5f d9dc8f6 09d07ec d9dc8f6 6056c5f 05af878 ee6eede 100ae21 ee6eede 100ae21 7f70460 c08de50 7f06d9b c08de50 7f70460 3df51c4 7f70460 eb9155a 7f70460 eb9155a 3df51c4 eb9155a 7f70460 91a4348 7f70460 100ae21 7f70460 100ae21 eb9155a 100ae21 eb9155a 100ae21 6056c5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
import tqdm
#import fastCNN
import numpy as np
import gradio as gr
import os
#os.system("sudo apt-get install nvIDia-cuda-toolkit")
os.system("pip3 install torch")
#os.system("/usr/local/bin/python -m pip install --upgrade pip")
os.system("pip3 install collections")
os.system("pip3 install torchvision")
os.system("pip3 install einops")
os.system("pip3 install opencv-python")
aaaa=0
#os.system("pip3 install pydensecrf")
#os.system("pip install argparse")
#import pydensecrf.densecrf as dcrf
from PIL import Image
import torch
import cv2
import torch.nn.functional as F
from torchvision import transforms
from model_video import build_model
import numpy as np
import collections
def show_coord(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
def generate_mask(model_type,img, coord):
#x, y = map(int, coord.split(','))
#
mask = sepia(model_type,(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8),(img*0.999999).astype(np.uint8), stack_image=False)
mask = F.interpolate(torch.from_numpy(mask).unsqueeze(0).unsqueeze(0),size=[img.shape[0],img.shape[1]],mode='bilinear').squeeze().numpy()
col = torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
col=col/col.max()
mask_torch=torch.from_numpy(mask).squeeze().unsqueeze(2).repeat(1,1,3)
mask_torch=mask_torch/mask_torch.max()
#col[:,:,0]=0
img=img/img.max()*255
col=col*255
col[:,:,0]=0
mix = (1-mask_torch)*img+mask_torch*img*0.5+mask_torch*col*0.5
return mix.numpy().astype(np.uint8)#overlay_mask(img, mask)
def create_mode2_interface():
with gr.Blocks() as mode2:
with gr.Column():
img_input = gr.Image(
type="numpy",
sources=["upload"], # 正确复数形式参数[2](@ref)
label="点击上传图片并选择点",
interactive=True
)
# 坐标存储组件
coord_store = gr.Textbox(visible=False)
# 绑定点击事件
@img_input.select(inputs=[], outputs=coord_store)
def capture_coordinates(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
# 修改3:正确绑定点击事件
@img_input.select(inputs=img_input, outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
btn.click(
generate_mask,
inputs=[img_input, coord_store],
outputs=mask_output
)
return mode2
def create_mode3_interface():
with gr.Blocks() as mode2:
with gr.Column():
img_input = gr.Image(
type="numpy",
sources=["upload"], # 正确复数形式参数[2](@ref)
label="点击上传图片并选择框",
interactive=True
)
# 坐标存储组件
coord_store = gr.Textbox(visible=False)
# 绑定点击事件
@img_input.select(inputs=[], outputs=coord_store)
def capture_coordinates(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
# 修改3:正确绑定点击事件
@img_input.select(inputs=img_input, outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
btn.click(
generate_mask,
inputs=[img_input, coord_store],
outputs=mask_output
)
return mode2
#import argparse
device='cpu'
net = build_model(device).to(device)
#net=torch.nn.DataParallel(net)
model_path = 'image_best.pth'
print(model_path)
weight=torch.load(model_path,map_location=torch.device(device))
#print(type(weight))
new_dict=collections.OrderedDict()
for k in weight.keys():
new_dict[k[len('module.'):]]=weight[k]
net.load_state_dict(new_dict)
net.eval()
net = net.to(device)
def test(gpu_id, net, img_list, group_size, img_size,stack_image=True):
print('test')
#device=device
hl,wl=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
img_transform = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
img_transform_gray = transforms.Compose([transforms.Resize((img_size, img_size)), transforms.ToTensor(),
transforms.Normalize(mean=[0.449], std=[0.226])])
with torch.no_grad():
group_img=torch.rand(5,3,224,224)
for i in range(5):
group_img[i]=img_transform(Image.fromarray(img_list[i]))
_,pred_mask=net(group_img*1)
pred_mask=(pred_mask.detach().squeeze()*255)#.numpy().astype(np.uint8)
#pred_mask=[F.interpolate(pred_mask[i].reshape(1,1,pred_mask[i].shape[-2],pred_mask[i].shape[-1]),size=(size,size),mode='bilinear').squeeze().numpy().astype(np.uint8) for i in range(5)]
img_resize=[((group_img[i]-group_img[i].min())/(group_img[i].max()-group_img[i].min())*255).permute(1,2,0).contiguous().numpy().astype(np.uint8)
for i in range(5)]
pred_mask=[(pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]#[(img_resize[i],pred_mask[i].numpy().astype(np.uint8)) for i in range(5)]
if not stack_image:
return pred_mask[0]
#for i in range(5):
# print(img_list[i].shape,pred_mask[i].shape)
#pred_mask=[crf_refine(img_list[i],pred_mask[i]) for i in range(5)]
print(pred_mask[0].shape)
white=(torch.ones(2,pred_mask[0].shape[1],3)*255).long()
result = [torch.cat([torch.from_numpy(img_resize[i]),white,torch.from_numpy(pred_mask[i]).unsqueeze(2).repeat(1,1,3)],dim=0).numpy() for i in range(5)]
#w, h = 224,224#Image.open(image_list[i][j]).size
#result = result.resize((w, h), Image.BILINEAR)
#result.convert('L').save('0.png')
print('done')
return result
img_lst=[(torch.rand(352,352,3)*255).numpy().astype(np.uint8) for i in range(5)]
#simly test
res=test('cpu',net,img_lst,5,224)
'''for i in range(5):
assert res[i].shape[0]==352 and res[i].shape[1]==352 and res[i].shape[2]==3'''
def sepia(model_type,img1,img2,img3,img4,img5,stack_image=True):
print('sepia')
print(img1.shape,img2.shape,img3.shape,img4.shape,img5.shape)
'''ans=[]
print(len(input_imgs))
for input_img in input_imgs:
sepia_filter = np.array(
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
)
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
ans.append(input_img)'''
img_list=[img1,img2,img3,img4,img5]
h_list,w_list=[_.shape[0] for _ in img_list],[_.shape[1] for _ in img_list]
#print(type(img1))
#print(img1.shape)
result_list=test(device,net,img_list,5,224,stack_image)
if not stack_image:
return result_list
#result_list=[result_list[i].resize((w_list[i], h_list[i]), Image.BILINEAR) for i in range(5)]
img1,img2,img3,img4,img5=result_list#test('cpu',net,img_list,5,224)
white=(torch.ones(img1.shape[0],2,3)*255).numpy().astype(np.uint8)
return np.concatenate([img1,white,img2,white,img3,white,img4,white,img5],axis=1)
#gr.Image(shape=(224, 2))
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image","image","image","image","image"])#gr.Interface(sepia, gr.Image(shape=(200, 200)), "image")
#demo = gr.Interface(sepia, inputs=["image","image","image","image","image"], outputs=["image"])
#demo.launch(debug=True)
#replace Interface with Blocks
def create_mode1_interface():
with gr.Blocks() as demo:
with gr.Row():
# 创建5列网格布局
with gr.Column(scale=1, min_width=150):
input1 = gr.Image(label="image1", type="numpy")
with gr.Column(scale=1, min_width=150):
input2 = gr.Image(label="image2", type="numpy")
with gr.Column(scale=1, min_width=150):
input3 = gr.Image(label="image3", type="numpy")
with gr.Column(scale=1, min_width=150):
input4 = gr.Image(label="image4", type="numpy")
with gr.Column(scale=1, min_width=150):
input5 = gr.Image(label="image5", type="numpy")
btn = gr.Button("start processing")
with gr.Row():
output = gr.Image(label="output", type="numpy")
#bind function
btn.click(
fn=sepia,
inputs=[input1, input2, input3, input4, input5],
outputs=output
)
with gr.Blocks(title="交互式图像组分割系统") as demo:
# 模式选择器
with gr.Row():
mode = gr.Radio(
["多图协同分割", "点提示交互分割","框提示交互分割"],
value="多图协同分割",
label="运行模式"
)
model_selector = gr.Dropdown(
choices=["RepViT-SAM", "EdgeSAM", "CoSOD-SAM"],
value="CoSOD-SAM",
label="选择模型",
container=False # 去除默认容器边框
)
# 使用Tab容器替代独立Blocks
with gr.Tabs() as mode_container:
with gr.Tab("多图模式", id=0) as tab1:
# 模式1界面组件
with gr.Row():
inputs = [gr.Image(type="numpy", label=f"图像{i+1}") for i in range(5)]
process_btn = gr.Button("开始处理")
output_img = gr.Image(label="处理结果")
process_btn.click(
sepia,
inputs=[model_selector]+inputs,
outputs=output_img
)
with gr.Tab("点选交互模式", id=1) as tab2:
# 模式2界面组件
img_input = gr.Image(type="numpy", label="点击上传图片并选择点")
coord_store = gr.Textbox(visible=False)
mask_btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
@img_input.select(inputs=[], outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
mask_btn.click(
generate_mask,
inputs=[model_selector,img_input, coord_store],
outputs=mask_output
)
with gr.Tab("框选交互模式", id=2) as tab3:
# 模式2界面组件
img_input = gr.Image(type="numpy", label="点击上传图片并选择框")
coord_store = gr.Textbox(visible=False)
mask_btn = gr.Button("生成分割掩码")
mask_output = gr.Image(label="分割结果")
@img_input.select(inputs=[], outputs=coord_store)
def store_coordinate(evt: gr.SelectData):
return f"{evt.index[0]},{evt.index[1]}"
mask_btn.click(
generate_mask,
inputs=[model_selector, img_input, coord_store],
outputs=mask_output
)
# 动态显示控制
mode.change(
lambda x: (gr.update(visible=x=="多图协同分割"), gr.update(visible=x=="点提示交互分割"), gr.update(visible=x=="框提示交互分割")),
inputs=mode,
outputs=[tab1, tab2, tab3]
)
demo.launch(debug=True)
|