import os import numpy as np from PIL import Image import torch import torch.nn as nn import gradio as gr from huggingface_hub import hf_hub_download device = torch.device("cuda" if torch.cuda.is_available() else "cpu") CKPT_PATH = hf_hub_download(repo_id="SotaSF/q3-model", filename="cyclegan_final.pt") class ResnetBlock(nn.Module): def __init__(self, c): super().__init__() self.b = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(c, c, 3), nn.InstanceNorm2d(c), nn.ReLU(True), nn.ReflectionPad2d(1), nn.Conv2d(c, c, 3), nn.InstanceNorm2d(c), ) def forward(self, x): return x + self.b(x) class ResnetGenerator(nn.Module): def __init__(self, in_c=3, out_c=3, n_blocks=6, base=64): super().__init__() # c7s1-64 m = [nn.ReflectionPad2d(3), nn.Conv2d(in_c, base, 7), nn.InstanceNorm2d(base), nn.ReLU(True)] f = base # d128 and d256 for _ in range(2): m += [nn.Conv2d(f, f * 2, 3, 2, 1), nn.InstanceNorm2d(f * 2), nn.ReLU(True)] f *= 2 # R256 for _ in range(n_blocks): m += [ResnetBlock(f)] # u128 and u64 for _ in range(2): m += [nn.ConvTranspose2d(f, f // 2, 3, 2, 1, output_padding=1), nn.InstanceNorm2d(f // 2), nn.ReLU(True)] f //= 2 # c7s1-3 m += [nn.ReflectionPad2d(3), nn.Conv2d(base, out_c, 7), nn.Tanh()] self.m = nn.Sequential(*m) def forward(self, x): return self.m(x) def clean_sd(sd): out = {} for k, v in sd.items(): if k.startswith("module."): out[k[7:]] = v else: out[k] = v return out g_ab = None g_ba = None load_error = None def load_models(): global g_ab, g_ba, load_error if g_ab is not None or load_error is not None: return try: sd = torch.load(CKPT_PATH, map_location=device) g_ab = ResnetGenerator(n_blocks=6).to(device) g_ba = ResnetGenerator(n_blocks=6).to(device) g_ab.load_state_dict(clean_sd(sd["G_AB"]), strict=False) g_ba.load_state_dict(clean_sd(sd["G_BA"]), strict=False) g_ab.eval() g_ba.eval() except Exception as e: load_error = str(e) def preprocess(inp, size=128): img = Image.fromarray(inp.astype(np.uint8)).convert("RGB").resize((size, size)) arr = np.asarray(img).astype(np.float32) / 255.0 arr = (arr - 0.5) / 0.5 arr = np.transpose(arr, (2, 0, 1)) x = torch.from_numpy(arr).unsqueeze(0) return x @torch.no_grad() def run(inp, direction='Sketch -> Photo (G_AB)'): load_models() if inp is None: return None if load_error is not None: raise RuntimeError("Model load failed: " + load_error) x = preprocess(inp, 128).to(device) y = g_ab(x) if direction.startswith("Sketch") else g_ba(x) y = (y * 0.5 + 0.5).clamp(0, 1) return (y[0].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) with gr.Blocks(title="Q3: CycleGAN") as demo: gr.Markdown("# Q3: CycleGAN") inp = gr.Image(type="numpy", label="Input") direction = gr.Radio(["Sketch -> Photo (G_AB)", "Photo -> Sketch (G_BA)"], value="Sketch -> Photo (G_AB)", label="Direction") out = gr.Image(type="numpy", label="Output") btn = gr.Button("Translate") btn.click(run, [inp, direction], [out]) demo.launch(server_name="0.0.0.0", server_port=7860)