| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as T |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| class DownBlock(nn.Module): |
| def __init__(self, in_c, out_c, norm=True): |
| super().__init__() |
| self.conv = nn.Conv2d(in_c, out_c, 4, 2, 1,bias=not norm) |
| self.norm = nn.BatchNorm2d(out_c) if norm else nn.Identity() |
| self.relu = nn.LeakyReLU(0.2, inplace=True) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| x = self.norm(x) |
| return self.relu(x) |
|
|
| class UpBlock(nn.Module): |
| def __init__(self, in_c, out_c, dropout=False): |
| super().__init__() |
| self.deconv = nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False) |
| self.bn = nn.BatchNorm2d(out_c) |
| self.relu = nn.ReLU(inplace=True) |
| self.use_dropout = dropout |
| if dropout: |
| self.dropout = nn.Dropout(0.5) |
|
|
| def forward(self, x): |
| x = self.deconv(x) |
| x = self.bn(x) |
| x = self.relu(x) |
| if self.use_dropout: |
| x = self.dropout(x) |
| return x |
|
|
| class UNetGenerator(nn.Module): |
| def __init__(self, in_c=3, out_c=3, f=64): |
| super().__init__() |
| self.d1 = DownBlock(in_c, f, norm=False) |
| self.d2 = DownBlock(f, f * 2) |
| self.d3 = DownBlock(f * 2, f * 4) |
| self.d4 = DownBlock(f * 4, f * 8) |
| self.d5 = DownBlock(f * 8, f * 8) |
| self.d6 = DownBlock(f * 8, f * 8) |
| self.d7 = DownBlock(f * 8, f * 8) |
| self.d8 = DownBlock(f * 8, f * 8, norm=False) |
|
|
| self.u1 = UpBlock(f * 8, f * 8, dropout=True) |
| self.u2 = UpBlock(f * 16, f * 8, dropout=True) |
| self.u3 = UpBlock(f * 16, f * 8, dropout=True) |
| self.u4 = UpBlock(f * 16, f * 8) |
| self.u5 = UpBlock(f * 16, f * 4) |
| self.u6 = UpBlock(f * 8, f * 2) |
| self.u7 = UpBlock(f * 4, f) |
| self.final_deconv = nn.ConvTranspose2d(f * 2, out_c, 4, 2, 1) |
| self.final_tanh = nn.Tanh() |
|
|
| def forward(self, x): |
| d1 = self.d1(x) |
| d2 = self.d2(d1) |
| d3 = self.d3(d2) |
| d4 = self.d4(d3) |
| d5 = self.d5(d4) |
| d6 = self.d6(d5) |
| d7 = self.d7(d6) |
| d8 = self.d8(d7) |
|
|
| u1 = self.u1(d8) |
| u2 = self.u2(torch.cat([u1, d7], 1)) |
| u3 = self.u3(torch.cat([u2, d6], 1)) |
| u4 = self.u4(torch.cat([u3, d5], 1)) |
| u5 = self.u5(torch.cat([u4, d4], 1)) |
| u6 = self.u6(torch.cat([u5, d3], 1)) |
| u7 = self.u7(torch.cat([u6, d2], 1)) |
| out = self.final_deconv(torch.cat([u7, d1], 1)) |
| return self.final_tanh(out) |
|
|
| def clean_sd(sd): |
| out={} |
| for k,v in sd.items(): |
| if k.startswith("module."): |
| out[k[7:]] = v |
| else: |
| out[k] = v |
| return out |
|
|
| def preprocess(img, size=256, force_gray=False): |
| if force_gray: img=img.convert("L").convert("RGB") |
| else: img=img.convert("RGB") |
| tfm=T.Compose([T.Resize((size,size)), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)]) |
| return tfm(img).unsqueeze(0) |
|
|
| @torch.no_grad() |
| def run(inp, ckpt='pix2pix_final.pt', force_gray=False): |
| if inp is None: return None |
| |
| |
| ckpt_path = hf_hub_download(repo_id="SotaSF/q2-model", filename=ckpt) |
| |
| model=UNetGenerator().to(device) |
| sd=torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(clean_sd(sd["G"] if "G" in sd else sd), strict=True) |
| model.eval() |
| x=preprocess(Image.fromarray(inp.astype(np.uint8)), 256, bool(force_gray)).to(device) |
| y=(model(x)*0.5+0.5).clamp(0,1) |
| return (y[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Q2: Pix2Pix") |
| i=gr.Image(type="numpy", label="Input") |
| o=gr.Image(type="numpy", label="Output") |
| ck=gr.Textbox(value="pix2pix_final.pt", label="Checkpoint") |
| fg=gr.Checkbox(value=False, label="Force grayscale input") |
| b=gr.Button("Generate") |
| b.click(run, [i, ck, fg], [o]) |
|
|
| demo.launch() |
|
|