import numpy as np from PIL import Image import torch import torch.nn as nn import torchvision.transforms as T import gradio as gr from huggingface_hub import hf_hub_download device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class DownBlock(nn.Module): def __init__(self, in_c, out_c, norm=True): super().__init__() self.conv = nn.Conv2d(in_c, out_c, 4, 2, 1,bias=not norm) self.norm = nn.BatchNorm2d(out_c) if norm else nn.Identity() self.relu = nn.LeakyReLU(0.2, inplace=True) def forward(self, x): x = self.conv(x) x = self.norm(x) return self.relu(x) class UpBlock(nn.Module): def __init__(self, in_c, out_c, dropout=False): super().__init__() self.deconv = nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False) self.bn = nn.BatchNorm2d(out_c) self.relu = nn.ReLU(inplace=True) self.use_dropout = dropout if dropout: self.dropout = nn.Dropout(0.5) def forward(self, x): x = self.deconv(x) x = self.bn(x) x = self.relu(x) if self.use_dropout: x = self.dropout(x) return x class UNetGenerator(nn.Module): def __init__(self, in_c=3, out_c=3, f=64): super().__init__() self.d1 = DownBlock(in_c, f, norm=False) self.d2 = DownBlock(f, f * 2) self.d3 = DownBlock(f * 2, f * 4) self.d4 = DownBlock(f * 4, f * 8) self.d5 = DownBlock(f * 8, f * 8) self.d6 = DownBlock(f * 8, f * 8) self.d7 = DownBlock(f * 8, f * 8) self.d8 = DownBlock(f * 8, f * 8, norm=False) self.u1 = UpBlock(f * 8, f * 8, dropout=True) self.u2 = UpBlock(f * 16, f * 8, dropout=True) self.u3 = UpBlock(f * 16, f * 8, dropout=True) self.u4 = UpBlock(f * 16, f * 8) self.u5 = UpBlock(f * 16, f * 4) self.u6 = UpBlock(f * 8, f * 2) self.u7 = UpBlock(f * 4, f) self.final_deconv = nn.ConvTranspose2d(f * 2, out_c, 4, 2, 1) self.final_tanh = nn.Tanh() def forward(self, x): d1 = self.d1(x) d2 = self.d2(d1) d3 = self.d3(d2) d4 = self.d4(d3) d5 = self.d5(d4) d6 = self.d6(d5) d7 = self.d7(d6) d8 = self.d8(d7) u1 = self.u1(d8) u2 = self.u2(torch.cat([u1, d7], 1)) u3 = self.u3(torch.cat([u2, d6], 1)) u4 = self.u4(torch.cat([u3, d5], 1)) u5 = self.u5(torch.cat([u4, d4], 1)) u6 = self.u6(torch.cat([u5, d3], 1)) u7 = self.u7(torch.cat([u6, d2], 1)) out = self.final_deconv(torch.cat([u7, d1], 1)) return self.final_tanh(out) def clean_sd(sd): out={} for k,v in sd.items(): if k.startswith("module."): out[k[7:]] = v else: out[k] = v return out def preprocess(img, size=256, force_gray=False): if force_gray: img=img.convert("L").convert("RGB") else: img=img.convert("RGB") tfm=T.Compose([T.Resize((size,size)), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)]) return tfm(img).unsqueeze(0) @torch.no_grad() def run(inp, ckpt='pix2pix_final.pt', force_gray=False): if inp is None: return None # Download the model checkpoint from the model repo ckpt_path = hf_hub_download(repo_id="SotaSF/q2-model", filename=ckpt) model=UNetGenerator().to(device) sd=torch.load(ckpt_path, map_location=device) model.load_state_dict(clean_sd(sd["G"] if "G" in sd else sd), strict=True) model.eval() x=preprocess(Image.fromarray(inp.astype(np.uint8)), 256, bool(force_gray)).to(device) y=(model(x)*0.5+0.5).clamp(0,1) return (y[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8) with gr.Blocks() as demo: gr.Markdown("# Q2: Pix2Pix") i=gr.Image(type="numpy", label="Input") o=gr.Image(type="numpy", label="Output") ck=gr.Textbox(value="pix2pix_final.pt", label="Checkpoint") fg=gr.Checkbox(value=False, label="Force grayscale input") b=gr.Button("Generate") b.click(run, [i, ck, fg], [o]) demo.launch()