| import os |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| CKPT_PATH = hf_hub_download(repo_id="SotaSF/q3-model", filename="cyclegan_final.pt") |
|
|
| class ResnetBlock(nn.Module): |
| def __init__(self, c): |
| super().__init__() |
| self.b = nn.Sequential( |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(c, c, 3), |
| nn.InstanceNorm2d(c), |
| nn.ReLU(True), |
| nn.ReflectionPad2d(1), |
| nn.Conv2d(c, c, 3), |
| nn.InstanceNorm2d(c), |
| ) |
|
|
| def forward(self, x): |
| return x + self.b(x) |
|
|
| class ResnetGenerator(nn.Module): |
| def __init__(self, in_c=3, out_c=3, n_blocks=6, base=64): |
| super().__init__() |
| |
| m = [nn.ReflectionPad2d(3), nn.Conv2d(in_c, base, 7), nn.InstanceNorm2d(base), nn.ReLU(True)] |
| f = base |
| |
| for _ in range(2): |
| m += [nn.Conv2d(f, f * 2, 3, 2, 1), nn.InstanceNorm2d(f * 2), nn.ReLU(True)] |
| f *= 2 |
| |
| for _ in range(n_blocks): |
| m += [ResnetBlock(f)] |
| |
| for _ in range(2): |
| m += [nn.ConvTranspose2d(f, f // 2, 3, 2, 1, output_padding=1), nn.InstanceNorm2d(f // 2), nn.ReLU(True)] |
| f //= 2 |
| |
| m += [nn.ReflectionPad2d(3), nn.Conv2d(base, out_c, 7), nn.Tanh()] |
| self.m = nn.Sequential(*m) |
|
|
| def forward(self, x): |
| return self.m(x) |
|
|
| def clean_sd(sd): |
| out = {} |
| for k, v in sd.items(): |
| if k.startswith("module."): |
| out[k[7:]] = v |
| else: |
| out[k] = v |
| return out |
|
|
| g_ab = None |
| g_ba = None |
| load_error = None |
|
|
| def load_models(): |
| global g_ab, g_ba, load_error |
| if g_ab is not None or load_error is not None: |
| return |
| try: |
| sd = torch.load(CKPT_PATH, map_location=device) |
| g_ab = ResnetGenerator(n_blocks=6).to(device) |
| g_ba = ResnetGenerator(n_blocks=6).to(device) |
| g_ab.load_state_dict(clean_sd(sd["G_AB"]), strict=False) |
| g_ba.load_state_dict(clean_sd(sd["G_BA"]), strict=False) |
| g_ab.eval() |
| g_ba.eval() |
| except Exception as e: |
| load_error = str(e) |
|
|
| def preprocess(inp, size=128): |
| img = Image.fromarray(inp.astype(np.uint8)).convert("RGB").resize((size, size)) |
| arr = np.asarray(img).astype(np.float32) / 255.0 |
| arr = (arr - 0.5) / 0.5 |
| arr = np.transpose(arr, (2, 0, 1)) |
| x = torch.from_numpy(arr).unsqueeze(0) |
| return x |
|
|
| @torch.no_grad() |
| def run(inp, direction='Sketch -> Photo (G_AB)'): |
| load_models() |
| if inp is None: |
| return None |
| if load_error is not None: |
| raise RuntimeError("Model load failed: " + load_error) |
| x = preprocess(inp, 128).to(device) |
| y = g_ab(x) if direction.startswith("Sketch") else g_ba(x) |
| y = (y * 0.5 + 0.5).clamp(0, 1) |
| return (y[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
| with gr.Blocks(title="Q3: CycleGAN") as demo: |
| gr.Markdown("# Q3: CycleGAN") |
| inp = gr.Image(type="numpy", label="Input") |
| direction = gr.Radio(["Sketch -> Photo (G_AB)", "Photo -> Sketch (G_BA)"], value="Sketch -> Photo (G_AB)", label="Direction") |
| out = gr.Image(type="numpy", label="Output") |
| btn = gr.Button("Translate") |
| btn.click(run, [inp, direction], [out]) |
|
|
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|