Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Optional | |
| def bernstein_matrix(n_samples: int, n_ctrl_pts: int, device): | |
| t = torch.linspace(0, 1, n_samples, device=device, dtype=torch.float32) | |
| if n_ctrl_pts == 4: | |
| B = torch.stack([ | |
| (1 - t) ** 3, | |
| 3 * t * (1 - t) ** 2, | |
| 3 * t ** 2 * (1 - t), | |
| t ** 3 | |
| ], dim=1) | |
| else: | |
| raise NotImplementedError | |
| return B | |
| def render_patch(grid_patch, pts, sigma_s, alpha_s, sharpness): | |
| d2 = ((grid_patch.unsqueeze(-2) - pts)**2).sum(-1) | |
| gauss = torch.exp(-0.5 * sharpness * d2 / (sigma_s**2 + 1e-12)) | |
| density = (alpha_s * gauss).sum(-1) | |
| patch_alpha = (1.0 - torch.exp(-density.clamp(min=1e-12))) ** sharpness | |
| return patch_alpha | |
| class StrokeRenderer(nn.Module): | |
| def __init__(self, canvas_size: Tuple[int, int], device, n_ctrl_pts: int, padding_k=3.0, | |
| n_samples: int = 36, sharpness=2): | |
| super().__init__() | |
| self.H, self.W = canvas_size | |
| self.n_ctrl_pts = n_ctrl_pts | |
| self.padding_k = padding_k | |
| self.n_samples = n_samples | |
| self.sharpness = sharpness | |
| # Precompute normalized canvas grid | |
| yy, xx = torch.meshgrid( | |
| torch.linspace(0, 1, self.H, device=device), | |
| torch.linspace(0, 1, self.W, device=device), | |
| indexing='ij' | |
| ) | |
| self.grid_norm: torch.Tensor | |
| self.register_buffer('grid_norm', torch.stack( | |
| [xx, yy], dim=-1)) # (H,W,2) | |
| self.t_64: torch.Tensor | |
| self.register_buffer('t_64', torch.linspace( | |
| 0, 1, n_samples, device=device)) | |
| self.bernstein: torch.Tensor | |
| self.register_buffer('bernstein', bernstein_matrix( | |
| self.n_samples, n_ctrl_pts, device=device)) | |
| def forward(self, strokes: torch.Tensor, | |
| prev_canvas: Optional[torch.Tensor] = None, | |
| return_step_canvases: bool = False): | |
| """ | |
| strokes: (B, S, n_ctrl_pts*2 + 7) | |
| Returns: dict with end_canvas (B,3,H,W) and step_canvas (B,S,3,H,W) | |
| """ | |
| B, S, _ = strokes.shape | |
| device = strokes.device | |
| # Parse stroke parameters | |
| ctrl_pts = strokes[..., :2 * | |
| self.n_ctrl_pts].view(B*S, self.n_ctrl_pts, 2) | |
| w_start = strokes[..., 2*self.n_ctrl_pts].view(B*S) | |
| w_end = strokes[..., 2*self.n_ctrl_pts+1].view(B*S) | |
| op_start = strokes[..., 2*self.n_ctrl_pts+2].view(B*S) | |
| op_end = strokes[..., 2*self.n_ctrl_pts+3].view(B*S) | |
| color = strokes[..., 2*self.n_ctrl_pts+4:].view(B*S, 3) | |
| if prev_canvas is None: | |
| prev_canvas = torch.zeros((B, 3, self.H, self.W), | |
| device=device, dtype=strokes.dtype) | |
| else: | |
| prev_canvas = prev_canvas.to(device=device, dtype=strokes.dtype) | |
| # Initialize canvas | |
| canvas = torch.zeros((B, 4, self.H, self.W), | |
| device=device, dtype=strokes.dtype) | |
| # Step canvas storage | |
| if return_step_canvases: | |
| cumulative_rgb = prev_canvas.clone() | |
| cumulative_alpha = torch.zeros( | |
| (B, 1, self.H, self.W), device=device, dtype=strokes.dtype) | |
| step_canvas = torch.zeros( | |
| B, S, 3, self.H, self.W, device=device, dtype=strokes.dtype) | |
| else: | |
| cumulative_rgb = None | |
| cumulative_alpha = None | |
| step_canvas = None | |
| for idx in range(B*S): | |
| pts = torch.matmul(self.bernstein, ctrl_pts[idx]) # (n_samples,2) | |
| alpha_s = (op_start[idx] * (1 - self.t_64) + | |
| op_end[idx] * self.t_64) # (n_samples,) | |
| sigma_s = (w_start[idx] * (1 - self.t_64) + | |
| w_end[idx] * self.t_64) / 2.8 | |
| sigma_s = sigma_s.clamp(min=1e-4) | |
| col = color[idx] # (3,) | |
| min_xy = pts.min(dim=0).values | |
| max_xy = pts.max(dim=0).values | |
| pad = self.padding_k * sigma_s.max() | |
| bbox_min = (min_xy - pad).clamp(0, 1) | |
| bbox_max = (max_xy + pad).clamp(0, 1) | |
| x0 = max(0, int(torch.floor(bbox_min[0] * (self.W - 1)).item())) | |
| y0 = max(0, int(torch.floor(bbox_min[1] * (self.H - 1)).item())) | |
| x1 = min( | |
| self.W-1, int(torch.ceil(bbox_max[0] * (self.W - 1)).item())) | |
| y1 = min( | |
| self.H-1, int(torch.ceil(bbox_max[1] * (self.H - 1)).item())) | |
| if x1 < x0 or y1 < y0: | |
| continue | |
| # Patch grid from registered buffer | |
| grid_patch = self.grid_norm[y0:y1+1, x0:x1+1] # (h, w, 2) | |
| patch_alpha = render_patch( | |
| grid_patch, pts, sigma_s, alpha_s, self.sharpness) | |
| patch_rgb = col.view(3, 1, 1) * \ | |
| patch_alpha.unsqueeze(0) # (3,h,w) | |
| # Clone the canvas patch to avoid in-place issues | |
| canvas_rgb = canvas[idx//S, :3, y0:y1+1, x0:x1+1].clone() | |
| canvas_alpha = canvas[idx//S, 3, y0:y1+1, x0:x1+1].clone() | |
| # Composite: src over dst | |
| inv_patch_alpha = (1.0 - patch_alpha).unsqueeze(0) # (1,h,w) | |
| new_rgb = patch_rgb + canvas_rgb * inv_patch_alpha # (3,h,w) | |
| new_alpha = patch_alpha + canvas_alpha * \ | |
| (1.0 - patch_alpha) # (h,w) | |
| # assign back to canvas | |
| canvas[idx//S, :3, y0:y1+1, x0:x1+1] = new_rgb | |
| canvas[idx//S, 3, y0:y1+1, x0:x1+1] = new_alpha | |
| if return_step_canvases and step_canvas is not None and cumulative_rgb is not None and cumulative_alpha is not None: | |
| batch_idx = idx // S | |
| step_idx = idx % S | |
| x_slice = slice(x0, x1+1) | |
| y_slice = slice(y0, y1+1) | |
| # Update cumulative canvas | |
| cumulative_rgb[batch_idx, :, y_slice, x_slice] = ( | |
| patch_rgb + | |
| cumulative_rgb[batch_idx, :, y_slice, | |
| x_slice] * (1 - patch_alpha.unsqueeze(0)) | |
| ) | |
| cumulative_alpha[batch_idx, :, y_slice, x_slice] = ( | |
| patch_alpha.unsqueeze( | |
| 0) + cumulative_alpha[batch_idx, :, y_slice, x_slice] * (1 - patch_alpha.unsqueeze(0)) | |
| ) | |
| step_canvas[batch_idx, | |
| step_idx] = cumulative_rgb[batch_idx].detach() | |
| rgb = canvas[:, :3, ...] | |
| alpha = canvas[:, 3:, ...].clamp(0, 1) | |
| new_canvas = rgb + prev_canvas*(1-alpha) | |
| new_canvas = new_canvas.clamp(0, 1) | |
| return {'end_canvas': new_canvas, 'step_canvases': step_canvas if return_step_canvases is True else None} | |