Spaces:
Sleeping
Sleeping
File size: 6,974 Bytes
97bca33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import torch
import torch.nn as nn
from typing import Tuple, Optional
def bernstein_matrix(n_samples: int, n_ctrl_pts: int, device):
t = torch.linspace(0, 1, n_samples, device=device, dtype=torch.float32)
if n_ctrl_pts == 4:
B = torch.stack([
(1 - t) ** 3,
3 * t * (1 - t) ** 2,
3 * t ** 2 * (1 - t),
t ** 3
], dim=1)
else:
raise NotImplementedError
return B
@torch.compile
def render_patch(grid_patch, pts, sigma_s, alpha_s, sharpness):
d2 = ((grid_patch.unsqueeze(-2) - pts)**2).sum(-1)
gauss = torch.exp(-0.5 * sharpness * d2 / (sigma_s**2 + 1e-12))
density = (alpha_s * gauss).sum(-1)
patch_alpha = (1.0 - torch.exp(-density.clamp(min=1e-12))) ** sharpness
return patch_alpha
class StrokeRenderer(nn.Module):
def __init__(self, canvas_size: Tuple[int, int], device, n_ctrl_pts: int, padding_k=3.0,
n_samples: int = 36, sharpness=2):
super().__init__()
self.H, self.W = canvas_size
self.n_ctrl_pts = n_ctrl_pts
self.padding_k = padding_k
self.n_samples = n_samples
self.sharpness = sharpness
# Precompute normalized canvas grid
yy, xx = torch.meshgrid(
torch.linspace(0, 1, self.H, device=device),
torch.linspace(0, 1, self.W, device=device),
indexing='ij'
)
self.grid_norm: torch.Tensor
self.register_buffer('grid_norm', torch.stack(
[xx, yy], dim=-1)) # (H,W,2)
self.t_64: torch.Tensor
self.register_buffer('t_64', torch.linspace(
0, 1, n_samples, device=device))
self.bernstein: torch.Tensor
self.register_buffer('bernstein', bernstein_matrix(
self.n_samples, n_ctrl_pts, device=device))
def forward(self, strokes: torch.Tensor,
prev_canvas: Optional[torch.Tensor] = None,
return_step_canvases: bool = False):
"""
strokes: (B, S, n_ctrl_pts*2 + 7)
Returns: dict with end_canvas (B,3,H,W) and step_canvas (B,S,3,H,W)
"""
B, S, _ = strokes.shape
device = strokes.device
# Parse stroke parameters
ctrl_pts = strokes[..., :2 *
self.n_ctrl_pts].view(B*S, self.n_ctrl_pts, 2)
w_start = strokes[..., 2*self.n_ctrl_pts].view(B*S)
w_end = strokes[..., 2*self.n_ctrl_pts+1].view(B*S)
op_start = strokes[..., 2*self.n_ctrl_pts+2].view(B*S)
op_end = strokes[..., 2*self.n_ctrl_pts+3].view(B*S)
color = strokes[..., 2*self.n_ctrl_pts+4:].view(B*S, 3)
if prev_canvas is None:
prev_canvas = torch.zeros((B, 3, self.H, self.W),
device=device, dtype=strokes.dtype)
else:
prev_canvas = prev_canvas.to(device=device, dtype=strokes.dtype)
# Initialize canvas
canvas = torch.zeros((B, 4, self.H, self.W),
device=device, dtype=strokes.dtype)
# Step canvas storage
if return_step_canvases:
cumulative_rgb = prev_canvas.clone()
cumulative_alpha = torch.zeros(
(B, 1, self.H, self.W), device=device, dtype=strokes.dtype)
step_canvas = torch.zeros(
B, S, 3, self.H, self.W, device=device, dtype=strokes.dtype)
else:
cumulative_rgb = None
cumulative_alpha = None
step_canvas = None
for idx in range(B*S):
pts = torch.matmul(self.bernstein, ctrl_pts[idx]) # (n_samples,2)
alpha_s = (op_start[idx] * (1 - self.t_64) +
op_end[idx] * self.t_64) # (n_samples,)
sigma_s = (w_start[idx] * (1 - self.t_64) +
w_end[idx] * self.t_64) / 2.8
sigma_s = sigma_s.clamp(min=1e-4)
col = color[idx] # (3,)
min_xy = pts.min(dim=0).values
max_xy = pts.max(dim=0).values
pad = self.padding_k * sigma_s.max()
bbox_min = (min_xy - pad).clamp(0, 1)
bbox_max = (max_xy + pad).clamp(0, 1)
x0 = max(0, int(torch.floor(bbox_min[0] * (self.W - 1)).item()))
y0 = max(0, int(torch.floor(bbox_min[1] * (self.H - 1)).item()))
x1 = min(
self.W-1, int(torch.ceil(bbox_max[0] * (self.W - 1)).item()))
y1 = min(
self.H-1, int(torch.ceil(bbox_max[1] * (self.H - 1)).item()))
if x1 < x0 or y1 < y0:
continue
# Patch grid from registered buffer
grid_patch = self.grid_norm[y0:y1+1, x0:x1+1] # (h, w, 2)
patch_alpha = render_patch(
grid_patch, pts, sigma_s, alpha_s, self.sharpness)
patch_rgb = col.view(3, 1, 1) * \
patch_alpha.unsqueeze(0) # (3,h,w)
# Clone the canvas patch to avoid in-place issues
canvas_rgb = canvas[idx//S, :3, y0:y1+1, x0:x1+1].clone()
canvas_alpha = canvas[idx//S, 3, y0:y1+1, x0:x1+1].clone()
# Composite: src over dst
inv_patch_alpha = (1.0 - patch_alpha).unsqueeze(0) # (1,h,w)
new_rgb = patch_rgb + canvas_rgb * inv_patch_alpha # (3,h,w)
new_alpha = patch_alpha + canvas_alpha * \
(1.0 - patch_alpha) # (h,w)
# assign back to canvas
canvas[idx//S, :3, y0:y1+1, x0:x1+1] = new_rgb
canvas[idx//S, 3, y0:y1+1, x0:x1+1] = new_alpha
if return_step_canvases and step_canvas is not None and cumulative_rgb is not None and cumulative_alpha is not None:
batch_idx = idx // S
step_idx = idx % S
x_slice = slice(x0, x1+1)
y_slice = slice(y0, y1+1)
# Update cumulative canvas
cumulative_rgb[batch_idx, :, y_slice, x_slice] = (
patch_rgb +
cumulative_rgb[batch_idx, :, y_slice,
x_slice] * (1 - patch_alpha.unsqueeze(0))
)
cumulative_alpha[batch_idx, :, y_slice, x_slice] = (
patch_alpha.unsqueeze(
0) + cumulative_alpha[batch_idx, :, y_slice, x_slice] * (1 - patch_alpha.unsqueeze(0))
)
step_canvas[batch_idx,
step_idx] = cumulative_rgb[batch_idx].detach()
rgb = canvas[:, :3, ...]
alpha = canvas[:, 3:, ...].clamp(0, 1)
new_canvas = rgb + prev_canvas*(1-alpha)
new_canvas = new_canvas.clamp(0, 1)
return {'end_canvas': new_canvas, 'step_canvases': step_canvas if return_step_canvases is True else None}
|