Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| import yaml | |
| from typing import List, Tuple, Dict, Optional, Union | |
| from deepsvg.difflib.tensor import SVGTensor | |
| from deepsvg.svglib.svg import SVG | |
| from deepsvg.svglib.geom import Bbox | |
| class SVGTokenizer: | |
| """SVG tokenizer - supports both 8B and 4B models via config.yaml""" | |
| def __init__(self, config_path: str = "./config.yaml", model_size: str = None): | |
| """ | |
| Initialize SVGTokenizer. | |
| Args: | |
| config_path: Path to config.yaml | |
| model_size: Model size ("8B" or "4B"). If None, uses default from config. | |
| """ | |
| with open(config_path, 'r') as f: | |
| self.config = yaml.safe_load(f) | |
| # Determine model size | |
| self.model_size = model_size or self.config.get('default_model_size', '8B') | |
| if self.model_size not in self.config.get('models', {}): | |
| raise ValueError(f"Invalid model_size: {self.model_size}. Must be one of: {list(self.config.get('models', {}).keys())}") | |
| self._load_config() | |
| self.pixel2xy = self._create_pixel2xy_mapping() | |
| def _get_model_specific_config(self, *keys): | |
| """Get model-specific config value, with fallback to shared config.""" | |
| model_cfg = self.config.get('models', {}).get(self.model_size, {}) | |
| # Navigate through nested keys in model-specific config | |
| value = model_cfg | |
| for key in keys: | |
| if isinstance(value, dict) and key in value: | |
| value = value[key] | |
| else: | |
| value = None | |
| break | |
| # If not found in model-specific, try shared config | |
| if value is None: | |
| value = self.config | |
| for key in keys: | |
| if isinstance(value, dict) and key in value: | |
| value = value[key] | |
| else: | |
| return None | |
| return value | |
| def _load_config(self): | |
| """Load all constants from configuration file with model-specific overrides.""" | |
| # ========== Token-related configs ========== | |
| # Model-specific tokens | |
| self.NUM_MASK_AND_EOM = self._get_model_specific_config('tokens', 'num_mask_and_eom') | |
| self.BASE_OFFSET = self._get_model_specific_config('tokens', 'base_offset') | |
| # Shared tokens | |
| tokens_cfg = self.config['tokens'] | |
| self.NUM_SVG_END = tokens_cfg['svg_end'] | |
| self.NUM_END_TOKEN = tokens_cfg['num_end_token'] | |
| # ========== Coordinate-related configs ========== | |
| # Model-specific coordinates | |
| self.PIX_PAD = self._get_model_specific_config('coordinates', 'pix_pad_offset') | |
| self.COORD_PAD = self._get_model_specific_config('coordinates', 'coord_pad_offset') | |
| # Shared coordinates | |
| coords_cfg = self.config['coordinates'] | |
| self.BBOX = coords_cfg['bbox'] | |
| # ========== Color-related configs ========== | |
| colors_cfg = self.config['colors'] | |
| self.COLOR_TOKEN_START_RAW = colors_cfg['color_token_start'] | |
| self.MAX_COLOR_TOKENS = colors_cfg['max_color_tokens'] | |
| # Model-specific colors | |
| self.COLOR_START_OFFSET = self._get_model_specific_config('colors', 'color_start_offset') | |
| self.COLOR_END_OFFSET = self._get_model_specific_config('colors', 'color_end_offset') | |
| # ========== SVG command values ========== | |
| commands_cfg = self.config['svg_commands'] | |
| self.CMD_MOVE = commands_cfg['move'] | |
| self.CMD_LINE = commands_cfg['line'] | |
| self.CMD_CURVE = commands_cfg['curve'] | |
| self.CMD_ARC = commands_cfg['arc'] | |
| self.CMD_CLOSE = commands_cfg['close'] | |
| # ========== Model-related configs ========== | |
| model_cfg = self.config['model'] | |
| self.BOS_TOKEN_ID = model_cfg['bos_token_id'] | |
| self.EOS_TOKEN_ID = model_cfg['eos_token_id'] | |
| self.PAD_TOKEN_ID = model_cfg['pad_token_id'] | |
| # ========== Arc parameter configs ========== | |
| arc_cfg = self.config.get('arc', {}) | |
| self.ARC_PARAM_OFFSET = arc_cfg.get('param_offset', 44500) | |
| self.ARC_PARAM_RANGE = arc_cfg.get('param_range', 100) | |
| self.ARC_PARAM_START = self.ARC_PARAM_OFFSET + self.BASE_OFFSET | |
| # ========== Derived constants ========== | |
| self.PIXEL_OFFSET = (self.NUM_MASK_AND_EOM - self.BASE_OFFSET + | |
| self.NUM_SVG_END - self.CMD_MOVE) | |
| # Command token range | |
| self.CMD_TOKEN_START = self.NUM_MASK_AND_EOM + self.NUM_SVG_END | |
| self.CMD_TOKEN_END = self.PIX_PAD + self.NUM_SVG_END | |
| # Coordinate token start | |
| self.COORD_TOKEN_START = self.PIX_PAD + self.NUM_SVG_END | |
| # Color-coordinate boundary | |
| self.COLOR_COORD_BOUNDARY = self.COLOR_TOKEN_START_RAW + 1 + self.BASE_OFFSET | |
| # Color threshold for raster_svg | |
| self.COLOR_THRESHOLD = self.COLOR_TOKEN_START_RAW - self.PIXEL_OFFSET + 1 | |
| def _create_pixel2xy_mapping(self) -> Dict[int, np.ndarray]: | |
| """Create pixel to xy mapping following dataset.py logic.""" | |
| pixel2xy = {} | |
| x = np.linspace(0, self.BBOX - 1, self.BBOX) | |
| y = np.linspace(0, self.BBOX - 1, self.BBOX) | |
| xx, yy = np.meshgrid(x, y) | |
| xy_grid = (np.array((xx.ravel(), yy.ravel())).T).astype(int) | |
| for pixel, xy in enumerate(xy_grid): | |
| pixel2xy[pixel] = xy + self.COORD_PAD + self.NUM_SVG_END | |
| return pixel2xy | |
| def token_to_color(self, color_token: int) -> str: | |
| """Convert token to color following dataset.py logic.""" | |
| try: | |
| if color_token == self.COLOR_TOKEN_START_RAW: | |
| return "none" | |
| elif color_token == self.COLOR_TOKEN_START_RAW + 1: | |
| return "currentColor" | |
| color_index = color_token - (self.COLOR_TOKEN_START_RAW + 2) | |
| if color_index < 0 or color_index >= self.MAX_COLOR_TOKENS: | |
| print(f"Warning: Color token {color_token} out of range") | |
| return "#808080" | |
| r = (color_index >> 8) & 0xF | |
| g = (color_index >> 4) & 0xF | |
| b = color_index & 0xF | |
| r = (r << 4) | r | |
| g = (g << 4) | g | |
| b = (b << 4) | b | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| except Exception as e: | |
| print(f"Error in token_to_color: {e}") | |
| return "#808080" | |
| def process_generated_tokens(self, output_ids: torch.Tensor) -> np.ndarray: | |
| """Process generated tokens following dataset.py logic.""" | |
| # Remove bos/eos | |
| generated_pixels = output_ids[:, 1:-1].cpu().numpy().flatten() | |
| sample_xys = [] | |
| for pixel in generated_pixels: | |
| try: | |
| # 1. Command tokens: CMD_TOKEN_START <= pixel < CMD_TOKEN_END | |
| if self.CMD_TOKEN_START <= pixel < self.CMD_TOKEN_END: | |
| xy = np.array([pixel - self.BASE_OFFSET, | |
| pixel - self.BASE_OFFSET]).astype(int) | |
| sample_xys.append(xy) | |
| # 2. Coordinate tokens: COORD_TOKEN_START <= pixel < COLOR_COORD_BOUNDARY | |
| elif self.COORD_TOKEN_START <= pixel < self.COLOR_COORD_BOUNDARY: | |
| pixel_index = pixel - self.COORD_TOKEN_START | |
| if pixel_index in self.pixel2xy: | |
| xy = self.pixel2xy[pixel_index] - self.BASE_OFFSET | |
| sample_xys.append(xy) | |
| # 3. Arc parameters: ARC_PARAM_START + 1 <= pixel < ARC_PARAM_START + 1 + ARC_PARAM_RANGE | |
| elif (self.ARC_PARAM_START + 1 <= pixel < | |
| self.ARC_PARAM_START + 1 + self.ARC_PARAM_RANGE): | |
| value = pixel - self.ARC_PARAM_START - 1 | |
| xy = np.array([value, value]).astype(int) | |
| sample_xys.append(xy) | |
| # 4. Color tokens: COLOR_COORD_BOUNDARY <= pixel < ARC_PARAM_START | |
| elif self.COLOR_COORD_BOUNDARY <= pixel < self.ARC_PARAM_START: | |
| xy = np.array([pixel - self.BASE_OFFSET, | |
| pixel - self.BASE_OFFSET]).astype(int) | |
| sample_xys.append(xy) | |
| except Exception as e: | |
| print(f"Error processing pixel {pixel}: {e}") | |
| continue | |
| if sample_xys: | |
| return np.vstack(sample_xys) | |
| else: | |
| return np.array([]).reshape(0, 2) | |
| def raster_svg(self, pixels: np.ndarray) -> Tuple[List[List[torch.Tensor]], List[int]]: | |
| """Convert pixels to SVG tensors following dataset.py logic.""" | |
| try: | |
| if len(pixels) == 0: | |
| return [[]], [] | |
| # Key step: subtract PIXEL_OFFSET | |
| pixels = pixels - self.PIXEL_OFFSET | |
| svg_tensors = [] | |
| color_tensors = [] | |
| path_tensor = [] | |
| i = 0 | |
| while i < len(pixels): | |
| try: | |
| pix = pixels[i] | |
| # Move command | |
| if pix[0] == self.CMD_MOVE: | |
| if i + 2 >= len(pixels): | |
| break | |
| cmd_tensor = np.zeros(14) | |
| cmd_tensor[0] = 0 # Move command index | |
| cmd_tensor[12:14] = pixels[i+2] | |
| path_tensor.append(cmd_tensor.tolist()) | |
| i += 3 | |
| # Line command | |
| elif pix[0] == self.CMD_LINE: | |
| if i + 1 >= len(pixels): | |
| break | |
| cmd_tensor = np.zeros(14) | |
| cmd_tensor[0] = 1 # Line command index | |
| cmd_tensor[12:14] = pixels[i+1] | |
| path_tensor.append(cmd_tensor.tolist()) | |
| i += 2 | |
| # Curve command | |
| elif pix[0] == self.CMD_CURVE: | |
| if i + 3 >= len(pixels): | |
| break | |
| cmd_tensor = np.zeros(14) | |
| cmd_tensor[0] = 2 # Curve command index | |
| cmd_tensor[8:10] = pixels[i+1] | |
| cmd_tensor[10:12] = pixels[i+2] | |
| cmd_tensor[12:14] = pixels[i+3] | |
| path_tensor.append(cmd_tensor.tolist()) | |
| i += 4 | |
| # Arc command | |
| elif pix[0] == self.CMD_ARC: | |
| if i + 5 >= len(pixels): | |
| break | |
| cmd_tensor = np.zeros(14) | |
| cmd_tensor[0] = 3 # Arc command index | |
| radius = pixels[i+1] | |
| x_axis_rot = pixels[i+2][0] + self.PIXEL_OFFSET | |
| large_arc_flg = pixels[i+3][0] + self.PIXEL_OFFSET | |
| sweep_flg = pixels[i+4][0] + self.PIXEL_OFFSET | |
| end_pos = pixels[i+5] | |
| cmd_tensor[1:3] = radius | |
| cmd_tensor[3] = x_axis_rot | |
| cmd_tensor[4] = large_arc_flg | |
| cmd_tensor[5] = sweep_flg | |
| cmd_tensor[12:14] = end_pos | |
| path_tensor.append(cmd_tensor.tolist()) | |
| i += 6 | |
| # Close command | |
| elif pix[0] == self.CMD_CLOSE: | |
| if i + 1 >= len(pixels): | |
| break | |
| cmd_tensor = np.zeros(14) | |
| cmd_tensor[0] = 6 # Close command index | |
| cmd_tensor[12:14] = pixels[i+1] | |
| path_tensor.append(cmd_tensor.tolist()) | |
| i += 2 | |
| # Color token: pix[0] >= COLOR_THRESHOLD | |
| elif pix[0] >= self.COLOR_THRESHOLD: | |
| if path_tensor: | |
| svg_tensors.append(torch.tensor(path_tensor)) | |
| # Reverse transform: restore original color token | |
| color_token = int(pix[0] + self.PIXEL_OFFSET - 1) | |
| color_tensors.append(color_token) | |
| path_tensor = [] | |
| i += 1 | |
| else: | |
| i += 1 | |
| except (IndexError, TypeError) as e: | |
| print(f"Error at position {i}: {e}") | |
| break | |
| # Handle remaining path (without color) | |
| if path_tensor: | |
| svg_tensors.append(torch.tensor(path_tensor)) | |
| return [svg_tensors], color_tensors | |
| except Exception as e: | |
| print(f"Error in raster_svg: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [[]], [] | |
| def apply_colors_to_svg(self, svg_tensors: List[torch.Tensor], | |
| colors: Optional[List[int]]) -> SVG: | |
| """Apply colors and create final SVG.""" | |
| paths = [] | |
| if not svg_tensors: | |
| raise ValueError("No valid SVG tensors") | |
| colors = colors or [] | |
| for i, path_tensor in enumerate(svg_tensors): | |
| try: | |
| path = SVGTensor.from_data(path_tensor) | |
| path = SVG.from_tensor(path.data, viewbox=Bbox(self.BBOX)) | |
| actual_color = self.token_to_color(colors[i]) if i < len(colors) else "none" | |
| for path_group in path: | |
| path_group.color = actual_color | |
| path_group.stroke_color = "none" | |
| path.fill_(True) | |
| paths.append(path) | |
| except Exception as e: | |
| print(f"Error processing path {i}: {e}") | |
| continue | |
| if not paths: | |
| raise ValueError("No valid paths generated") | |
| path_groups = paths[0].svg_path_groups | |
| for i in range(1, len(paths)): | |
| path_groups.extend(paths[i].svg_path_groups) | |
| return SVG(path_groups, viewbox=Bbox(self.BBOX)) | |