Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,157 Bytes
a846205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from typing import Optional
import torch.nn.functional as Fn
import math
import copy
from . import register
from .base import Base
class BaseLight(Base):
"""
Base class for light models.
"""
def setup(self):
pass
def forward(self, x: Optional[torch.Tensor] = None):
"""
Get the light intensity.
Args:
x: positions of shape (..., 3).
Returns:
color: radiance intensity of shape (..., 3)
d: directions of shape (..., 3).
"""
raise NotImplementedError
@register("point-light")
class PointLight(BaseLight):
"""Point light definitions
"""
def setup(self):
"""Initialize point light.
Args:
position (float, float, float): World coordinate of the light.
color (float, float, float): Light color in (R, G, B).
power (float): Light power, it will be directly multiplied to each color channel.
"""
position = self.config.get("position", [0., 0., 10.])
color = self.config.get("color", [23.47, 21.31, 20.79])
power = self.config.get("power", 10.)
self.register_buffer("position", torch.tensor(position))
self.register_buffer("color", torch.tensor(color) * power)
def forward(self, x: Optional[torch.Tensor] = None):
"""Compute light radiance and direction.
Args:
x : World coordinate of the interacting surface. [B, H, W, 3]
Returns:
color: radiance intensity of shape [B, H, W, 3]
d: directions of shape [B, H, W, 3], V = (light_pos - world_pos)
"""
distance = torch.norm(self.position - x, dim=-1, keepdim=True)
attenuation = 1.0 / (distance ** 2)
radiance = self.color * attenuation
direction = Fn.normalize(self.position - x, dim=-1)
return radiance, direction
@register("distant-light")
class DistantLight(BaseLight):
"""Distant light definitions
"""
def setup(self):
"""Initialize distant light.
Args:
direction (float, float, float):The direction of light vector.
color (float, float, float): Light color in (R, G, B).
power (float): Light power, it will be directly multiplied to each color channel.
"""
direction = self.config.get("direction", [0., 0., 1.])
color = self.config.get("color", [23.47, 21.31, 20.79])
power = self.config.get("power", 0.1)
self.register_buffer("color", torch.tensor(color) * power)
self.register_buffer("direction", Fn.normalize(torch.tensor(direction), dim=0))
def forward(self, x: Optional[torch.Tensor] = None):
"""Compute light radiance and direction.
Args:
x : World coordinate of the interacting surface. [B, H, W, 3]
Returns:
color: radiance intensity of shape [B, H, W, 3]
d: directions of shape [B, H, W, 3]
"""
radiance = self.color.repeat(*x.shape[:-1], 1)
direction = self.direction.repeat(*x.shape[:-1], 1)
return radiance, direction |