q3 / app.py
SotaSF's picture
Upload folder using huggingface_hub
59baf73 verified
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import gradio as gr
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CKPT_PATH = hf_hub_download(repo_id="SotaSF/q3-model", filename="cyclegan_final.pt")
class ResnetBlock(nn.Module):
def __init__(self, c):
super().__init__()
self.b = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
nn.ReLU(True),
nn.ReflectionPad2d(1),
nn.Conv2d(c, c, 3),
nn.InstanceNorm2d(c),
)
def forward(self, x):
return x + self.b(x)
class ResnetGenerator(nn.Module):
def __init__(self, in_c=3, out_c=3, n_blocks=6, base=64):
super().__init__()
# c7s1-64
m = [nn.ReflectionPad2d(3), nn.Conv2d(in_c, base, 7), nn.InstanceNorm2d(base), nn.ReLU(True)]
f = base
# d128 and d256
for _ in range(2):
m += [nn.Conv2d(f, f * 2, 3, 2, 1), nn.InstanceNorm2d(f * 2), nn.ReLU(True)]
f *= 2
# R256
for _ in range(n_blocks):
m += [ResnetBlock(f)]
# u128 and u64
for _ in range(2):
m += [nn.ConvTranspose2d(f, f // 2, 3, 2, 1, output_padding=1), nn.InstanceNorm2d(f // 2), nn.ReLU(True)]
f //= 2
# c7s1-3
m += [nn.ReflectionPad2d(3), nn.Conv2d(base, out_c, 7), nn.Tanh()]
self.m = nn.Sequential(*m)
def forward(self, x):
return self.m(x)
def clean_sd(sd):
out = {}
for k, v in sd.items():
if k.startswith("module."):
out[k[7:]] = v
else:
out[k] = v
return out
g_ab = None
g_ba = None
load_error = None
def load_models():
global g_ab, g_ba, load_error
if g_ab is not None or load_error is not None:
return
try:
sd = torch.load(CKPT_PATH, map_location=device)
g_ab = ResnetGenerator(n_blocks=6).to(device)
g_ba = ResnetGenerator(n_blocks=6).to(device)
g_ab.load_state_dict(clean_sd(sd["G_AB"]), strict=False)
g_ba.load_state_dict(clean_sd(sd["G_BA"]), strict=False)
g_ab.eval()
g_ba.eval()
except Exception as e:
load_error = str(e)
def preprocess(inp, size=128):
img = Image.fromarray(inp.astype(np.uint8)).convert("RGB").resize((size, size))
arr = np.asarray(img).astype(np.float32) / 255.0
arr = (arr - 0.5) / 0.5
arr = np.transpose(arr, (2, 0, 1))
x = torch.from_numpy(arr).unsqueeze(0)
return x
@torch.no_grad()
def run(inp, direction='Sketch -> Photo (G_AB)'):
load_models()
if inp is None:
return None
if load_error is not None:
raise RuntimeError("Model load failed: " + load_error)
x = preprocess(inp, 128).to(device)
y = g_ab(x) if direction.startswith("Sketch") else g_ba(x)
y = (y * 0.5 + 0.5).clamp(0, 1)
return (y[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
with gr.Blocks(title="Q3: CycleGAN") as demo:
gr.Markdown("# Q3: CycleGAN")
inp = gr.Image(type="numpy", label="Input")
direction = gr.Radio(["Sketch -> Photo (G_AB)", "Photo -> Sketch (G_BA)"], value="Sketch -> Photo (G_AB)", label="Direction")
out = gr.Image(type="numpy", label="Output")
btn = gr.Button("Translate")
btn.click(run, [inp, direction], [out])
demo.launch(server_name="0.0.0.0", server_port=7860)