ShadeNet / inference_utils.py
singam96's picture
revert
39a275f
Raw
History Blame Contribute Delete
8.23 kB
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
def resize_pad(img_rgb: Image.Image, image_size: int) -> Image.Image:
img_rgb = img_rgb.convert("RGB")
w, h = img_rgb.size
size = int(image_size)
if w == 0 or h == 0:
return Image.new("RGB", (size, size), (0, 0, 0))
scale = max(size / float(w), size / float(h))
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
img_r = img_rgb.resize((new_w, new_h), resample=Image.BICUBIC)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img_r.crop((left, top, left + size, top + size))
def cond_tensor_from_pil(img_rgb: Image.Image, device: torch.device) -> torch.Tensor:
t = transforms.ToTensor()(img_rgb).unsqueeze(0).to(device)
t = t * 2.0 - 1.0
return t
def full_inference(model, img_rgb: Image.Image, image_size: int, device: torch.device, num_passes: int = 5, noise_std: float = 0.01):
img_rgb = resize_pad(img_rgb, int(image_size))
x = cond_tensor_from_pil(img_rgb, device)
map_names = ['basecolor', 'normal', 'rmd']
stacks = {k: [] for k in map_names}
for p in range(num_passes):
noise = torch.randn_like(x) * noise_std if noise_std > 0 else 0
preds = model(x + noise)
for k in map_names:
stacks[k].append(preds[k])
merged = {}
for k in map_names:
merged[k] = torch.median(torch.stack(stacks[k]), dim=0).values
inv_input = torch.cat([merged['basecolor'], merged['normal'], merged['rmd']], dim=1)
with torch.no_grad():
merged['rgb'] = model(inv_input, mode=1)['rgb']
def to_pil(tensor):
out = (tensor + 1.0) / 2.0
out = out.clamp(0, 1)
out_np = out[0].detach().cpu().permute(1, 2, 0).numpy()
return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8"))
outputs = {k: to_pil(v) for k, v in merged.items()}
return img_rgb, outputs
def tiled_inference(model, img_rgb: Image.Image, tile: int, device: torch.device, overlap: int = 16):
img_rgb = img_rgb.convert("RGB")
overlap = int(overlap)
if overlap < 0:
overlap = 0
if overlap >= tile:
overlap = max(0, tile - 1)
stride = max(1, tile - overlap)
w, h = img_rgb.size
pad_w = (tile - (w % stride)) % stride
pad_h = (tile - (h % stride)) % stride
if pad_w or pad_h:
new_img = Image.new("RGB", (w + pad_w, h + pad_h), (0, 0, 0))
new_img.paste(img_rgb, (0, 0))
src_padded = new_img
else:
src_padded = img_rgb
pw, ph = src_padded.size
map_names = ['basecolor', 'normal', 'rmd']
acc = {k: np.zeros((ph, pw, 3), dtype=np.float32) for k in map_names}
wsum = np.zeros((ph, pw, 1), dtype=np.float32)
xs = list(range(0, max(1, pw - tile + 1), stride))
ys = list(range(0, max(1, ph - tile + 1), stride))
if xs[-1] != pw - tile:
xs.append(pw - tile)
if ys[-1] != ph - tile:
ys.append(ph - tile)
for top in ys:
for left in xs:
patch_img = src_padded.crop((left, top, left + tile, top + tile))
cond = cond_tensor_from_pil(patch_img, device)
with torch.no_grad():
preds = model(cond)
def tensor_to_np(t):
t = (t + 1.0) / 2.0
t = t.clamp(0, 1)
return t[0].detach().cpu().permute(1, 2, 0).numpy()
ramp_x = np.ones((tile,), dtype=np.float32)
ramp_y = np.ones((tile,), dtype=np.float32)
if overlap > 0:
if left > 0:
ramp_x[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32)
if left + tile < pw:
ramp_x[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32)
if top > 0:
ramp_y[:overlap] = np.linspace(0.0, 1.0, overlap, endpoint=False, dtype=np.float32)
if top + tile < ph:
ramp_y[-overlap:] = np.linspace(1.0, 0.0, overlap, endpoint=False, dtype=np.float32)
weight = (ramp_y[:, None] * ramp_x[None, :])[:, :, None]
for k in map_names:
np_pred = tensor_to_np(preds[k])
acc[k][top : top + tile, left : left + tile, :] += np_pred * weight
wsum[top : top + tile, left : left + tile, :] += weight
def acc_to_pil(out_np):
out_np = out_np / np.maximum(wsum, 1e-8)
out_np = np.clip(out_np, 0.0, 1.0)
return Image.fromarray((out_np * 255.0 + 0.5).astype("uint8"))
outputs = {k: acc_to_pil(acc[k]) for k in map_names}
if pad_w or pad_h:
for k in map_names:
outputs[k] = outputs[k].crop((0, 0, w, h))
return img_rgb, outputs
def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image:
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 18)
except OSError:
font = ImageFont.load_default()
draw.rectangle((0, 0, img.width, 24), fill=bar_color)
draw.text((4, 2), label, fill=(255, 255, 255), font=font)
return img
def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image:
draw = ImageDraw.Draw(img)
cx, cy = img.width // 2, img.height // 2
r = 8
draw.line((0, cy, img.width - r, cy), fill=color, width=3)
draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color)
return img
def _draw_label(img: Image.Image, label: str, bar_color=(0, 0, 0)) -> Image.Image:
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 18)
except OSError:
font = ImageFont.load_default()
draw.rectangle((0, 0, img.width, 24), fill=bar_color)
draw.text((4, 2), label, fill=(255, 255, 255), font=font)
return img
def _draw_arrow(img: Image.Image, color=(180, 180, 180)) -> Image.Image:
draw = ImageDraw.Draw(img)
cx, cy = img.width // 2, img.height // 2
r = 8
draw.line((0, cy, img.width - r, cy), fill=color, width=3)
draw.polygon([(img.width - r, cy - r), (img.width - r, cy + r), (img.width, cy)], fill=color)
return img
def make_side_by_side(inp_img: Image.Image, outputs: dict) -> Image.Image:
inp_img = inp_img.convert("RGB")
outputs = {k: v.convert("RGB") for k, v in outputs.items()}
r, g, b = outputs['rmd'].split()
cell_size = 200
arrow_w = 48
gap = 8
stage1_w = cell_size
stage2_w = cell_size * 3 + gap * 2
stage3_w = cell_size
total_w = stage1_w + arrow_w + stage2_w + arrow_w + stage3_w
total_h = cell_size * 2 + gap
canvas = Image.new("RGB", (total_w, total_h), (35, 35, 35))
def place(cvs, img, x, y, sz=cell_size):
img = img.resize((sz, sz), Image.BICUBIC)
cvs.paste(img, (x, y))
cx = 0
stage1 = _draw_label(inp_img.resize((cell_size, cell_size), Image.BICUBIC), "INPUT", (50, 100, 200))
place(canvas, stage1, cx, (total_h - cell_size) // 2)
cx += cell_size
arrow1 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35))
arrow1 = _draw_arrow(arrow1)
canvas.paste(arrow1, (cx, 0))
cx += arrow_w
ordered = [
("BASECOLOR", outputs['basecolor'], (50, 160, 80)),
("NORMAL", outputs['normal'], (50, 160, 80)),
("DEPTH", b.convert("RGB"), (50, 160, 80)),
("ROUGHNESS", r.convert("RGB"), (50, 160, 80)),
("METALLIC", g.convert("RGB"), (50, 160, 80)),
]
for i, (label, img, color) in enumerate(ordered):
col = i % 3
row = i // 3
px = cx + col * (cell_size + gap)
py = row * (cell_size + gap)
panel = _draw_label(img.resize((cell_size, cell_size), Image.BICUBIC), label, color)
canvas.paste(panel, (px, py))
cx += cell_size * 3 + gap * 2
arrow2 = Image.new("RGB", (arrow_w, total_h), (35, 35, 35))
arrow2 = _draw_arrow(arrow2)
canvas.paste(arrow2, (cx, 0))
cx += arrow_w
stage3 = _draw_label(outputs['rgb'].resize((cell_size, cell_size), Image.BICUBIC), "RECON RGB", (200, 120, 50))
place(canvas, stage3, cx, (total_h - cell_size) // 2)
return canvas