| | import tempfile |
| |
|
| | import numpy as np |
| | import torch |
| | import trimesh |
| | from shap_e.diffusion.gaussian_diffusion import diffusion_from_config |
| | from shap_e.diffusion.sample import sample_latents |
| | from shap_e.models.download import load_config, load_model |
| | from shap_e.models.nn.camera import (DifferentiableCameraBatch, |
| | DifferentiableProjectiveCamera) |
| | from shap_e.models.transmitter.base import Transmitter, VectorDecoder |
| | from shap_e.rendering.torch_mesh import TorchMesh |
| | from shap_e.util.collections import AttrDict |
| | from shap_e.util.image_util import load_image |
| |
|
| |
|
| | |
| | def create_pan_cameras(size: int, |
| | device: torch.device) -> DifferentiableCameraBatch: |
| | origins = [] |
| | xs = [] |
| | ys = [] |
| | zs = [] |
| | for theta in np.linspace(0, 2 * np.pi, num=20): |
| | z = np.array([np.sin(theta), np.cos(theta), -0.5]) |
| | z /= np.sqrt(np.sum(z**2)) |
| | origin = -z * 4 |
| | x = np.array([np.cos(theta), -np.sin(theta), 0.0]) |
| | y = np.cross(z, x) |
| | origins.append(origin) |
| | xs.append(x) |
| | ys.append(y) |
| | zs.append(z) |
| | return DifferentiableCameraBatch( |
| | shape=(1, len(xs)), |
| | flat_camera=DifferentiableProjectiveCamera( |
| | origin=torch.from_numpy(np.stack(origins, |
| | axis=0)).float().to(device), |
| | x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), |
| | y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), |
| | z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), |
| | width=size, |
| | height=size, |
| | x_fov=0.7, |
| | y_fov=0.7, |
| | ), |
| | ) |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def decode_latent_mesh( |
| | xm: Transmitter | VectorDecoder, |
| | latent: torch.Tensor, |
| | ) -> TorchMesh: |
| | decoded = xm.renderer.render_views( |
| | AttrDict(cameras=create_pan_cameras( |
| | 2, latent.device)), |
| | params=(xm.encoder if isinstance(xm, Transmitter) else |
| | xm).bottleneck_to_params(latent[None]), |
| | options=AttrDict(rendering_mode='stf', render_with_direction=False), |
| | ) |
| | return decoded.raw_meshes[0] |
| |
|
| |
|
| | class Model: |
| | def __init__(self): |
| | self.device = torch.device( |
| | 'cuda' if torch.cuda.is_available() else 'cpu') |
| | self.xm = load_model('transmitter', device=self.device) |
| | self.diffusion = diffusion_from_config(load_config('diffusion')) |
| | self.model_text = None |
| | self.model_image = None |
| |
|
| | def load_model(self, model_name: str) -> None: |
| | assert model_name in ['text300M', 'image300M'] |
| | if model_name == 'text300M' and self.model_text is None: |
| | self.model_text = load_model(model_name, device=self.device) |
| | elif model_name == 'image300M' and self.model_image is None: |
| | self.model_image = load_model(model_name, device=self.device) |
| |
|
| | def to_glb(self, latent: torch.Tensor) -> str: |
| | ply_path = tempfile.NamedTemporaryFile(suffix='.ply', |
| | delete=False, |
| | mode='w+b') |
| | decode_latent_mesh(self.xm, latent).tri_mesh().write_ply(ply_path) |
| |
|
| | mesh = trimesh.load(ply_path.name) |
| | rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0]) |
| | mesh = mesh.apply_transform(rot) |
| | rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0]) |
| | mesh = mesh.apply_transform(rot) |
| |
|
| | mesh_path = tempfile.NamedTemporaryFile(suffix='.glb', delete=False) |
| | mesh.export(mesh_path.name, file_type='glb') |
| |
|
| | return mesh_path.name |
| |
|
| | def run_text(self, |
| | prompt: str, |
| | seed: int = 0, |
| | guidance_scale: float = 15.0, |
| | num_steps: int = 64) -> str: |
| | self.load_model('text300M') |
| | torch.manual_seed(seed) |
| |
|
| | latents = sample_latents( |
| | batch_size=1, |
| | model=self.model_text, |
| | diffusion=self.diffusion, |
| | guidance_scale=guidance_scale, |
| | model_kwargs=dict(texts=[prompt]), |
| | progress=True, |
| | clip_denoised=True, |
| | use_fp16=True, |
| | use_karras=True, |
| | karras_steps=num_steps, |
| | sigma_min=1e-3, |
| | sigma_max=160, |
| | s_churn=0, |
| | ) |
| | return self.to_glb(latents[0]) |
| |
|
| | def run_image(self, |
| | image_path: str, |
| | seed: int = 0, |
| | guidance_scale: float = 3.0, |
| | num_steps: int = 64) -> str: |
| | self.load_model('image300M') |
| | torch.manual_seed(seed) |
| |
|
| | image = load_image(image_path) |
| | latents = sample_latents( |
| | batch_size=1, |
| | model=self.model_image, |
| | diffusion=self.diffusion, |
| | guidance_scale=guidance_scale, |
| | model_kwargs=dict(images=[image]), |
| | progress=True, |
| | clip_denoised=True, |
| | use_fp16=True, |
| | use_karras=True, |
| | karras_steps=num_steps, |
| | sigma_min=1e-3, |
| | sigma_max=160, |
| | s_churn=0, |
| | ) |
| | return self.to_glb(latents[0]) |
| |
|