image2painting / renderer.py
Lasercatz
Upload 9 files
97bca33 verified
import torch
import torch.nn as nn
from typing import Tuple, Optional
def bernstein_matrix(n_samples: int, n_ctrl_pts: int, device):
t = torch.linspace(0, 1, n_samples, device=device, dtype=torch.float32)
if n_ctrl_pts == 4:
B = torch.stack([
(1 - t) ** 3,
3 * t * (1 - t) ** 2,
3 * t ** 2 * (1 - t),
t ** 3
], dim=1)
else:
raise NotImplementedError
return B
@torch.compile
def render_patch(grid_patch, pts, sigma_s, alpha_s, sharpness):
d2 = ((grid_patch.unsqueeze(-2) - pts)**2).sum(-1)
gauss = torch.exp(-0.5 * sharpness * d2 / (sigma_s**2 + 1e-12))
density = (alpha_s * gauss).sum(-1)
patch_alpha = (1.0 - torch.exp(-density.clamp(min=1e-12))) ** sharpness
return patch_alpha
class StrokeRenderer(nn.Module):
def __init__(self, canvas_size: Tuple[int, int], device, n_ctrl_pts: int, padding_k=3.0,
n_samples: int = 36, sharpness=2):
super().__init__()
self.H, self.W = canvas_size
self.n_ctrl_pts = n_ctrl_pts
self.padding_k = padding_k
self.n_samples = n_samples
self.sharpness = sharpness
# Precompute normalized canvas grid
yy, xx = torch.meshgrid(
torch.linspace(0, 1, self.H, device=device),
torch.linspace(0, 1, self.W, device=device),
indexing='ij'
)
self.grid_norm: torch.Tensor
self.register_buffer('grid_norm', torch.stack(
[xx, yy], dim=-1)) # (H,W,2)
self.t_64: torch.Tensor
self.register_buffer('t_64', torch.linspace(
0, 1, n_samples, device=device))
self.bernstein: torch.Tensor
self.register_buffer('bernstein', bernstein_matrix(
self.n_samples, n_ctrl_pts, device=device))
def forward(self, strokes: torch.Tensor,
prev_canvas: Optional[torch.Tensor] = None,
return_step_canvases: bool = False):
"""
strokes: (B, S, n_ctrl_pts*2 + 7)
Returns: dict with end_canvas (B,3,H,W) and step_canvas (B,S,3,H,W)
"""
B, S, _ = strokes.shape
device = strokes.device
# Parse stroke parameters
ctrl_pts = strokes[..., :2 *
self.n_ctrl_pts].view(B*S, self.n_ctrl_pts, 2)
w_start = strokes[..., 2*self.n_ctrl_pts].view(B*S)
w_end = strokes[..., 2*self.n_ctrl_pts+1].view(B*S)
op_start = strokes[..., 2*self.n_ctrl_pts+2].view(B*S)
op_end = strokes[..., 2*self.n_ctrl_pts+3].view(B*S)
color = strokes[..., 2*self.n_ctrl_pts+4:].view(B*S, 3)
if prev_canvas is None:
prev_canvas = torch.zeros((B, 3, self.H, self.W),
device=device, dtype=strokes.dtype)
else:
prev_canvas = prev_canvas.to(device=device, dtype=strokes.dtype)
# Initialize canvas
canvas = torch.zeros((B, 4, self.H, self.W),
device=device, dtype=strokes.dtype)
# Step canvas storage
if return_step_canvases:
cumulative_rgb = prev_canvas.clone()
cumulative_alpha = torch.zeros(
(B, 1, self.H, self.W), device=device, dtype=strokes.dtype)
step_canvas = torch.zeros(
B, S, 3, self.H, self.W, device=device, dtype=strokes.dtype)
else:
cumulative_rgb = None
cumulative_alpha = None
step_canvas = None
for idx in range(B*S):
pts = torch.matmul(self.bernstein, ctrl_pts[idx]) # (n_samples,2)
alpha_s = (op_start[idx] * (1 - self.t_64) +
op_end[idx] * self.t_64) # (n_samples,)
sigma_s = (w_start[idx] * (1 - self.t_64) +
w_end[idx] * self.t_64) / 2.8
sigma_s = sigma_s.clamp(min=1e-4)
col = color[idx] # (3,)
min_xy = pts.min(dim=0).values
max_xy = pts.max(dim=0).values
pad = self.padding_k * sigma_s.max()
bbox_min = (min_xy - pad).clamp(0, 1)
bbox_max = (max_xy + pad).clamp(0, 1)
x0 = max(0, int(torch.floor(bbox_min[0] * (self.W - 1)).item()))
y0 = max(0, int(torch.floor(bbox_min[1] * (self.H - 1)).item()))
x1 = min(
self.W-1, int(torch.ceil(bbox_max[0] * (self.W - 1)).item()))
y1 = min(
self.H-1, int(torch.ceil(bbox_max[1] * (self.H - 1)).item()))
if x1 < x0 or y1 < y0:
continue
# Patch grid from registered buffer
grid_patch = self.grid_norm[y0:y1+1, x0:x1+1] # (h, w, 2)
patch_alpha = render_patch(
grid_patch, pts, sigma_s, alpha_s, self.sharpness)
patch_rgb = col.view(3, 1, 1) * \
patch_alpha.unsqueeze(0) # (3,h,w)
# Clone the canvas patch to avoid in-place issues
canvas_rgb = canvas[idx//S, :3, y0:y1+1, x0:x1+1].clone()
canvas_alpha = canvas[idx//S, 3, y0:y1+1, x0:x1+1].clone()
# Composite: src over dst
inv_patch_alpha = (1.0 - patch_alpha).unsqueeze(0) # (1,h,w)
new_rgb = patch_rgb + canvas_rgb * inv_patch_alpha # (3,h,w)
new_alpha = patch_alpha + canvas_alpha * \
(1.0 - patch_alpha) # (h,w)
# assign back to canvas
canvas[idx//S, :3, y0:y1+1, x0:x1+1] = new_rgb
canvas[idx//S, 3, y0:y1+1, x0:x1+1] = new_alpha
if return_step_canvases and step_canvas is not None and cumulative_rgb is not None and cumulative_alpha is not None:
batch_idx = idx // S
step_idx = idx % S
x_slice = slice(x0, x1+1)
y_slice = slice(y0, y1+1)
# Update cumulative canvas
cumulative_rgb[batch_idx, :, y_slice, x_slice] = (
patch_rgb +
cumulative_rgb[batch_idx, :, y_slice,
x_slice] * (1 - patch_alpha.unsqueeze(0))
)
cumulative_alpha[batch_idx, :, y_slice, x_slice] = (
patch_alpha.unsqueeze(
0) + cumulative_alpha[batch_idx, :, y_slice, x_slice] * (1 - patch_alpha.unsqueeze(0))
)
step_canvas[batch_idx,
step_idx] = cumulative_rgb[batch_idx].detach()
rgb = canvas[:, :3, ...]
alpha = canvas[:, 3:, ...].clamp(0, 1)
new_canvas = rgb + prev_canvas*(1-alpha)
new_canvas = new_canvas.clamp(0, 1)
return {'end_canvas': new_canvas, 'step_canvases': step_canvas if return_step_canvases is True else None}