| import random |
| import torch |
| from torch import nn |
| import numpy as np |
| import re |
| from einops import rearrange |
| from dataclasses import dataclass |
| from torchvision import transforms |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from transformers.utils import ModelOutput |
| from typing import Iterable, Optional, Union, List |
|
|
| import step1x3d_geometry |
| from step1x3d_geometry.utils.typing import * |
| from step1x3d_geometry.utils.misc import get_device |
|
|
| from .base import BaseLabelEncoder |
|
|
| DEFAULT_POSE = 0 |
| NUM_POSE_CLASSES = 3 |
| POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3} |
|
|
| DEFAULT_SYMMETRY_TYPE = 0 |
| NUM_SYMMETRY_TYPE_CLASSES = 2 |
| SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2} |
|
|
| DEFAULT_GEOMETRY_QUALITY = 0 |
| NUM_GEOMETRY_QUALITY_CLASSES = 3 |
| GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3} |
|
|
|
|
| @step1x3d_geometry.register("label-encoder") |
| class LabelEncoder(BaseLabelEncoder, ModelMixin): |
| """ |
| Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
| |
| Args: |
| num_classes (`int`): The number of classes. |
| hidden_size (`int`): The size of the vector embeddings. |
| """ |
|
|
| def configure(self) -> None: |
| super().configure() |
|
|
| if self.cfg.zero_uncond_embeds: |
| self.embedding_table_tpose = nn.Embedding( |
| NUM_POSE_CLASSES, self.cfg.hidden_size |
| ) |
| self.embedding_table_symmetry_type = nn.Embedding( |
| NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size |
| ) |
| self.embedding_table_geometry_quality = nn.Embedding( |
| NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size |
| ) |
| else: |
| self.embedding_table_tpose = nn.Embedding( |
| NUM_POSE_CLASSES + 1, self.cfg.hidden_size |
| ) |
| self.embedding_table_symmetry_type = nn.Embedding( |
| NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size |
| ) |
| self.embedding_table_geometry_quality = nn.Embedding( |
| NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size |
| ) |
|
|
| if self.cfg.zero_uncond_embeds: |
| self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach() |
| else: |
| self.empty_label_embeds = ( |
| self.encode_label( |
| [{"pose": "", "symetry": "", "geometry_type": ""}] |
| ).detach() |
| ) |
|
|
| |
| if self.cfg.pretrained_model_name_or_path is not None: |
| print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}") |
| ckpt = torch.load( |
| self.cfg.pretrained_model_name_or_path, map_location="cpu" |
| )["state_dict"] |
| pretrained_model_ckpt = {} |
| for k, v in ckpt.items(): |
| if k.startswith("label_condition."): |
| pretrained_model_ckpt[k.replace("label_condition.", "")] = v |
| self.load_state_dict(pretrained_model_ckpt, strict=True) |
|
|
| def encode_label(self, labels: List[dict]) -> torch.FloatTensor: |
| tpose_label_embeds = [] |
| symmetry_type_label_embeds = [] |
| geometry_quality_label_embeds = [] |
|
|
| for label in labels: |
| if "pose" in label.keys(): |
| if label["pose"] is None or label["pose"] == "": |
| tpose_label_embeds.append( |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
| ) |
| else: |
| tpose_label_embeds.append( |
| self.embedding_table_symmetry_type( |
| torch.tensor(POSE_MAPPING[label["pose"][0]]).to( |
| get_device() |
| ) |
| ) |
| ) |
| else: |
| tpose_label_embeds.append( |
| self.embedding_table_tpose( |
| torch.tensor(DEFAULT_POSE).to(get_device()) |
| ) |
| ) |
|
|
| if "symmetry" in label.keys(): |
| if label["symmetry"] is None or label["symmetry"] == "": |
| symmetry_type_label_embeds.append( |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
| ) |
| else: |
| symmetry_type_label_embeds.append( |
| self.embedding_table_symmetry_type( |
| torch.tensor( |
| SYMMETRY_TYPE_MAPPING[label["symmetry"]] |
| ).to(get_device()) |
| ) |
| ) |
| else: |
| symmetry_type_label_embeds.append( |
| self.embedding_table_symmetry_type( |
| torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device()) |
| ) |
| ) |
|
|
| if "geometry_type" in label.keys(): |
| if label["geometry_type"] is None or label["geometry_type"] == "": |
| geometry_quality_label_embeds.append( |
| torch.zeros(self.cfg.hidden_size).detach().to(get_device()) |
| ) |
| else: |
| geometry_quality_label_embeds.append( |
| self.embedding_table_geometry_quality( |
| torch.tensor( |
| GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]] |
| ).to(get_device()) |
| ) |
| ) |
| else: |
| geometry_quality_label_embeds.append( |
| self.embedding_table_geometry_quality( |
| torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device()) |
| ) |
| ) |
|
|
| tpose_label_embeds = torch.stack(tpose_label_embeds) |
| symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds) |
| geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds) |
|
|
| label_embeds = torch.stack( |
| [ |
| tpose_label_embeds, |
| symmetry_type_label_embeds, |
| geometry_quality_label_embeds, |
| ], |
| dim=1, |
| ).to(self.dtype) |
|
|
| return label_embeds |
|
|