| import math |
| import os |
| from dataclasses import dataclass, field |
| from typing import List, Union |
|
|
| import numpy as np |
| import PIL.Image |
| import torch |
| import torch.nn.functional as F |
| import trimesh |
| from einops import rearrange |
| from huggingface_hub import hf_hub_download |
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
| from .models.isosurface import MarchingCubeHelper |
| from .utils import ( |
| BaseModule, |
| ImagePreprocessor, |
| find_class, |
| get_spherical_cameras, |
| scale_tensor, |
| ) |
|
|
|
|
| class TSR(BaseModule): |
| @dataclass |
| class Config(BaseModule.Config): |
| cond_image_size: int |
|
|
| image_tokenizer_cls: str |
| image_tokenizer: dict |
|
|
| tokenizer_cls: str |
| tokenizer: dict |
|
|
| backbone_cls: str |
| backbone: dict |
|
|
| post_processor_cls: str |
| post_processor: dict |
|
|
| decoder_cls: str |
| decoder: dict |
|
|
| renderer_cls: str |
| renderer: dict |
|
|
| cfg: Config |
|
|
| @classmethod |
| def from_pretrained( |
| cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None |
| ): |
| if os.path.isdir(pretrained_model_name_or_path): |
| config_path = os.path.join(pretrained_model_name_or_path, config_name) |
| weight_path = os.path.join(pretrained_model_name_or_path, weight_name) |
| else: |
| config_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, filename=config_name, token=token |
| ) |
| weight_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, filename=weight_name, token=token |
| ) |
|
|
| cfg = OmegaConf.load(config_path) |
| OmegaConf.resolve(cfg) |
| model = cls(cfg) |
| ckpt = torch.load(weight_path, map_location="cpu") |
| model.load_state_dict(ckpt) |
| return model |
|
|
| def configure(self): |
| self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)( |
| self.cfg.image_tokenizer |
| ) |
| self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer) |
| self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone) |
| self.post_processor = find_class(self.cfg.post_processor_cls)( |
| self.cfg.post_processor |
| ) |
| self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder) |
| self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer) |
| self.image_processor = ImagePreprocessor() |
| self.isosurface_helper = None |
|
|
| def forward( |
| self, |
| image: Union[ |
| PIL.Image.Image, |
| np.ndarray, |
| torch.FloatTensor, |
| List[PIL.Image.Image], |
| List[np.ndarray], |
| List[torch.FloatTensor], |
| ], |
| device: str, |
| ) -> torch.FloatTensor: |
| rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to( |
| device |
| ) |
| batch_size = rgb_cond.shape[0] |
|
|
| input_image_tokens: torch.Tensor = self.image_tokenizer( |
| rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1), |
| ) |
|
|
| input_image_tokens = rearrange( |
| input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1 |
| ) |
|
|
| tokens: torch.Tensor = self.tokenizer(batch_size) |
|
|
| tokens = self.backbone( |
| tokens, |
| encoder_hidden_states=input_image_tokens, |
| ) |
|
|
| scene_codes = self.post_processor(self.tokenizer.detokenize(tokens)) |
| return scene_codes |
|
|
| def render( |
| self, |
| scene_codes, |
| n_views: int, |
| elevation_deg: float = 0.0, |
| camera_distance: float = 1.9, |
| fovy_deg: float = 40.0, |
| height: int = 256, |
| width: int = 256, |
| return_type: str = "pil", |
| ): |
| rays_o, rays_d = get_spherical_cameras( |
| n_views, elevation_deg, camera_distance, fovy_deg, height, width |
| ) |
| rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device) |
|
|
| def process_output(image: torch.FloatTensor): |
| if return_type == "pt": |
| return image |
| elif return_type == "np": |
| return image.detach().cpu().numpy() |
| elif return_type == "pil": |
| return Image.fromarray( |
| (image.detach().cpu().numpy() * 255.0).astype(np.uint8) |
| ) |
| else: |
| raise NotImplementedError |
|
|
| images = [] |
| for scene_code in scene_codes: |
| images_ = [] |
| for i in range(n_views): |
| with torch.no_grad(): |
| image = self.renderer( |
| self.decoder, scene_code, rays_o[i], rays_d[i] |
| ) |
| images_.append(process_output(image)) |
| images.append(images_) |
|
|
| return images |
|
|
| def set_marching_cubes_resolution(self, resolution: int): |
| if ( |
| self.isosurface_helper is not None |
| and self.isosurface_helper.resolution == resolution |
| ): |
| return |
| self.isosurface_helper = MarchingCubeHelper(resolution) |
|
|
| def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0): |
| self.set_marching_cubes_resolution(resolution) |
| meshes = [] |
| for scene_code in scene_codes: |
| with torch.no_grad(): |
| density = self.renderer.query_triplane( |
| self.decoder, |
| scale_tensor( |
| self.isosurface_helper.grid_vertices.to(scene_codes.device), |
| self.isosurface_helper.points_range, |
| (-self.renderer.cfg.radius, self.renderer.cfg.radius), |
| ), |
| scene_code, |
| )["density_act"] |
| v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold)) |
| v_pos = scale_tensor( |
| v_pos, |
| self.isosurface_helper.points_range, |
| (-self.renderer.cfg.radius, self.renderer.cfg.radius), |
| ) |
| with torch.no_grad(): |
| color = self.renderer.query_triplane( |
| self.decoder, |
| v_pos, |
| scene_code, |
| )["color"] |
| mesh = trimesh.Trimesh( |
| vertices=v_pos.cpu().numpy(), |
| faces=t_pos_idx.cpu().numpy(), |
| vertex_colors=color.cpu().numpy(), |
| ) |
| meshes.append(mesh) |
| return meshes |
|
|