q2 / app.py
SotaSF's picture
Upload folder using huggingface_hub
3a26301 verified
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
import gradio as gr
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DownBlock(nn.Module):
def __init__(self, in_c, out_c, norm=True):
super().__init__()
self.conv = nn.Conv2d(in_c, out_c, 4, 2, 1,bias=not norm)
self.norm = nn.BatchNorm2d(out_c) if norm else nn.Identity()
self.relu = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return self.relu(x)
class UpBlock(nn.Module):
def __init__(self, in_c, out_c, dropout=False):
super().__init__()
self.deconv = nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)
self.bn = nn.BatchNorm2d(out_c)
self.relu = nn.ReLU(inplace=True)
self.use_dropout = dropout
if dropout:
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
x = self.relu(x)
if self.use_dropout:
x = self.dropout(x)
return x
class UNetGenerator(nn.Module):
def __init__(self, in_c=3, out_c=3, f=64):
super().__init__()
self.d1 = DownBlock(in_c, f, norm=False)
self.d2 = DownBlock(f, f * 2)
self.d3 = DownBlock(f * 2, f * 4)
self.d4 = DownBlock(f * 4, f * 8)
self.d5 = DownBlock(f * 8, f * 8)
self.d6 = DownBlock(f * 8, f * 8)
self.d7 = DownBlock(f * 8, f * 8)
self.d8 = DownBlock(f * 8, f * 8, norm=False)
self.u1 = UpBlock(f * 8, f * 8, dropout=True)
self.u2 = UpBlock(f * 16, f * 8, dropout=True)
self.u3 = UpBlock(f * 16, f * 8, dropout=True)
self.u4 = UpBlock(f * 16, f * 8)
self.u5 = UpBlock(f * 16, f * 4)
self.u6 = UpBlock(f * 8, f * 2)
self.u7 = UpBlock(f * 4, f)
self.final_deconv = nn.ConvTranspose2d(f * 2, out_c, 4, 2, 1)
self.final_tanh = nn.Tanh()
def forward(self, x):
d1 = self.d1(x)
d2 = self.d2(d1)
d3 = self.d3(d2)
d4 = self.d4(d3)
d5 = self.d5(d4)
d6 = self.d6(d5)
d7 = self.d7(d6)
d8 = self.d8(d7)
u1 = self.u1(d8)
u2 = self.u2(torch.cat([u1, d7], 1))
u3 = self.u3(torch.cat([u2, d6], 1))
u4 = self.u4(torch.cat([u3, d5], 1))
u5 = self.u5(torch.cat([u4, d4], 1))
u6 = self.u6(torch.cat([u5, d3], 1))
u7 = self.u7(torch.cat([u6, d2], 1))
out = self.final_deconv(torch.cat([u7, d1], 1))
return self.final_tanh(out)
def clean_sd(sd):
out={}
for k,v in sd.items():
if k.startswith("module."):
out[k[7:]] = v
else:
out[k] = v
return out
def preprocess(img, size=256, force_gray=False):
if force_gray: img=img.convert("L").convert("RGB")
else: img=img.convert("RGB")
tfm=T.Compose([T.Resize((size,size)), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)])
return tfm(img).unsqueeze(0)
@torch.no_grad()
def run(inp, ckpt='pix2pix_final.pt', force_gray=False):
if inp is None: return None
# Download the model checkpoint from the model repo
ckpt_path = hf_hub_download(repo_id="SotaSF/q2-model", filename=ckpt)
model=UNetGenerator().to(device)
sd=torch.load(ckpt_path, map_location=device)
model.load_state_dict(clean_sd(sd["G"] if "G" in sd else sd), strict=True)
model.eval()
x=preprocess(Image.fromarray(inp.astype(np.uint8)), 256, bool(force_gray)).to(device)
y=(model(x)*0.5+0.5).clamp(0,1)
return (y[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
with gr.Blocks() as demo:
gr.Markdown("# Q2: Pix2Pix")
i=gr.Image(type="numpy", label="Input")
o=gr.Image(type="numpy", label="Output")
ck=gr.Textbox(value="pix2pix_final.pt", label="Checkpoint")
fg=gr.Checkbox(value=False, label="Force grayscale input")
b=gr.Button("Generate")
b.click(run, [i, ck, fg], [o])
demo.launch()