import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from torchvision import transforms def resize_pad(img_rgb: Image.Image, image_size: int) -> Image.Image: img_rgb = img_rgb.convert("RGB") w, h = img_rgb.size size = int(image_size) if w == 0 or h == 0: return Image.new("RGB", (size, size), (0, 0, 0)) scale = max(size / float(w), size / float(h)) new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) img_r = img_rgb.resize((new_w, new_h), resample=Image.BICUBIC) left = (new_w - size) // 2 top = (new_h - size) // 2 return img_r.crop((left, top, left + size, top + size)) def cond_tensor_from_pil(img_rgb: Image.Image, device: torch.device) -> torch.Tensor: t = transforms.ToTensor()(img_rgb).unsqueeze(0).to(device) t = t * 2.0 - 1.0 return t def full_inference(model, img_rgb: Image.Image, image_size: int, device: torch.device, num_passes: int = 5, noise_std: float = 0.01): img_rgb = resize_pad(img_rgb, int(image_size)) x = cond_tensor_from_pil(img_rgb, device) map_names = ['basecolor', 'normal', 'rmd'] stacks = {k: [] for k in map_names} for p in range(num_passes): noise = torch.randn_like(x) * noise_std if noise_std > 0 else 0 preds = model(x + noise) for k in map_names: stacks[k].append(preds[k]) merged = {} for k in map_names: merged[k] = torch.median(torch.stack(stacks[k]), dim=0).values inv_input = torch.cat([merged['basecolor'], merged['normal'], merged['rmd']], dim=1) with torch.no_grad(): merged['rgb'] = model(inv_input, mode=1)['rgb'] def to_pil(tensor): out = (tensor + 1.0) / 2.0 out = out.clamp(0, 1) out_np = out[0].detach().cpu().permute(1, 2, 0).numpy() return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8")) outputs = {k: to_pil(v) for k, v in merged.items()} return img_rgb, outputs def tiled_inference(model, img_rgb: Image.Image, tile: int, device: torch.device, overlap: int = 16): img_rgb = img_rgb.convert("RGB") overlap = int(overlap) if overlap < 0: overlap = 0 if overlap >= tile: overlap = max(0, tile - 1) stride = max(1, tile - overlap) w, h = img_rgb.size pad_w = (tile - (w % stride)) % stride pad_h = (tile - (h % stride)) % stride if pad_w or pad_h: new_img = Image.new("RGB", (w + pad_w, h + pad_h), (0, 0, 0)) new_img.paste(img_rgb, (0, 0)) src_padded = new_img else: src_padded = img_rgb pw, ph = src_padded.size map_names = ['basecolor', 'normal', 'rmd'] acc = {k: np.zeros((ph, pw, 3), dtype=np.float32) for k in map_names} wsum = np.zeros((ph, pw, 1), dtype=np.float32) xs = list(range(0, max(1, pw - tile + 1), stride)) ys = list(range(0, max(1, ph - tile + 1), stride)) if xs[-1] != pw - tile: xs.append(pw - tile) if ys[-1] != ph - tile: ys.append(ph - tile) for top in ys: for left in xs: patch_img = src_padded.crop((left, top, left + tile, top + tile)) cond = cond_tensor_from_pil(patch_img, device) with torch.no_grad(): preds = model(cond) def tensor_to_np(t): t = (t + 1.0) / 2.0 t = t.clamp(0, 1) return t[0].detach().cpu().permute(1, 2, 0).numpy() ramp_x = np.ones((tile,), dtype=np.float32) ramp_y = np.ones((tile,), dtype=np.float32) if overlap > 0: if left > 0: ramp_x[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32) if left + tile < pw: ramp_x[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32) if top > 0: ramp_y[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32) if top + tile < ph: ramp_y[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32) weight = (ramp_y[:, None] * ramp_x[None, :])[:, :, None] for k in map_names: np_pred = tensor_to_np(preds[k]) acc[k][top : top + tile, left : left + tile, :] += np_pred * weight wsum[top : top + tile, left : left + tile, :] += weight def acc_to_pil(out_np): out_np = out_np / np.maximum(wsum, 1e-8) out_np = np.clip(out_np, 0.0, 1.0) return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8")) outputs = {k: acc_to_pil(acc[k]) for k in map_names} if pad_w or pad_h: for k in map_names: outputs[k] = outputs[k].crop((0, 0, w, h)) return img_rgb, outputs def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image: draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 18) except OSError: font = ImageFont.load_default() draw.rectangle((0, 0, img.width, 24), fill=bar_color) draw.text((4, 2), label, fill=(255, 255, 255), font=font) return img def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image: draw = ImageDraw.Draw(img) cx, cy = img.width // 2, img.height // 2 r = 8 draw.line((0, cy, img.width - r, cy), fill=color, width=3) draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color) return img def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image: draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 18) except OSError: font = ImageFont.load_default() draw.rectangle((0, 0, img.width, 24), fill=bar_color) draw.text((4, 2), label, fill=(255, 255, 255), font=font) return img def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image: draw = ImageDraw.Draw(img) cx, cy = img.width // 2, img.height // 2 r = 8 draw.line((0, cy, img.width - r, cy), fill=color, width=3) draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color) return img def make_side_by_side(inp_img: Image.Image, outputs: dict) -> Image.Image: inp_img = inp_img.convert("RGB") outputs = {k: v.convert("RGB") for k, v in outputs.items()} r, g, b = outputs['rmd'].split() cell_size = 200 arrow_w = 48 gap = 8 stage1_w = cell_size stage2_w = cell_size * 3 + gap * 2 stage3_w = cell_size total_w = stage1_w + arrow_w + stage2_w + arrow_w + stage3_w total_h = cell_size * 2 + gap canvas = Image.new("RGB", (total_w, total_h), (35, 35, 35)) def place(cvs, img, x, y, sz=cell_size): img = img.resize((sz, sz), Image.BICUBIC) cvs.paste(img, (x, y)) cx = 0 stage1 = _draw_label(inp_img.resize((cell_size, cell_size), Image.BICUBIC), "INPUT", (50, 100, 200)) place(canvas, stage1, cx, (total_h - cell_size) // 2) cx += cell_size arrow1 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35)) arrow1 = _draw_arrow(arrow1) canvas.paste(arrow1, (cx, 0)) cx += arrow_w ordered = [ ("BASECOLOR", outputs['basecolor'], (50, 160, 80)), ("NORMAL", outputs['normal'], (50, 160, 80)), ("DEPTH", b.convert("RGB"), (50, 160, 80)), ("ROUGHNESS", r.convert("RGB"), (50, 160, 80)), ("METALLIC", g.convert("RGB"), (50, 160, 80)), ] for i, (label, img, color) in enumerate(ordered): col = i % 3 row = i // 3 px = cx + col * (cell_size + gap) py = row * (cell_size + gap) panel = _draw_label(img.resize((cell_size, cell_size), Image.BICUBIC), label, color) canvas.paste(panel, (px, py)) cx += cell_size * 3 + gap * 2 arrow2 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35)) arrow2 = _draw_arrow(arrow2) canvas.paste(arrow2, (cx, 0)) cx += arrow_w stage3 = _draw_label(outputs['rgb'].resize((cell_size, cell_size), Image.BICUBIC), "RECON RGB", (200, 120, 50)) place(canvas, stage3, cx, (total_h - cell_size) // 2) return canvas