File size: 4,092 Bytes
1e76a6c
 
 
 
 
 
f0bad45
1e76a6c
 
 
 
 
 
3a26301
 
 
 
 
 
 
 
1e76a6c
 
 
 
3a26301
 
 
 
 
 
 
 
 
 
 
 
 
 
1e76a6c
 
 
 
3a26301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e76a6c
 
 
 
3a26301
 
 
 
1e76a6c
 
 
 
 
 
 
 
 
 
 
f0bad45
 
 
 
1e76a6c
f0bad45
1e76a6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
import gradio as gr
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DownBlock(nn.Module):
    def __init__(self, in_c, out_c, norm=True):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, 4, 2, 1,bias=not norm)
        self.norm = nn.BatchNorm2d(out_c) if norm else nn.Identity()
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        return self.relu(x)

class UpBlock(nn.Module):
    def __init__(self, in_c, out_c, dropout=False):
        super().__init__()
        self.deconv = nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU(inplace=True)
        self.use_dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_c=3, out_c=3, f=64):
        super().__init__()
        self.d1 = DownBlock(in_c, f, norm=False)
        self.d2 = DownBlock(f, f * 2)
        self.d3 = DownBlock(f * 2, f * 4)
        self.d4 = DownBlock(f * 4, f * 8)
        self.d5 = DownBlock(f * 8, f * 8)
        self.d6 = DownBlock(f * 8, f * 8)
        self.d7 = DownBlock(f * 8, f * 8)
        self.d8 = DownBlock(f * 8, f * 8, norm=False)

        self.u1 = UpBlock(f * 8, f * 8, dropout=True)
        self.u2 = UpBlock(f * 16, f * 8, dropout=True)
        self.u3 = UpBlock(f * 16, f * 8, dropout=True)
        self.u4 = UpBlock(f * 16, f * 8)
        self.u5 = UpBlock(f * 16, f * 4)
        self.u6 = UpBlock(f * 8, f * 2)
        self.u7 = UpBlock(f * 4, f)
        self.final_deconv = nn.ConvTranspose2d(f * 2, out_c, 4, 2, 1)
        self.final_tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.d1(x)
        d2 = self.d2(d1)
        d3 = self.d3(d2)
        d4 = self.d4(d3)
        d5 = self.d5(d4)
        d6 = self.d6(d5)
        d7 = self.d7(d6)
        d8 = self.d8(d7)

        u1 = self.u1(d8)
        u2 = self.u2(torch.cat([u1, d7], 1))
        u3 = self.u3(torch.cat([u2, d6], 1))
        u4 = self.u4(torch.cat([u3, d5], 1))
        u5 = self.u5(torch.cat([u4, d4], 1))
        u6 = self.u6(torch.cat([u5, d3], 1))
        u7 = self.u7(torch.cat([u6, d2], 1))
        out = self.final_deconv(torch.cat([u7, d1], 1))
        return self.final_tanh(out)

def clean_sd(sd):
    out={}
    for k,v in sd.items():
        if k.startswith("module."):
            out[k[7:]] = v
        else:
            out[k] = v
    return out

def preprocess(img, size=256, force_gray=False):
    if force_gray: img=img.convert("L").convert("RGB")
    else: img=img.convert("RGB")
    tfm=T.Compose([T.Resize((size,size)), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)])
    return tfm(img).unsqueeze(0)

@torch.no_grad()
def run(inp, ckpt='pix2pix_final.pt', force_gray=False):
    if inp is None: return None
    
    # Download the model checkpoint from the model repo
    ckpt_path = hf_hub_download(repo_id="SotaSF/q2-model", filename=ckpt)
    
    model=UNetGenerator().to(device)
    sd=torch.load(ckpt_path, map_location=device)
    model.load_state_dict(clean_sd(sd["G"] if "G" in sd else sd), strict=True)
    model.eval()
    x=preprocess(Image.fromarray(inp.astype(np.uint8)), 256, bool(force_gray)).to(device)
    y=(model(x)*0.5+0.5).clamp(0,1)
    return (y[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)

with gr.Blocks() as demo:
    gr.Markdown("# Q2: Pix2Pix")
    i=gr.Image(type="numpy", label="Input")
    o=gr.Image(type="numpy", label="Output")
    ck=gr.Textbox(value="pix2pix_final.pt", label="Checkpoint")
    fg=gr.Checkbox(value=False, label="Force grayscale input")
    b=gr.Button("Generate")
    b.click(run, [i, ck, fg], [o])

demo.launch()