OmniSVG commited on
Commit
1015bc8
·
verified ·
1 Parent(s): de52740

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +137 -171
tokenizer.py CHANGED
@@ -8,56 +8,61 @@ from deepsvg.svglib.geom import Bbox
8
 
9
 
10
  class SVGTokenizer:
11
- """SVG tokenizer for converting between tokens and SVG representations"""
12
 
13
- def __init__(self, config_path: str = "config.yaml"):
14
- with open(config_path, 'r') as f:
15
- self.config = yaml.safe_load(f)
 
16
 
17
- # Extract configuration values
18
- self.tokens_config = self.config['tokens']
19
- self.coordinates_config = self.config['coordinates']
20
- self.colors_config = self.config['colors']
21
- self.svg_commands = self.config['svg_commands']
 
 
 
 
 
 
 
 
 
 
22
 
23
  self.pixel2xy = self._create_pixel2xy_mapping()
24
 
25
  def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
26
- """Create mapping from pixel indices to xy coordinates"""
27
- bbox = self.coordinates_config['bbox']
28
- coord_pad = self.coordinates_config['coord_pad_offset']
29
- svg_end = self.tokens_config['svg_end']
30
-
31
  pixel2xy = {}
32
- x = np.linspace(0, bbox-1, bbox)
33
- y = np.linspace(0, bbox-1, bbox)
34
  xx, yy = np.meshgrid(x, y)
35
  xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
36
 
37
  for pixel, xy in enumerate(xy_grid):
38
- pixel2xy[pixel] = xy + coord_pad + svg_end
 
39
 
40
  return pixel2xy
41
 
42
  def token_to_color(self, color_token: int) -> str:
43
  try:
44
- color_token_start = self.colors_config['color_token_start']
45
- max_color_tokens = self.colors_config['max_color_tokens']
 
 
 
 
46
 
47
- # Check special color tokens
48
- if color_token == color_token_start:
49
- return "none" # No color
50
- elif color_token == color_token_start + 1:
51
- return "currentColor" # Special color
52
 
53
- color_index = color_token - (color_token_start + 2)
54
- if color_index < 0 or color_index >= max_color_tokens:
55
- print(f"Warning: Color token {color_token} out of range, using default color")
56
- return "#808080" # Gray as default
57
 
58
- r = (color_index >> 8) & 0xF
59
- g = (color_index >> 4) & 0xF
60
- b = color_index & 0xF
61
 
62
  r = (r << 4) | r
63
  g = (g << 4) | g
@@ -67,94 +72,103 @@ class SVGTokenizer:
67
 
68
  except Exception as e:
69
  print(f"Error in token_to_color: {e}")
70
- return "#808080"
71
-
72
- def pixel_to_xy(self, pixel: int) -> np.ndarray:
73
- """Convert pixel token to xy coordinates"""
74
- base_offset = self.tokens_config['base_offset']
75
- pix_pad = self.coordinates_config['pix_pad_offset']
76
- svg_end = self.tokens_config['svg_end']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- if self.tokens_config['eom'] < pixel < pix_pad + svg_end:
79
- xy = np.array([pixel - base_offset, pixel - base_offset]).astype(int)
80
- return xy
81
- elif pix_pad + svg_end <= pixel < self.colors_config['cmd_fill'] + base_offset + svg_end:
82
- pixel_index = pixel - pix_pad - svg_end
83
- if pixel_index in self.pixel2xy:
84
- return self.pixel2xy[pixel_index] - base_offset
85
- else:
86
- raise ValueError(f"Invalid pixel index: {pixel_index}")
87
  else:
88
- raise ValueError(f"Invalid pixel token: {pixel}")
89
-
90
- def raster_svg(self, pixels: np.ndarray) -> List[List[torch.Tensor]]:
91
- """Convert pixel sequence to SVG tensor representation"""
92
  try:
93
- adjustment = self.tokens_config['num_end_token'] + self.tokens_config['svg_end'] + 2 # 8
94
- pixels = pixels - adjustment
 
 
95
 
96
  svg_tensors = []
 
97
  path_tensor = []
98
- i = 0
99
 
 
100
  while i < len(pixels):
101
  try:
102
  pix = pixels[i]
103
 
104
- if pix[0] == self.svg_commands['move']: # Move command
105
- cmd_tensor = np.zeros(14)
106
- cmd_tensor[0] = 0
107
-
108
  if i + 2 >= len(pixels):
109
- break
110
-
 
111
  cmd_tensor[12:14] = pixels[i+2]
112
- start_pos = pixels[i+1]
113
- end_pos = pixels[i+2]
114
-
115
- if np.all(start_pos == end_pos) and path_tensor:
116
- svg_tensors.append(torch.tensor(path_tensor))
117
- path_tensor = []
118
  path_tensor.append(cmd_tensor.tolist())
119
  i += 3
120
 
121
- elif pix[0] == self.svg_commands['line']: # Line command
122
- cmd_tensor = np.zeros(14)
123
- cmd_tensor[0] = 1
124
-
125
  if i + 1 >= len(pixels):
126
- break
127
-
 
128
  cmd_tensor[12:14] = pixels[i+1]
129
  path_tensor.append(cmd_tensor.tolist())
130
  i += 2
131
 
132
- elif pix[0] == self.svg_commands['curve']: # Curve command
133
- cmd_tensor = np.zeros(14)
134
- cmd_tensor[0] = 2
135
-
136
  if i + 3 >= len(pixels):
137
- break
138
-
 
139
  cmd_tensor[8:10] = pixels[i+1]
140
  cmd_tensor[10:12] = pixels[i+2]
141
  cmd_tensor[12:14] = pixels[i+3]
142
  path_tensor.append(cmd_tensor.tolist())
143
  i += 4
144
 
145
- elif pix[0] == self.svg_commands['arc']: # Arc command
146
- cmd_tensor = np.zeros(14)
147
- cmd_tensor[0] = 3
148
-
149
  if i + 5 >= len(pixels):
150
- break
151
-
 
152
  radius = pixels[i+1]
153
- x_axis_rot = pixels[i+2][0]
154
- large_arc_flg = pixels[i+3][0]
155
- sweep_flg = pixels[i+4][0]
156
  end_pos = pixels[i+5]
157
-
158
  cmd_tensor[1:3] = radius
159
  cmd_tensor[3] = x_axis_rot
160
  cmd_tensor[4] = large_arc_flg
@@ -163,102 +177,57 @@ class SVGTokenizer:
163
  path_tensor.append(cmd_tensor.tolist())
164
  i += 6
165
 
166
- elif pix[0] == self.svg_commands['close']: # Close command
167
- cmd_tensor = np.zeros(14)
168
- cmd_tensor[0] = 6
169
-
170
  if i + 1 >= len(pixels):
171
- break
172
-
 
173
  cmd_tensor[12:14] = pixels[i+1]
174
  path_tensor.append(cmd_tensor.tolist())
175
  i += 2
 
 
 
 
 
 
 
 
 
 
176
  else:
177
- i += 1
178
 
179
- except IndexError:
180
- print(f"Index error at position {i}, stopping SVG processing")
181
  break
182
-
183
  if path_tensor:
184
  svg_tensors.append(torch.tensor(path_tensor))
185
 
186
- return [svg_tensors]
187
 
188
  except Exception as e:
189
  print(f"Error in raster_svg: {e}")
190
- return []
191
-
192
- def extract_colors_from_tokens(self, tokens: List[int]) -> List[int]:
193
- colors = []
194
- base_offset = self.tokens_config['base_offset']
195
- color_start = self.colors_config['color_start_offset']
196
- color_end = self.colors_config['color_end_offset']
197
-
198
- for token in tokens:
199
- if color_start <= token < color_end:
200
- colors.append(token - 1 - base_offset)
201
-
202
- return colors
203
-
204
- def process_generated_tokens(self, output_ids: torch.Tensor) -> Tuple[np.ndarray, List[int]]:
205
- # Remove <bos> and <eos> tokens
206
- generated_pixels = output_ids[:, 1:-1].tolist()
207
-
208
- generated_xy = []
209
- generated_colors = []
210
-
211
- for pixel_sequence in generated_pixels:
212
- xy_sequence = []
213
- colors = []
214
-
215
- for pixel in pixel_sequence:
216
- try:
217
- if self.tokens_config['eom'] < pixel < self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end']:
218
- xy = self.pixel_to_xy(pixel)
219
- xy_sequence.append(xy)
220
- elif self.coordinates_config['pix_pad_offset'] + self.tokens_config['svg_end'] <= pixel < self.colors_config['cmd_fill'] + self.tokens_config['base_offset'] + self.tokens_config['svg_end']:
221
- xy = self.pixel_to_xy(pixel)
222
- xy_sequence.append(xy)
223
- elif self.colors_config['color_start_offset'] <= pixel < self.colors_config['color_end_offset']:
224
- colors.append(pixel - 1 - self.tokens_config['base_offset'])
225
- except ValueError as e:
226
- print(f"Error processing pixel {pixel}: {e}")
227
- continue
228
-
229
- if xy_sequence:
230
- generated_xy = np.vstack(xy_sequence)
231
- generated_colors = colors
232
-
233
- return generated_xy, generated_colors
234
-
235
- def apply_colors_to_svg(self, svg_tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]], colors: Optional[List[int]]) -> SVG:
236
  paths = []
237
- bbox = self.coordinates_config['bbox']
238
 
239
- flat_tensors = []
240
- if svg_tensors and isinstance(svg_tensors[0], list):
241
- for tensor_list in svg_tensors:
242
- flat_tensors.extend(tensor_list)
243
- else:
244
- flat_tensors = svg_tensors
245
 
246
- if not flat_tensors:
247
- raise ValueError("No valid SVG tensors provided")
248
 
249
- if colors is None:
250
- colors = []
251
-
252
- for i, path_tensor in enumerate(flat_tensors):
253
  try:
254
  path = SVGTensor.from_data(path_tensor)
255
- path = SVG.from_tensor(path.data, viewbox=Bbox(bbox))
256
 
257
- if i < len(colors):
258
- color_token = colors[i]
259
- actual_color = self.token_to_color(color_token)
260
- else:
261
- actual_color = "none"
262
 
263
  for path_group in path:
264
  path_group.color = actual_color
@@ -266,19 +235,16 @@ class SVGTokenizer:
266
 
267
  path.fill_(True)
268
  paths.append(path)
269
-
270
 
271
  except Exception as e:
272
  print(f"Error processing path {i}: {e}")
273
  continue
274
 
275
  if not paths:
276
- raise ValueError("No valid paths could be generated")
 
277
  path_groups = paths[0].svg_path_groups
278
  for i in range(1, len(paths)):
279
- if i < len(paths):
280
- path_groups.extend(paths[i].svg_path_groups)
281
-
282
- svg = SVG(path_groups, viewbox=Bbox(bbox))
283
 
284
- return svg
 
8
 
9
 
10
  class SVGTokenizer:
 
11
 
12
+ def __init__(self, config_path: str = "./config.yaml"):
13
+ self.NUM_SVG_END = 1
14
+ self.BASE_OFFSET = 152064
15
+ self.NUM_MASK_AND_EOM = 2 + self.BASE_OFFSET
16
 
17
+ self.CMD_MOVE_RAW = 0 + self.NUM_MASK_AND_EOM
18
+ self.CMD_LINE_RAW = 1 + self.NUM_MASK_AND_EOM
19
+ self.CMD_CURVE_RAW = 2 + self.NUM_MASK_AND_EOM
20
+ self.CMD_ARC_RAW = 3 + self.NUM_MASK_AND_EOM
21
+ self.CMD_Z_RAW = 4 + self.NUM_MASK_AND_EOM
22
+
23
+ self.PIX_PAD = 5 + self.NUM_MASK_AND_EOM
24
+ self.COORD_PAD = self.PIX_PAD
25
+
26
+ self.BBOX = 200
27
+ self.ARC_PARAM_START = 44500 + self.BASE_OFFSET
28
+
29
+
30
+ self.COLOR_TOKEN_START = 40010 + self.BASE_OFFSET
31
+ self.COLOR_TOKEN_END = self.ARC_PARAM_START - 1
32
 
33
  self.pixel2xy = self._create_pixel2xy_mapping()
34
 
35
  def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]:
 
 
 
 
 
36
  pixel2xy = {}
37
+ x = np.linspace(0, self.BBOX-1, self.BBOX)
38
+ y = np.linspace(0, self.BBOX-1, self.BBOX)
39
  xx, yy = np.meshgrid(x, y)
40
  xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int)
41
 
42
  for pixel, xy in enumerate(xy_grid):
43
+ # xy + COORD_PAD + NUM_SVG_END = xy + 151943 + 1 = xy + 151944
44
+ pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END
45
 
46
  return pixel2xy
47
 
48
  def token_to_color(self, color_token: int) -> str:
49
  try:
50
+ COLOR_TOKEN_START = 40010
51
+
52
+ if color_token == COLOR_TOKEN_START:
53
+ return "none"
54
+ elif color_token == COLOR_TOKEN_START + 1:
55
+ return "currentColor"
56
 
57
+ color_index = color_token - (COLOR_TOKEN_START + 2)
 
 
 
 
58
 
59
+ if color_index < 0 or color_index >= 4098:
60
+ print(f"Warning: Color token {color_token} out of range")
61
+ return "#808080"
 
62
 
63
+ r = (color_index >> 8) & 0xF
64
+ g = (color_index >> 4) & 0xF
65
+ b = color_index & 0xF
66
 
67
  r = (r << 4) | r
68
  g = (g << 4) | g
 
72
 
73
  except Exception as e:
74
  print(f"Error in token_to_color: {e}")
75
+ return "#808080"
76
+
77
+ def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray:
78
+
79
+ generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten()
80
+
81
+ sample_xys = []
82
+
83
+ for pixel in generated_pixels:
84
+ try:
85
+ if 151939 <= pixel < self.PIX_PAD + self.NUM_SVG_END: # < 151944
86
+ xy = np.array([pixel - self.BASE_OFFSET,
87
+ pixel - self.BASE_OFFSET]).astype(int)
88
+ sample_xys.append(xy)
89
+
90
+ elif (self.PIX_PAD + self.NUM_SVG_END <= pixel <
91
+ 40011 + self.BASE_OFFSET):
92
+ pixel_index = pixel - self.PIX_PAD - self.NUM_SVG_END
93
+ if pixel_index in self.pixel2xy:
94
+ xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET
95
+ sample_xys.append(xy)
96
+
97
+ elif (self.ARC_PARAM_START + 1 <= pixel <
98
+ self.ARC_PARAM_START + 1 + 100):
99
+ value = pixel - self.ARC_PARAM_START - 1
100
+ xy = np.array([value, value]).astype(int)
101
+ sample_xys.append(xy)
102
+
103
+ elif 40011 + self.BASE_OFFSET <= pixel < self.ARC_PARAM_START:
104
+ xy = np.array([pixel - self.BASE_OFFSET,
105
+ pixel - self.BASE_OFFSET]).astype(int)
106
+ sample_xys.append(xy)
107
+
108
+ except Exception as e:
109
+ print(f"Error processing pixel {pixel}: {e}")
110
+ continue
111
 
112
+ if sample_xys:
113
+ return np.vstack(sample_xys)
 
 
 
 
 
 
 
114
  else:
115
+ return np.array([]).reshape(0, 2)
116
+
117
+ def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]:
 
118
  try:
119
+ if len(pixels) == 0:
120
+ return [[]], []
121
+
122
+ pixels = pixels - 8
123
 
124
  svg_tensors = []
125
+ color_tensors = []
126
  path_tensor = []
 
127
 
128
+ i = 0
129
  while i < len(pixels):
130
  try:
131
  pix = pixels[i]
132
 
133
+ if pix[0] == -5:
 
 
 
134
  if i + 2 >= len(pixels):
135
+ break
136
+ cmd_tensor = np.zeros(14)
137
+ cmd_tensor[0] = 0 # Move command
138
  cmd_tensor[12:14] = pixels[i+2]
 
 
 
 
 
 
139
  path_tensor.append(cmd_tensor.tolist())
140
  i += 3
141
 
142
+ elif pix[0] == -4:
 
 
 
143
  if i + 1 >= len(pixels):
144
+ break
145
+ cmd_tensor = np.zeros(14)
146
+ cmd_tensor[0] = 1 # Line command
147
  cmd_tensor[12:14] = pixels[i+1]
148
  path_tensor.append(cmd_tensor.tolist())
149
  i += 2
150
 
151
+ elif pix[0] == -3:
 
 
 
152
  if i + 3 >= len(pixels):
153
+ break
154
+ cmd_tensor = np.zeros(14)
155
+ cmd_tensor[0] = 2 # Curve command
156
  cmd_tensor[8:10] = pixels[i+1]
157
  cmd_tensor[10:12] = pixels[i+2]
158
  cmd_tensor[12:14] = pixels[i+3]
159
  path_tensor.append(cmd_tensor.tolist())
160
  i += 4
161
 
162
+ elif pix[0] == -2:
 
 
 
163
  if i + 5 >= len(pixels):
164
+ break
165
+ cmd_tensor = np.zeros(14)
166
+ cmd_tensor[0] = 3 # Arc command
167
  radius = pixels[i+1]
168
+ x_axis_rot = pixels[i+2][0] + 8
169
+ large_arc_flg = pixels[i+3][0] + 8
170
+ sweep_flg = pixels[i+4][0] + 8
171
  end_pos = pixels[i+5]
 
172
  cmd_tensor[1:3] = radius
173
  cmd_tensor[3] = x_axis_rot
174
  cmd_tensor[4] = large_arc_flg
 
177
  path_tensor.append(cmd_tensor.tolist())
178
  i += 6
179
 
180
+ elif pix[0] == -1:
 
 
 
181
  if i + 1 >= len(pixels):
182
+ break
183
+ cmd_tensor = np.zeros(14)
184
+ cmd_tensor[0] = 6 # Close command
185
  cmd_tensor[12:14] = pixels[i+1]
186
  path_tensor.append(cmd_tensor.tolist())
187
  i += 2
188
+
189
+ elif pix[0] >= 40003:
190
+ if path_tensor:
191
+ svg_tensors.append(torch.tensor(path_tensor))
192
+ # 逆转换:还原原始颜色token
193
+ # pix[0] + 8 + 152064 - 1 - 152064 = pix[0] + 7
194
+ color_token = int(pix[0] + 7)
195
+ color_tensors.append(color_token)
196
+ path_tensor = []
197
+ i += 1
198
  else:
199
+ i += 1
200
 
201
+ except (IndexError, TypeError) as e:
202
+ print(f"Error at position {i}: {e}")
203
  break
204
+
205
  if path_tensor:
206
  svg_tensors.append(torch.tensor(path_tensor))
207
 
208
+ return [svg_tensors], color_tensors
209
 
210
  except Exception as e:
211
  print(f"Error in raster_svg: {e}")
212
+ import traceback
213
+ traceback.print_exc()
214
+ return [[]], []
215
+
216
+ def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor],
217
+ colors: Optional[List[int]]) -> SVG:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  paths = []
 
219
 
220
+ if not svg_tensors:
221
+ raise ValueError("No valid SVG tensors")
 
 
 
 
222
 
223
+ colors = colors or []
 
224
 
225
+ for i, path_tensor in enumerate(svg_tensors):
 
 
 
226
  try:
227
  path = SVGTensor.from_data(path_tensor)
228
+ path = SVG.from_tensor(path.data, viewbox=Bbox(self.BBOX))
229
 
230
+ actual_color = self.token_to_color(colors[i]) if i < len(colors) else "none"
 
 
 
 
231
 
232
  for path_group in path:
233
  path_group.color = actual_color
 
235
 
236
  path.fill_(True)
237
  paths.append(path)
 
238
 
239
  except Exception as e:
240
  print(f"Error processing path {i}: {e}")
241
  continue
242
 
243
  if not paths:
244
+ raise ValueError("No valid paths generated")
245
+
246
  path_groups = paths[0].svg_path_groups
247
  for i in range(1, len(paths)):
248
+ path_groups.extend(paths[i].svg_path_groups)
 
 
 
249
 
250
+ return SVG(path_groups, viewbox=Bbox(self.BBOX))