| from typing import Optional |
|
|
| import torch |
| import pytorch3d |
|
|
|
|
| from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj |
| from pytorch3d.ops import interpolate_face_attributes |
|
|
| from pytorch3d.structures import Meshes |
| from pytorch3d.renderer import ( |
| look_at_view_transform, |
| FoVPerspectiveCameras, |
| AmbientLights, |
| PointLights, |
| DirectionalLights, |
| Materials, |
| RasterizationSettings, |
| MeshRenderer, |
| MeshRasterizer, |
| SoftPhongShader, |
| SoftSilhouetteShader, |
| HardPhongShader, |
| TexturesVertex, |
| TexturesUV, |
| Materials, |
| ) |
| from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend |
| from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties |
|
|
| from pytorch3d.renderer.lighting import AmbientLights |
| from pytorch3d.renderer.materials import Materials |
| from pytorch3d.renderer.mesh.shader import ShaderBase |
| from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading |
| from pytorch3d.renderer.mesh.rasterizer import Fragments |
|
|
|
|
| """ |
| Customized the original pytorch3d hard flat shader to support N channel flat shading |
| """ |
|
|
|
|
| class HardNChannelFlatShader(ShaderBase): |
| """ |
| Per face lighting - the lighting model is applied using the average face |
| position and the face normal. The blending function hard assigns |
| the color of the closest face for each pixel. |
| |
| To use the default values, simply initialize the shader with the desired |
| device e.g. |
| |
| .. code-block:: |
| |
| shader = HardFlatShader(device=torch.device("cuda:0")) |
| """ |
|
|
| def __init__( |
| self, |
| device="cpu", |
| cameras: Optional[TensorProperties] = None, |
| lights: Optional[TensorProperties] = None, |
| materials: Optional[Materials] = None, |
| blend_params: Optional[BlendParams] = None, |
| channels: int = 3, |
| ): |
| self.channels = channels |
| ones = ((1.0,) * channels,) |
| zeros = ((0.0,) * channels,) |
|
|
| if ( |
| not isinstance(lights, AmbientLights) |
| or not lights.ambient_color.shape[-1] == channels |
| ): |
| lights = AmbientLights( |
| ambient_color=ones, |
| device=device, |
| ) |
|
|
| if not materials or not materials.ambient_color.shape[-1] == channels: |
| materials = Materials( |
| device=device, |
| diffuse_color=zeros, |
| ambient_color=ones, |
| specular_color=zeros, |
| shininess=0.0, |
| ) |
|
|
| blend_params_new = BlendParams(background_color=(1.0,) * channels) |
| if not isinstance(blend_params, BlendParams): |
| blend_params = blend_params_new |
| else: |
| background_color_ = blend_params.background_color |
| if ( |
| isinstance(background_color_, Sequence[float]) |
| and not len(background_color_) == channels |
| ): |
| blend_params = blend_params_new |
| if ( |
| isinstance(background_color_, torch.Tensor) |
| and not background_color_.shape[-1] == channels |
| ): |
| blend_params = blend_params_new |
|
|
| super().__init__( |
| device, |
| cameras, |
| lights, |
| materials, |
| blend_params, |
| ) |
|
|
| def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: |
| cameras = super()._get_cameras(**kwargs) |
| texels = meshes.sample_textures(fragments) |
| lights = kwargs.get("lights", self.lights) |
| materials = kwargs.get("materials", self.materials) |
| blend_params = kwargs.get("blend_params", self.blend_params) |
| colors = flat_shading( |
| meshes=meshes, |
| fragments=fragments, |
| texels=texels, |
| lights=lights, |
| cameras=cameras, |
| materials=materials, |
| ) |
| images = hard_rgb_blend(colors, fragments, blend_params) |
| return images |
|
|