| import numpy as np |
| import torch |
| from PIL import Image, ImageDraw, ImageFont |
| from torchvision import transforms |
|
|
|
|
| def resize_pad(img_rgb: Image.Image, image_size: int) -> Image.Image: |
| img_rgb = img_rgb.convert("RGB") |
| w, h = img_rgb.size |
| size = int(image_size) |
| if w == 0 or h == 0: |
| return Image.new("RGB", (size, size), (0, 0, 0)) |
| scale = max(size / float(w), size / float(h)) |
| new_w = max(1, int(round(w * scale))) |
| new_h = max(1, int(round(h * scale))) |
| img_r = img_rgb.resize((new_w, new_h), resample=Image.BICUBIC) |
| left = (new_w - size) // 2 |
| top = (new_h - size) // 2 |
| return img_r.crop((left, top, left + size, top + size)) |
|
|
|
|
| def cond_tensor_from_pil(img_rgb: Image.Image, device: torch.device) -> torch.Tensor: |
| t = transforms.ToTensor()(img_rgb).unsqueeze(0).to(device) |
| t = t * 2.0 - 1.0 |
| return t |
|
|
|
|
| def full_inference(model, img_rgb: Image.Image, image_size: int, device: torch.device, num_passes: int = 5, noise_std: float = 0.01): |
| img_rgb = resize_pad(img_rgb, int(image_size)) |
| x = cond_tensor_from_pil(img_rgb, device) |
|
|
| map_names = ['basecolor', 'normal', 'rmd'] |
| stacks = {k: [] for k in map_names} |
|
|
| for p in range(num_passes): |
| noise = torch.randn_like(x) * noise_std if noise_std > 0 else 0 |
| preds = model(x + noise) |
|
|
| for k in map_names: |
| stacks[k].append(preds[k]) |
|
|
| merged = {} |
| for k in map_names: |
| merged[k] = torch.median(torch.stack(stacks[k]), dim=0).values |
|
|
| inv_input = torch.cat([merged['basecolor'], merged['normal'], merged['rmd']], dim=1) |
| with torch.no_grad(): |
| merged['rgb'] = model(inv_input, mode=1)['rgb'] |
|
|
| def to_pil(tensor): |
| out = (tensor + 1.0) / 2.0 |
| out = out.clamp(0, 1) |
| out_np = out[0].detach().cpu().permute(1, 2, 0).numpy() |
| return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8")) |
|
|
| outputs = {k: to_pil(v) for k, v in merged.items()} |
| return img_rgb, outputs |
|
|
|
|
| def tiled_inference(model, img_rgb: Image.Image, tile: int, device: torch.device, overlap: int = 16): |
| img_rgb = img_rgb.convert("RGB") |
|
|
| overlap = int(overlap) |
| if overlap < 0: |
| overlap = 0 |
| if overlap >= tile: |
| overlap = max(0, tile - 1) |
| stride = max(1, tile - overlap) |
|
|
| w, h = img_rgb.size |
| pad_w = (tile - (w % stride)) % stride |
| pad_h = (tile - (h % stride)) % stride |
|
|
| if pad_w or pad_h: |
| new_img = Image.new("RGB", (w + pad_w, h + pad_h), (0, 0, 0)) |
| new_img.paste(img_rgb, (0, 0)) |
| src_padded = new_img |
| else: |
| src_padded = img_rgb |
|
|
| pw, ph = src_padded.size |
|
|
| map_names = ['basecolor', 'normal', 'rmd'] |
| acc = {k: np.zeros((ph, pw, 3), dtype=np.float32) for k in map_names} |
| wsum = np.zeros((ph, pw, 1), dtype=np.float32) |
|
|
| xs = list(range(0, max(1, pw - tile + 1), stride)) |
| ys = list(range(0, max(1, ph - tile + 1), stride)) |
| if xs[-1] != pw - tile: |
| xs.append(pw - tile) |
| if ys[-1] != ph - tile: |
| ys.append(ph - tile) |
|
|
| for top in ys: |
| for left in xs: |
| patch_img = src_padded.crop((left, top, left + tile, top + tile)) |
| cond = cond_tensor_from_pil(patch_img, device) |
|
|
| with torch.no_grad(): |
| preds = model(cond) |
|
|
| def tensor_to_np(t): |
| t = (t + 1.0) / 2.0 |
| t = t.clamp(0, 1) |
| return t[0].detach().cpu().permute(1, 2, 0).numpy() |
|
|
| ramp_x = np.ones((tile,), dtype=np.float32) |
| ramp_y = np.ones((tile,), dtype=np.float32) |
| if overlap > 0: |
| if left > 0: |
| ramp_x[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32) |
| if left + tile < pw: |
| ramp_x[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32) |
| if top > 0: |
| ramp_y[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32) |
| if top + tile < ph: |
| ramp_y[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32) |
|
|
| weight = (ramp_y[:, None] * ramp_x[None, :])[:, :, None] |
|
|
| for k in map_names: |
| np_pred = tensor_to_np(preds[k]) |
| acc[k][top : top + tile, left : left + tile, :] += np_pred * weight |
| wsum[top : top + tile, left : left + tile, :] += weight |
|
|
| def acc_to_pil(out_np): |
| out_np = out_np / np.maximum(wsum, 1e-8) |
| out_np = np.clip(out_np, 0.0, 1.0) |
| return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8")) |
|
|
| outputs = {k: acc_to_pil(acc[k]) for k in map_names} |
|
|
| if pad_w or pad_h: |
| for k in map_names: |
| outputs[k] = outputs[k].crop((0, 0, w, h)) |
|
|
| return img_rgb, outputs |
|
|
|
|
| def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image: |
| draw = ImageDraw.Draw(img) |
| try: |
| font = ImageFont.truetype("arial.ttf", 18) |
| except OSError: |
| font = ImageFont.load_default() |
| draw.rectangle((0, 0, img.width, 24), fill=bar_color) |
| draw.text((4, 2), label, fill=(255, 255, 255), font=font) |
| return img |
|
|
|
|
| def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image: |
| draw = ImageDraw.Draw(img) |
| cx, cy = img.width // 2, img.height // 2 |
| r = 8 |
| draw.line((0, cy, img.width - r, cy), fill=color, width=3) |
| draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color) |
| return img |
|
|
|
|
| def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image: |
| draw = ImageDraw.Draw(img) |
| try: |
| font = ImageFont.truetype("arial.ttf", 18) |
| except OSError: |
| font = ImageFont.load_default() |
| draw.rectangle((0, 0, img.width, 24), fill=bar_color) |
| draw.text((4, 2), label, fill=(255, 255, 255), font=font) |
| return img |
|
|
|
|
| def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image: |
| draw = ImageDraw.Draw(img) |
| cx, cy = img.width // 2, img.height // 2 |
| r = 8 |
| draw.line((0, cy, img.width - r, cy), fill=color, width=3) |
| draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color) |
| return img |
|
|
|
|
| def make_side_by_side(inp_img: Image.Image, outputs: dict) -> Image.Image: |
| inp_img = inp_img.convert("RGB") |
|
|
| outputs = {k: v.convert("RGB") for k, v in outputs.items()} |
|
|
| r, g, b = outputs['rmd'].split() |
|
|
| cell_size = 200 |
| arrow_w = 48 |
| gap = 8 |
|
|
| stage1_w = cell_size |
| stage2_w = cell_size * 3 + gap * 2 |
| stage3_w = cell_size |
|
|
| total_w = stage1_w + arrow_w + stage2_w + arrow_w + stage3_w |
| total_h = cell_size * 2 + gap |
|
|
| canvas = Image.new("RGB", (total_w, total_h), (35, 35, 35)) |
|
|
| def place(cvs, img, x, y, sz=cell_size): |
| img = img.resize((sz, sz), Image.BICUBIC) |
| cvs.paste(img, (x, y)) |
|
|
| cx = 0 |
|
|
| stage1 = _draw_label(inp_img.resize((cell_size, cell_size), Image.BICUBIC), "INPUT", (50, 100, 200)) |
| place(canvas, stage1, cx, (total_h - cell_size) // 2) |
| cx += cell_size |
|
|
| arrow1 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35)) |
| arrow1 = _draw_arrow(arrow1) |
| canvas.paste(arrow1, (cx, 0)) |
| cx += arrow_w |
|
|
| ordered = [ |
| ("BASECOLOR", outputs['basecolor'], (50, 160, 80)), |
| ("NORMAL", outputs['normal'], (50, 160, 80)), |
| ("DEPTH", b.convert("RGB"), (50, 160, 80)), |
| ("ROUGHNESS", r.convert("RGB"), (50, 160, 80)), |
| ("METALLIC", g.convert("RGB"), (50, 160, 80)), |
| ] |
|
|
| for i, (label, img, color) in enumerate(ordered): |
| col = i % 3 |
| row = i // 3 |
| px = cx + col * (cell_size + gap) |
| py = row * (cell_size + gap) |
| panel = _draw_label(img.resize((cell_size, cell_size), Image.BICUBIC), label, color) |
| canvas.paste(panel, (px, py)) |
|
|
| cx += cell_size * 3 + gap * 2 |
|
|
| arrow2 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35)) |
| arrow2 = _draw_arrow(arrow2) |
| canvas.paste(arrow2, (cx, 0)) |
| cx += arrow_w |
|
|
| stage3 = _draw_label(outputs['rgb'].resize((cell_size, cell_size), Image.BICUBIC), "RECON RGB", (200, 120, 50)) |
| place(canvas, stage3, cx, (total_h - cell_size) // 2) |
|
|
| return canvas |
|
|