lime-j commited on
Commit
7732686
·
1 Parent(s): 15a930e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from models.arch.revcol_sig import FullNet,FullNet_NLP
7
+ from models.arch.classifier import PretrainedConvNext
8
+ import torchvision.transforms.functional as TF
9
+ channels = [64, 128, 256, 512]
10
+ layers = [2, 2, 4, 2]
11
+ num_subnet = 4
12
+ net_i = FullNet_NLP(channels, layers, num_subnet, 4,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None,kernel_size=3).cpu()
13
+ net_i.load_state_dict(torch.load('./merge_stem_reg_014_00055524.pt')['icnn'])
14
+ net_c = PretrainedConvNext("convnext_small_in22k").cpu()
15
+ net_c.load_state_dict(torch.load('./cls_newdis_058_00014384.pt')['icnn'])
16
+ net_i.eval()
17
+ net_c.eval()
18
+ def align(x1):
19
+ x2 = x1
20
+ h, w = x1.height, x1.width
21
+ h, w = h // 32 * 32, w // 32 * 32
22
+ x1 = x1.resize((w, h))
23
+ x2 = x2.resize((w, h))
24
+ return x1
25
+ def predict(img):
26
+ with torch.no_grad():
27
+ img=align(img)
28
+ image_tensor=TF.to_tensor(img).cpu()
29
+ image_tensor=image_tensor.unsqueeze(0).cuda()
30
+ ipt=net_c(image_tensor)
31
+ output_i, output_j=net_i(image_tensor,ipt,prompt=True)
32
+ output_j_out=[]
33
+ for i in range(4):
34
+ out_reflection, out_clean = output_j[i][:, :3, ...], output_j[i][:, 3:, ...]
35
+ output_j_out.append(out_clean)
36
+ output_j_out.append(out_reflection)
37
+ clean = output_j_out[6]
38
+ clean=torch.clamp(clean, 0, 1)
39
+ return clean
40
+ demo=gr.Interface(predict, gr.Image(), "image")
41
+
42
+ demo.launch()