Spaces:
Running on Zero
Running on Zero
Update tokenizer.py
Browse files- tokenizer.py +137 -171
tokenizer.py
CHANGED
|
@@ -8,56 +8,61 @@ from deepsvg.svglib.geom import Bbox
|
|
| 8 |
|
| 9 |
|
| 10 |
class SVGTokenizer:
|
| 11 |
-
"""SVG tokenizer for converting between tokens and SVG representations"""
|
| 12 |
|
| 13 |
-
def __init__(self, config_path: str = "config.yaml"):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
self.
|
| 19 |
-
self.
|
| 20 |
-
self.
|
| 21 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
self.pixel2xy = self._create_pixel2xy_mapping()
|
| 24 |
|
| 25 |
def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
|
| 26 |
-
"""Create mapping from pixel indices to xy coordinates"""
|
| 27 |
-
bbox = self.coordinates_config['bbox']
|
| 28 |
-
coord_pad = self.coordinates_config['coord_pad_offset']
|
| 29 |
-
svg_end = self.tokens_config['svg_end']
|
| 30 |
-
|
| 31 |
pixel2xy = {}
|
| 32 |
-
x = np.linspace(0,
|
| 33 |
-
y = np.linspace(0,
|
| 34 |
xx, yy = np.meshgrid(x, y)
|
| 35 |
xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
|
| 36 |
|
| 37 |
for pixel, xy in enumerate(xy_grid):
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
return pixel2xy
|
| 41 |
|
| 42 |
def token_to_color(self, color_token: int) -> str:
|
| 43 |
try:
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
if color_token == color_token_start:
|
| 49 |
-
return "none" # No color
|
| 50 |
-
elif color_token == color_token_start + 1:
|
| 51 |
-
return "currentColor" # Special color
|
| 52 |
|
| 53 |
-
color_index
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
return "#808080" # Gray as default
|
| 57 |
|
| 58 |
-
r = (color_index >> 8) & 0xF
|
| 59 |
-
g = (color_index >> 4) & 0xF
|
| 60 |
-
b = color_index & 0xF
|
| 61 |
|
| 62 |
r = (r << 4) | r
|
| 63 |
g = (g << 4) | g
|
|
@@ -67,94 +72,103 @@ class SVGTokenizer:
|
|
| 67 |
|
| 68 |
except Exception as e:
|
| 69 |
print(f"Error in token_to_color: {e}")
|
| 70 |
-
return "#808080"
|
| 71 |
-
|
| 72 |
-
def
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
if
|
| 79 |
-
|
| 80 |
-
return xy
|
| 81 |
-
elif pix_pad + svg_end <= pixel < self.colors_config['cmd_fill'] + base_offset + svg_end:
|
| 82 |
-
pixel_index = pixel - pix_pad - svg_end
|
| 83 |
-
if pixel_index in self.pixel2xy:
|
| 84 |
-
return self.pixel2xy[pixel_index] - base_offset
|
| 85 |
-
else:
|
| 86 |
-
raise ValueError(f"Invalid pixel index: {pixel_index}")
|
| 87 |
else:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def raster_svg(self, pixels: np.ndarray) -> List[List[torch.Tensor]]:
|
| 91 |
-
"""Convert pixel sequence to SVG tensor representation"""
|
| 92 |
try:
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
svg_tensors = []
|
|
|
|
| 97 |
path_tensor = []
|
| 98 |
-
i = 0
|
| 99 |
|
|
|
|
| 100 |
while i < len(pixels):
|
| 101 |
try:
|
| 102 |
pix = pixels[i]
|
| 103 |
|
| 104 |
-
if pix[0] ==
|
| 105 |
-
cmd_tensor = np.zeros(14)
|
| 106 |
-
cmd_tensor[0] = 0
|
| 107 |
-
|
| 108 |
if i + 2 >= len(pixels):
|
| 109 |
-
break
|
| 110 |
-
|
|
|
|
| 111 |
cmd_tensor[12:14] = pixels[i+2]
|
| 112 |
-
start_pos = pixels[i+1]
|
| 113 |
-
end_pos = pixels[i+2]
|
| 114 |
-
|
| 115 |
-
if np.all(start_pos == end_pos) and path_tensor:
|
| 116 |
-
svg_tensors.append(torch.tensor(path_tensor))
|
| 117 |
-
path_tensor = []
|
| 118 |
path_tensor.append(cmd_tensor.tolist())
|
| 119 |
i += 3
|
| 120 |
|
| 121 |
-
elif pix[0] ==
|
| 122 |
-
cmd_tensor = np.zeros(14)
|
| 123 |
-
cmd_tensor[0] = 1
|
| 124 |
-
|
| 125 |
if i + 1 >= len(pixels):
|
| 126 |
-
break
|
| 127 |
-
|
|
|
|
| 128 |
cmd_tensor[12:14] = pixels[i+1]
|
| 129 |
path_tensor.append(cmd_tensor.tolist())
|
| 130 |
i += 2
|
| 131 |
|
| 132 |
-
elif pix[0] ==
|
| 133 |
-
cmd_tensor = np.zeros(14)
|
| 134 |
-
cmd_tensor[0] = 2
|
| 135 |
-
|
| 136 |
if i + 3 >= len(pixels):
|
| 137 |
-
break
|
| 138 |
-
|
|
|
|
| 139 |
cmd_tensor[8:10] = pixels[i+1]
|
| 140 |
cmd_tensor[10:12] = pixels[i+2]
|
| 141 |
cmd_tensor[12:14] = pixels[i+3]
|
| 142 |
path_tensor.append(cmd_tensor.tolist())
|
| 143 |
i += 4
|
| 144 |
|
| 145 |
-
elif pix[0] ==
|
| 146 |
-
cmd_tensor = np.zeros(14)
|
| 147 |
-
cmd_tensor[0] = 3
|
| 148 |
-
|
| 149 |
if i + 5 >= len(pixels):
|
| 150 |
-
break
|
| 151 |
-
|
|
|
|
| 152 |
radius = pixels[i+1]
|
| 153 |
-
x_axis_rot = pixels[i+2][0]
|
| 154 |
-
large_arc_flg = pixels[i+3][0]
|
| 155 |
-
sweep_flg = pixels[i+4][0]
|
| 156 |
end_pos = pixels[i+5]
|
| 157 |
-
|
| 158 |
cmd_tensor[1:3] = radius
|
| 159 |
cmd_tensor[3] = x_axis_rot
|
| 160 |
cmd_tensor[4] = large_arc_flg
|
|
@@ -163,102 +177,57 @@ class SVGTokenizer:
|
|
| 163 |
path_tensor.append(cmd_tensor.tolist())
|
| 164 |
i += 6
|
| 165 |
|
| 166 |
-
elif pix[0] ==
|
| 167 |
-
cmd_tensor = np.zeros(14)
|
| 168 |
-
cmd_tensor[0] = 6
|
| 169 |
-
|
| 170 |
if i + 1 >= len(pixels):
|
| 171 |
-
break
|
| 172 |
-
|
|
|
|
| 173 |
cmd_tensor[12:14] = pixels[i+1]
|
| 174 |
path_tensor.append(cmd_tensor.tolist())
|
| 175 |
i += 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
-
i += 1
|
| 178 |
|
| 179 |
-
except IndexError:
|
| 180 |
-
print(f"
|
| 181 |
break
|
| 182 |
-
|
| 183 |
if path_tensor:
|
| 184 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 185 |
|
| 186 |
-
return [svg_tensors]
|
| 187 |
|
| 188 |
except Exception as e:
|
| 189 |
print(f"Error in raster_svg: {e}")
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
color_end = self.colors_config['color_end_offset']
|
| 197 |
-
|
| 198 |
-
for token in tokens:
|
| 199 |
-
if color_start <= token < color_end:
|
| 200 |
-
colors.append(token - 1 - base_offset)
|
| 201 |
-
|
| 202 |
-
return colors
|
| 203 |
-
|
| 204 |
-
def process_generated_tokens(self, output_ids: torch.Tensor) -> Tuple[np.ndarray, List[int]]:
|
| 205 |
-
# Remove <bos> and <eos> tokens
|
| 206 |
-
generated_pixels = output_ids[:, 1:-1].tolist()
|
| 207 |
-
|
| 208 |
-
generated_xy = []
|
| 209 |
-
generated_colors = []
|
| 210 |
-
|
| 211 |
-
for pixel_sequence in generated_pixels:
|
| 212 |
-
xy_sequence = []
|
| 213 |
-
colors = []
|
| 214 |
-
|
| 215 |
-
for pixel in pixel_sequence:
|
| 216 |
-
try:
|
| 217 |
-
if self.tokens_config['eom'] < pixel < self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end']:
|
| 218 |
-
xy = self.pixel_to_xy(pixel)
|
| 219 |
-
xy_sequence.append(xy)
|
| 220 |
-
elif self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end'] <= pixel < self.colors_config['cmd_fill'] + self.tokens_config['base_offset'] + self.tokens_config['svg_end']:
|
| 221 |
-
xy = self.pixel_to_xy(pixel)
|
| 222 |
-
xy_sequence.append(xy)
|
| 223 |
-
elif self.colors_config['color_start_offset'] <= pixel < self.colors_config['color_end_offset']:
|
| 224 |
-
colors.append(pixel - 1 - self.tokens_config['base_offset'])
|
| 225 |
-
except ValueError as e:
|
| 226 |
-
print(f"Error processing pixel {pixel}: {e}")
|
| 227 |
-
continue
|
| 228 |
-
|
| 229 |
-
if xy_sequence:
|
| 230 |
-
generated_xy = np.vstack(xy_sequence)
|
| 231 |
-
generated_colors = colors
|
| 232 |
-
|
| 233 |
-
return generated_xy, generated_colors
|
| 234 |
-
|
| 235 |
-
def apply_colors_to_svg(self, svg_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]], colors: Optional[List[int]]) -> SVG:
|
| 236 |
paths = []
|
| 237 |
-
bbox = self.coordinates_config['bbox']
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
for tensor_list in svg_tensors:
|
| 242 |
-
flat_tensors.extend(tensor_list)
|
| 243 |
-
else:
|
| 244 |
-
flat_tensors = svg_tensors
|
| 245 |
|
| 246 |
-
|
| 247 |
-
raise ValueError("No valid SVG tensors provided")
|
| 248 |
|
| 249 |
-
|
| 250 |
-
colors = []
|
| 251 |
-
|
| 252 |
-
for i, path_tensor in enumerate(flat_tensors):
|
| 253 |
try:
|
| 254 |
path = SVGTensor.from_data(path_tensor)
|
| 255 |
-
path = SVG.from_tensor(path.data, viewbox=Bbox(
|
| 256 |
|
| 257 |
-
if i < len(colors)
|
| 258 |
-
color_token = colors[i]
|
| 259 |
-
actual_color = self.token_to_color(color_token)
|
| 260 |
-
else:
|
| 261 |
-
actual_color = "none"
|
| 262 |
|
| 263 |
for path_group in path:
|
| 264 |
path_group.color = actual_color
|
|
@@ -266,19 +235,16 @@ class SVGTokenizer:
|
|
| 266 |
|
| 267 |
path.fill_(True)
|
| 268 |
paths.append(path)
|
| 269 |
-
|
| 270 |
|
| 271 |
except Exception as e:
|
| 272 |
print(f"Error processing path {i}: {e}")
|
| 273 |
continue
|
| 274 |
|
| 275 |
if not paths:
|
| 276 |
-
raise ValueError("No valid paths
|
|
|
|
| 277 |
path_groups = paths[0].svg_path_groups
|
| 278 |
for i in range(1, len(paths)):
|
| 279 |
-
|
| 280 |
-
path_groups.extend(paths[i].svg_path_groups)
|
| 281 |
-
|
| 282 |
-
svg = SVG(path_groups, viewbox=Bbox(bbox))
|
| 283 |
|
| 284 |
-
return
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class SVGTokenizer:
|
|
|
|
| 11 |
|
| 12 |
+
def __init__(self, config_path: str = "./config.yaml"):
|
| 13 |
+
self.NUM_SVG_END = 1
|
| 14 |
+
self.BASE_OFFSET = 152064
|
| 15 |
+
self.NUM_MASK_AND_EOM = 2 + self.BASE_OFFSET
|
| 16 |
|
| 17 |
+
self.CMD_MOVE_RAW = 0 + self.NUM_MASK_AND_EOM
|
| 18 |
+
self.CMD_LINE_RAW = 1 + self.NUM_MASK_AND_EOM
|
| 19 |
+
self.CMD_CURVE_RAW = 2 + self.NUM_MASK_AND_EOM
|
| 20 |
+
self.CMD_ARC_RAW = 3 + self.NUM_MASK_AND_EOM
|
| 21 |
+
self.CMD_Z_RAW = 4 + self.NUM_MASK_AND_EOM
|
| 22 |
+
|
| 23 |
+
self.PIX_PAD = 5 + self.NUM_MASK_AND_EOM
|
| 24 |
+
self.COORD_PAD = self.PIX_PAD
|
| 25 |
+
|
| 26 |
+
self.BBOX = 200
|
| 27 |
+
self.ARC_PARAM_START = 44500 + self.BASE_OFFSET
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
self.COLOR_TOKEN_START = 40010 + self.BASE_OFFSET
|
| 31 |
+
self.COLOR_TOKEN_END = self.ARC_PARAM_START - 1
|
| 32 |
|
| 33 |
self.pixel2xy = self._create_pixel2xy_mapping()
|
| 34 |
|
| 35 |
def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
pixel2xy = {}
|
| 37 |
+
x = np.linspace(0, self.BBOX-1, self.BBOX)
|
| 38 |
+
y = np.linspace(0, self.BBOX-1, self.BBOX)
|
| 39 |
xx, yy = np.meshgrid(x, y)
|
| 40 |
xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
|
| 41 |
|
| 42 |
for pixel, xy in enumerate(xy_grid):
|
| 43 |
+
# xy + COORD_PAD + NUM_SVG_END = xy + 151943 + 1 = xy + 151944
|
| 44 |
+
pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
|
| 45 |
|
| 46 |
return pixel2xy
|
| 47 |
|
| 48 |
def token_to_color(self, color_token: int) -> str:
|
| 49 |
try:
|
| 50 |
+
COLOR_TOKEN_START = 40010
|
| 51 |
+
|
| 52 |
+
if color_token == COLOR_TOKEN_START:
|
| 53 |
+
return "none"
|
| 54 |
+
elif color_token == COLOR_TOKEN_START + 1:
|
| 55 |
+
return "currentColor"
|
| 56 |
|
| 57 |
+
color_index = color_token - (COLOR_TOKEN_START + 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
if color_index < 0 or color_index >= 4098:
|
| 60 |
+
print(f"Warning: Color token {color_token} out of range")
|
| 61 |
+
return "#808080"
|
|
|
|
| 62 |
|
| 63 |
+
r = (color_index >> 8) & 0xF
|
| 64 |
+
g = (color_index >> 4) & 0xF
|
| 65 |
+
b = color_index & 0xF
|
| 66 |
|
| 67 |
r = (r << 4) | r
|
| 68 |
g = (g << 4) | g
|
|
|
|
| 72 |
|
| 73 |
except Exception as e:
|
| 74 |
print(f"Error in token_to_color: {e}")
|
| 75 |
+
return "#808080"
|
| 76 |
+
|
| 77 |
+
def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
|
| 78 |
+
|
| 79 |
+
generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
|
| 80 |
+
|
| 81 |
+
sample_xys = []
|
| 82 |
+
|
| 83 |
+
for pixel in generated_pixels:
|
| 84 |
+
try:
|
| 85 |
+
if 151939 <= pixel < self.PIX_PAD + self.NUM_SVG_END: # < 151944
|
| 86 |
+
xy = np.array([pixel - self.BASE_OFFSET,
|
| 87 |
+
pixel - self.BASE_OFFSET]).astype(int)
|
| 88 |
+
sample_xys.append(xy)
|
| 89 |
+
|
| 90 |
+
elif (self.PIX_PAD + self.NUM_SVG_END <= pixel <
|
| 91 |
+
40011 + self.BASE_OFFSET):
|
| 92 |
+
pixel_index = pixel - self.PIX_PAD - self.NUM_SVG_END
|
| 93 |
+
if pixel_index in self.pixel2xy:
|
| 94 |
+
xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
|
| 95 |
+
sample_xys.append(xy)
|
| 96 |
+
|
| 97 |
+
elif (self.ARC_PARAM_START + 1 <= pixel <
|
| 98 |
+
self.ARC_PARAM_START + 1 + 100):
|
| 99 |
+
value = pixel - self.ARC_PARAM_START - 1
|
| 100 |
+
xy = np.array([value, value]).astype(int)
|
| 101 |
+
sample_xys.append(xy)
|
| 102 |
+
|
| 103 |
+
elif 40011 + self.BASE_OFFSET <= pixel < self.ARC_PARAM_START:
|
| 104 |
+
xy = np.array([pixel - self.BASE_OFFSET,
|
| 105 |
+
pixel - self.BASE_OFFSET]).astype(int)
|
| 106 |
+
sample_xys.append(xy)
|
| 107 |
+
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Error processing pixel {pixel}: {e}")
|
| 110 |
+
continue
|
| 111 |
|
| 112 |
+
if sample_xys:
|
| 113 |
+
return np.vstack(sample_xys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
else:
|
| 115 |
+
return np.array([]).reshape(0, 2)
|
| 116 |
+
|
| 117 |
+
def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
|
|
|
|
| 118 |
try:
|
| 119 |
+
if len(pixels) == 0:
|
| 120 |
+
return [[]], []
|
| 121 |
+
|
| 122 |
+
pixels = pixels - 8
|
| 123 |
|
| 124 |
svg_tensors = []
|
| 125 |
+
color_tensors = []
|
| 126 |
path_tensor = []
|
|
|
|
| 127 |
|
| 128 |
+
i = 0
|
| 129 |
while i < len(pixels):
|
| 130 |
try:
|
| 131 |
pix = pixels[i]
|
| 132 |
|
| 133 |
+
if pix[0] == -5:
|
|
|
|
|
|
|
|
|
|
| 134 |
if i + 2 >= len(pixels):
|
| 135 |
+
break
|
| 136 |
+
cmd_tensor = np.zeros(14)
|
| 137 |
+
cmd_tensor[0] = 0 # Move command
|
| 138 |
cmd_tensor[12:14] = pixels[i+2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
path_tensor.append(cmd_tensor.tolist())
|
| 140 |
i += 3
|
| 141 |
|
| 142 |
+
elif pix[0] == -4:
|
|
|
|
|
|
|
|
|
|
| 143 |
if i + 1 >= len(pixels):
|
| 144 |
+
break
|
| 145 |
+
cmd_tensor = np.zeros(14)
|
| 146 |
+
cmd_tensor[0] = 1 # Line command
|
| 147 |
cmd_tensor[12:14] = pixels[i+1]
|
| 148 |
path_tensor.append(cmd_tensor.tolist())
|
| 149 |
i += 2
|
| 150 |
|
| 151 |
+
elif pix[0] == -3:
|
|
|
|
|
|
|
|
|
|
| 152 |
if i + 3 >= len(pixels):
|
| 153 |
+
break
|
| 154 |
+
cmd_tensor = np.zeros(14)
|
| 155 |
+
cmd_tensor[0] = 2 # Curve command
|
| 156 |
cmd_tensor[8:10] = pixels[i+1]
|
| 157 |
cmd_tensor[10:12] = pixels[i+2]
|
| 158 |
cmd_tensor[12:14] = pixels[i+3]
|
| 159 |
path_tensor.append(cmd_tensor.tolist())
|
| 160 |
i += 4
|
| 161 |
|
| 162 |
+
elif pix[0] == -2:
|
|
|
|
|
|
|
|
|
|
| 163 |
if i + 5 >= len(pixels):
|
| 164 |
+
break
|
| 165 |
+
cmd_tensor = np.zeros(14)
|
| 166 |
+
cmd_tensor[0] = 3 # Arc command
|
| 167 |
radius = pixels[i+1]
|
| 168 |
+
x_axis_rot = pixels[i+2][0] + 8
|
| 169 |
+
large_arc_flg = pixels[i+3][0] + 8
|
| 170 |
+
sweep_flg = pixels[i+4][0] + 8
|
| 171 |
end_pos = pixels[i+5]
|
|
|
|
| 172 |
cmd_tensor[1:3] = radius
|
| 173 |
cmd_tensor[3] = x_axis_rot
|
| 174 |
cmd_tensor[4] = large_arc_flg
|
|
|
|
| 177 |
path_tensor.append(cmd_tensor.tolist())
|
| 178 |
i += 6
|
| 179 |
|
| 180 |
+
elif pix[0] == -1:
|
|
|
|
|
|
|
|
|
|
| 181 |
if i + 1 >= len(pixels):
|
| 182 |
+
break
|
| 183 |
+
cmd_tensor = np.zeros(14)
|
| 184 |
+
cmd_tensor[0] = 6 # Close command
|
| 185 |
cmd_tensor[12:14] = pixels[i+1]
|
| 186 |
path_tensor.append(cmd_tensor.tolist())
|
| 187 |
i += 2
|
| 188 |
+
|
| 189 |
+
elif pix[0] >= 40003:
|
| 190 |
+
if path_tensor:
|
| 191 |
+
svg_tensors.append(torch.tensor(path_tensor))
|
| 192 |
+
# 逆转换:还原原始颜色token
|
| 193 |
+
# pix[0] + 8 + 152064 - 1 - 152064 = pix[0] + 7
|
| 194 |
+
color_token = int(pix[0] + 7)
|
| 195 |
+
color_tensors.append(color_token)
|
| 196 |
+
path_tensor = []
|
| 197 |
+
i += 1
|
| 198 |
else:
|
| 199 |
+
i += 1
|
| 200 |
|
| 201 |
+
except (IndexError, TypeError) as e:
|
| 202 |
+
print(f"Error at position {i}: {e}")
|
| 203 |
break
|
| 204 |
+
|
| 205 |
if path_tensor:
|
| 206 |
svg_tensors.append(torch.tensor(path_tensor))
|
| 207 |
|
| 208 |
+
return [svg_tensors], color_tensors
|
| 209 |
|
| 210 |
except Exception as e:
|
| 211 |
print(f"Error in raster_svg: {e}")
|
| 212 |
+
import traceback
|
| 213 |
+
traceback.print_exc()
|
| 214 |
+
return [[]], []
|
| 215 |
+
|
| 216 |
+
def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
|
| 217 |
+
colors: Optional[List[int]]) -> SVG:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
paths = []
|
|
|
|
| 219 |
|
| 220 |
+
if not svg_tensors:
|
| 221 |
+
raise ValueError("No valid SVG tensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
colors = colors or []
|
|
|
|
| 224 |
|
| 225 |
+
for i, path_tensor in enumerate(svg_tensors):
|
|
|
|
|
|
|
|
|
|
| 226 |
try:
|
| 227 |
path = SVGTensor.from_data(path_tensor)
|
| 228 |
+
path = SVG.from_tensor(path.data, viewbox=Bbox(self.BBOX))
|
| 229 |
|
| 230 |
+
actual_color = self.token_to_color(colors[i]) if i < len(colors) else "none"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
for path_group in path:
|
| 233 |
path_group.color = actual_color
|
|
|
|
| 235 |
|
| 236 |
path.fill_(True)
|
| 237 |
paths.append(path)
|
|
|
|
| 238 |
|
| 239 |
except Exception as e:
|
| 240 |
print(f"Error processing path {i}: {e}")
|
| 241 |
continue
|
| 242 |
|
| 243 |
if not paths:
|
| 244 |
+
raise ValueError("No valid paths generated")
|
| 245 |
+
|
| 246 |
path_groups = paths[0].svg_path_groups
|
| 247 |
for i in range(1, len(paths)):
|
| 248 |
+
path_groups.extend(paths[i].svg_path_groups)
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
return SVG(path_groups, viewbox=Bbox(self.BBOX))
|