| from typing import List |
| from pydantic import validator |
|
|
| from my.config import BaseConf, SingleOrList, dispatch |
| from my.utils.seed import seed_everything |
|
|
| import numpy as np |
| from voxnerf.vox import VOXRF_REGISTRY |
| from voxnerf.pipelines import train |
|
|
|
|
| class VoxConfig(BaseConf): |
| model_type: str = "VoxRF" |
| bbox_len: float = 1.5 |
| grid_size: SingleOrList(int) = [128, 128, 128] |
| step_ratio: float = 0.5 |
| density_shift: float = -10. |
| ray_march_weight_thres: float = 0.0001 |
| c: int = 3 |
| blend_bg_texture: bool = False |
| bg_texture_hw: int = 64 |
|
|
| @validator("grid_size") |
| def check_gsize(cls, grid_size): |
| if isinstance(grid_size, int): |
| return [grid_size, ] * 3 |
| else: |
| assert len(grid_size) == 3 |
| return grid_size |
|
|
| def make(self): |
| params = self.dict() |
| m_type = params.pop("model_type") |
| model_fn = VOXRF_REGISTRY.get(m_type) |
|
|
| radius = params.pop('bbox_len') |
| aabb = radius * np.array([ |
| [-1, -1, -1], |
| [1, 1, 1] |
| ]) |
| model = model_fn(aabb=aabb, **params) |
| return model |
|
|
|
|
| class TrainerConfig(BaseConf): |
| model: VoxConfig = VoxConfig() |
| scene: str = "lego" |
| n_epoch: int = 2 |
| bs: int = 4096 |
| lr: float = 0.02 |
|
|
| def run(self): |
| args = self.dict() |
| args.pop("model") |
|
|
| model = self.model.make() |
| train(model, **args) |
|
|
|
|
| if __name__ == "__main__": |
| seed_everything(0) |
| dispatch(TrainerConfig) |
|
|