Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from models.arch.revcol_sig import FullNet,FullNet_NLP
|
| 7 |
+
from models.arch.classifier import PretrainedConvNext
|
| 8 |
+
import torchvision.transforms.functional as TF
|
| 9 |
+
channels = [64, 128, 256, 512]
|
| 10 |
+
layers = [2, 2, 4, 2]
|
| 11 |
+
num_subnet = 4
|
| 12 |
+
net_i = FullNet_NLP(channels, layers, num_subnet, 4,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None,kernel_size=3).cpu()
|
| 13 |
+
net_i.load_state_dict(torch.load('./merge_stem_reg_014_00055524.pt')['icnn'])
|
| 14 |
+
net_c = PretrainedConvNext("convnext_small_in22k").cpu()
|
| 15 |
+
net_c.load_state_dict(torch.load('./cls_newdis_058_00014384.pt')['icnn'])
|
| 16 |
+
net_i.eval()
|
| 17 |
+
net_c.eval()
|
| 18 |
+
def align(x1):
|
| 19 |
+
x2 = x1
|
| 20 |
+
h, w = x1.height, x1.width
|
| 21 |
+
h, w = h // 32 * 32, w // 32 * 32
|
| 22 |
+
x1 = x1.resize((w, h))
|
| 23 |
+
x2 = x2.resize((w, h))
|
| 24 |
+
return x1
|
| 25 |
+
def predict(img):
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
img=align(img)
|
| 28 |
+
image_tensor=TF.to_tensor(img).cpu()
|
| 29 |
+
image_tensor=image_tensor.unsqueeze(0).cuda()
|
| 30 |
+
ipt=net_c(image_tensor)
|
| 31 |
+
output_i, output_j=net_i(image_tensor,ipt,prompt=True)
|
| 32 |
+
output_j_out=[]
|
| 33 |
+
for i in range(4):
|
| 34 |
+
out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
|
| 35 |
+
output_j_out.append(out_clean)
|
| 36 |
+
output_j_out.append(out_reflection)
|
| 37 |
+
clean = output_j_out[6]
|
| 38 |
+
clean=torch.clamp(clean, 0, 1)
|
| 39 |
+
return clean
|
| 40 |
+
demo=gr.Interface(predict, gr.Image(), "image")
|
| 41 |
+
|
| 42 |
+
demo.launch()
|