import torch import torch.nn as nn from typing import Tuple, Optional def bernstein_matrix(n_samples: int, n_ctrl_pts: int, device): t = torch.linspace(0, 1, n_samples, device=device, dtype=torch.float32) if n_ctrl_pts == 4: B = torch.stack([ (1 - t) ** 3, 3 * t * (1 - t) ** 2, 3 * t ** 2 * (1 - t), t ** 3 ], dim=1) else: raise NotImplementedError return B @torch.compile def render_patch(grid_patch, pts, sigma_s, alpha_s, sharpness): d2 = ((grid_patch.unsqueeze(-2) - pts)**2).sum(-1) gauss = torch.exp(-0.5 * sharpness * d2 / (sigma_s**2 + 1e-12)) density = (alpha_s * gauss).sum(-1) patch_alpha = (1.0 - torch.exp(-density.clamp(min=1e-12))) ** sharpness return patch_alpha class StrokeRenderer(nn.Module): def __init__(self, canvas_size: Tuple[int, int], device, n_ctrl_pts: int, padding_k=3.0, n_samples: int = 36, sharpness=2): super().__init__() self.H, self.W = canvas_size self.n_ctrl_pts = n_ctrl_pts self.padding_k = padding_k self.n_samples = n_samples self.sharpness = sharpness # Precompute normalized canvas grid yy, xx = torch.meshgrid( torch.linspace(0, 1, self.H, device=device), torch.linspace(0, 1, self.W, device=device), indexing='ij' ) self.grid_norm: torch.Tensor self.register_buffer('grid_norm', torch.stack( [xx, yy], dim=-1)) # (H,W,2) self.t_64: torch.Tensor self.register_buffer('t_64', torch.linspace( 0, 1, n_samples, device=device)) self.bernstein: torch.Tensor self.register_buffer('bernstein', bernstein_matrix( self.n_samples, n_ctrl_pts, device=device)) def forward(self, strokes: torch.Tensor, prev_canvas: Optional[torch.Tensor] = None, return_step_canvases: bool = False): """ strokes: (B, S, n_ctrl_pts*2 + 7) Returns: dict with end_canvas (B,3,H,W) and step_canvas (B,S,3,H,W) """ B, S, _ = strokes.shape device = strokes.device # Parse stroke parameters ctrl_pts = strokes[..., :2 * self.n_ctrl_pts].view(B*S, self.n_ctrl_pts, 2) w_start = strokes[..., 2*self.n_ctrl_pts].view(B*S) w_end = strokes[..., 2*self.n_ctrl_pts+1].view(B*S) op_start = strokes[..., 2*self.n_ctrl_pts+2].view(B*S) op_end = strokes[..., 2*self.n_ctrl_pts+3].view(B*S) color = strokes[..., 2*self.n_ctrl_pts+4:].view(B*S, 3) if prev_canvas is None: prev_canvas = torch.zeros((B, 3, self.H, self.W), device=device, dtype=strokes.dtype) else: prev_canvas = prev_canvas.to(device=device, dtype=strokes.dtype) # Initialize canvas canvas = torch.zeros((B, 4, self.H, self.W), device=device, dtype=strokes.dtype) # Step canvas storage if return_step_canvases: cumulative_rgb = prev_canvas.clone() cumulative_alpha = torch.zeros( (B, 1, self.H, self.W), device=device, dtype=strokes.dtype) step_canvas = torch.zeros( B, S, 3, self.H, self.W, device=device, dtype=strokes.dtype) else: cumulative_rgb = None cumulative_alpha = None step_canvas = None for idx in range(B*S): pts = torch.matmul(self.bernstein, ctrl_pts[idx]) # (n_samples,2) alpha_s = (op_start[idx] * (1 - self.t_64) + op_end[idx] * self.t_64) # (n_samples,) sigma_s = (w_start[idx] * (1 - self.t_64) + w_end[idx] * self.t_64) / 2.8 sigma_s = sigma_s.clamp(min=1e-4) col = color[idx] # (3,) min_xy = pts.min(dim=0).values max_xy = pts.max(dim=0).values pad = self.padding_k * sigma_s.max() bbox_min = (min_xy - pad).clamp(0, 1) bbox_max = (max_xy + pad).clamp(0, 1) x0 = max(0, int(torch.floor(bbox_min[0] * (self.W - 1)).item())) y0 = max(0, int(torch.floor(bbox_min[1] * (self.H - 1)).item())) x1 = min( self.W-1, int(torch.ceil(bbox_max[0] * (self.W - 1)).item())) y1 = min( self.H-1, int(torch.ceil(bbox_max[1] * (self.H - 1)).item())) if x1 < x0 or y1 < y0: continue # Patch grid from registered buffer grid_patch = self.grid_norm[y0:y1+1, x0:x1+1] # (h, w, 2) patch_alpha = render_patch( grid_patch, pts, sigma_s, alpha_s, self.sharpness) patch_rgb = col.view(3, 1, 1) * \ patch_alpha.unsqueeze(0) # (3,h,w) # Clone the canvas patch to avoid in-place issues canvas_rgb = canvas[idx//S, :3, y0:y1+1, x0:x1+1].clone() canvas_alpha = canvas[idx//S, 3, y0:y1+1, x0:x1+1].clone() # Composite: src over dst inv_patch_alpha = (1.0 - patch_alpha).unsqueeze(0) # (1,h,w) new_rgb = patch_rgb + canvas_rgb * inv_patch_alpha # (3,h,w) new_alpha = patch_alpha + canvas_alpha * \ (1.0 - patch_alpha) # (h,w) # assign back to canvas canvas[idx//S, :3, y0:y1+1, x0:x1+1] = new_rgb canvas[idx//S, 3, y0:y1+1, x0:x1+1] = new_alpha if return_step_canvases and step_canvas is not None and cumulative_rgb is not None and cumulative_alpha is not None: batch_idx = idx // S step_idx = idx % S x_slice = slice(x0, x1+1) y_slice = slice(y0, y1+1) # Update cumulative canvas cumulative_rgb[batch_idx, :, y_slice, x_slice] = ( patch_rgb + cumulative_rgb[batch_idx, :, y_slice, x_slice] * (1 - patch_alpha.unsqueeze(0)) ) cumulative_alpha[batch_idx, :, y_slice, x_slice] = ( patch_alpha.unsqueeze( 0) + cumulative_alpha[batch_idx, :, y_slice, x_slice] * (1 - patch_alpha.unsqueeze(0)) ) step_canvas[batch_idx, step_idx] = cumulative_rgb[batch_idx].detach() rgb = canvas[:, :3, ...] alpha = canvas[:, 3:, ...].clamp(0, 1) new_canvas = rgb + prev_canvas*(1-alpha) new_canvas = new_canvas.clamp(0, 1) return {'end_canvas': new_canvas, 'step_canvases': step_canvas if return_step_canvases is True else None}