File size: 1,954 Bytes
27a537a
7732686
 
 
 
 
95a65f7
7732686
 
27a537a
b3787da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d50bde
 
 
b3787da
 
d66e84e
 
b3787da
fe83596
15064f3
b3787da
 
f6bdb80
 
 
db76fe1
15064f3
 
48fac1d
7732686
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import spaces
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch.RDnet_ import FullNet_NLP
from models.arch.classifier import PretrainedConvNext
import torchvision.transforms.functional as TF

class Pipe:
    def __init__(self):
        channels = [64, 128, 256, 512]
        layers = [2, 2, 4, 2]
        num_subnet = 4
        self.net_i = FullNet_NLP(channels, layers, num_subnet, 4,num_classes=1000, drop_path=0,save_memory=True, inter_supv=True, head_init_scale=None,kernel_size=3)
        for param in self.net_i.parameters():
            param.data = param.data.to(torch.float16)
        self.net_i.load_state_dict(torch.load('./fp16_check.pt')['icnn'])
        self.net_i = self.net_i.to('cpu')
        self.net_c = PretrainedConvNext("convnext_small_in22k")
        self.net_c.load_state_dict(torch.load('./classifier_32.pt')['icnn'])
        self.net_c=self.net_c.to('cpu')
        #net_c=net_c.to('cuda')
        self.net_i.eval().to('cuda')
        self.net_c.eval().to('cuda')
        self.output = None

    def __call__(self, img):
        with torch.no_grad():
            image_tensor = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0)
            h, w = image_tensor.shape[-2], image_tensor.shape[-1]
            h, w = h // 32 * 32, w // 32 * 32
            image_tensor = torch.nn.functional.interpolate(image_tensor, size=(h, w), mode='bilinear').cuda()
            
            ipt=self.net_c(image_tensor)
            image_tensor = image_tensor.half()
            ipt = ipt.half()
            output_i, output_j=self.net_i(image_tensor,ipt,prompt=True)
            clean = output_j[-1][:, 3:, ...]
            clean=torch.clamp(clean, 0, 1)
            
            self.output = clean

pipe = Pipe()


@spaces.GPU(duration=120)
def predict(img):
    pipe(img)
    return pipe.output
demo=gr.Interface(predict, gr.Image(), "image")

demo.launch()