Spaces:
Paused
Paused
dylanebert
commited on
Commit
·
ddefedb
1
Parent(s):
02afac0
working checkpoint
Browse files- .gitattributes +2 -0
- .gitignore +4 -0
- Dockerfile +54 -0
- LICENSE +21 -0
- README.md +1 -0
- __pycache__/app.cpython-310.pyc +0 -0
- app.py +63 -0
- convert.py +548 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-310.pyc +0 -0
- core/__pycache__/gs.cpython-310.pyc +0 -0
- core/__pycache__/options.cpython-310.pyc +0 -0
- core/attention.py +156 -0
- core/gs.py +190 -0
- core/models.py +174 -0
- core/options.py +120 -0
- core/provider_objaverse.py +172 -0
- core/unet.py +319 -0
- core/utils.py +109 -0
- data_test/catstatue.ply +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ply filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
|
| 3 |
+
venv/
|
| 4 |
+
gradio_cached_examples/
|
Dockerfile
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Configure environment
|
| 4 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 5 |
+
|
| 6 |
+
# Install the required packages
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
software-properties-common
|
| 9 |
+
|
| 10 |
+
# Add the deadsnakes PPA
|
| 11 |
+
RUN add-apt-repository ppa:deadsnakes/ppa
|
| 12 |
+
|
| 13 |
+
# Install Python 3.10
|
| 14 |
+
RUN apt-get update && apt-get install -y \
|
| 15 |
+
python3.10 \
|
| 16 |
+
python3.10-dev \
|
| 17 |
+
python3.10-distutils \
|
| 18 |
+
python3.10-venv \
|
| 19 |
+
python3-pip
|
| 20 |
+
|
| 21 |
+
# Install other dependencies
|
| 22 |
+
RUN apt-get install -y \
|
| 23 |
+
git \
|
| 24 |
+
gcc \
|
| 25 |
+
g++ \
|
| 26 |
+
libgl1 \
|
| 27 |
+
libglib2.0.0 \
|
| 28 |
+
ffmpeg \
|
| 29 |
+
cmake \
|
| 30 |
+
libgtk2.0.0
|
| 31 |
+
|
| 32 |
+
# Working directory
|
| 33 |
+
RUN useradd -m -u 1000 user
|
| 34 |
+
USER user
|
| 35 |
+
ENV HOME=/home/user \
|
| 36 |
+
PATH=/home/user/.local/bin:$PATH
|
| 37 |
+
WORKDIR $HOME/app
|
| 38 |
+
|
| 39 |
+
# Install the required Python packages
|
| 40 |
+
RUN pip install wheel
|
| 41 |
+
RUN pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 torchtext==0.16.0 torchdata==0.7.0 --extra-index-url https://download.pytorch.org/whl/cu121 -U
|
| 42 |
+
RUN sed -i 's/return caster.operator typename make_caster<T>::template cast_op_type<T>();/return caster;/' /home/user/.local/lib/python3.10/site-packages/torch/include/pybind11/cast.h
|
| 43 |
+
RUN pip install tyro kiui PyMCubes nerfacc trimesh pymeshlab ninja plyfile xatlas pygltflib gradio opencv-python scikit-learn
|
| 44 |
+
RUN pip install https://github.com/camenduru/LGM-replicate/releases/download/replicate/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
|
| 45 |
+
RUN pip install https://github.com/camenduru/wheels/releases/download/colab/nvdiffrast-0.3.1-py3-none-any.whl
|
| 46 |
+
RUN pip install git+https://github.com/ashawkey/kiuikit.git
|
| 47 |
+
|
| 48 |
+
# Copy all files to the working directory
|
| 49 |
+
COPY --chown=user . $HOME/app
|
| 50 |
+
|
| 51 |
+
EXPOSE 7860
|
| 52 |
+
|
| 53 |
+
# Run the gradio app
|
| 54 |
+
CMD ["python3.10", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 3D Topia
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -6,6 +6,7 @@ colorTo: yellow
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import subprocess
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def run(input_ply):
|
| 6 |
+
subprocess.run(
|
| 7 |
+
"python3.10 convert.py big --force-cuda-rast --test_path " + input_ply,
|
| 8 |
+
shell=True,
|
| 9 |
+
)
|
| 10 |
+
return input_ply.replace(".ply", ".glb")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
title = """Splat to Mesh"""
|
| 15 |
+
|
| 16 |
+
description = """
|
| 17 |
+
Converts Gaussian Splat (.ply) to Mesh (.glb) using [LGM](https://github.com/3DTopia/LGM).
|
| 18 |
+
|
| 19 |
+
For faster inference without waiting in a queue, you may duplicate the space and upgrade to a GPU in the settings.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
css = """
|
| 23 |
+
#duplicate-button {
|
| 24 |
+
margin: auto;
|
| 25 |
+
color: white;
|
| 26 |
+
background: #1565c0;
|
| 27 |
+
border-radius: 100vh;
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
with gr.Blocks(title=title, css=css) as demo:
|
| 32 |
+
gr.DuplicateButton(
|
| 33 |
+
value="Duplicate Space for private use", elem_id="duplicate-button"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
with gr.Row():
|
| 37 |
+
with gr.Column():
|
| 38 |
+
gr.Markdown("# " + title + "\n" + description)
|
| 39 |
+
|
| 40 |
+
with gr.Row(variant="panel"):
|
| 41 |
+
with gr.Column():
|
| 42 |
+
input_ply = gr.Model3D(label="Input Splat")
|
| 43 |
+
button_gen = gr.Button("Convert")
|
| 44 |
+
|
| 45 |
+
with gr.Column():
|
| 46 |
+
output_glb = gr.Model3D(label="Output GLB")
|
| 47 |
+
|
| 48 |
+
button_gen.click(run, inputs=[input_ply], outputs=[output_glb])
|
| 49 |
+
|
| 50 |
+
gr.Examples(
|
| 51 |
+
["data_test/catstatue.ply"],
|
| 52 |
+
inputs=[input_ply],
|
| 53 |
+
outputs=[output_glb],
|
| 54 |
+
fn=lambda x: run(x),
|
| 55 |
+
cache_examples=True,
|
| 56 |
+
label="Examples",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
convert.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tyro
|
| 2 |
+
import tqdm
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from core.options import AllConfigs, Options
|
| 9 |
+
from core.gs import GaussianRenderer
|
| 10 |
+
|
| 11 |
+
import mcubes
|
| 12 |
+
import nerfacc
|
| 13 |
+
import nvdiffrast.torch as dr
|
| 14 |
+
|
| 15 |
+
from kiui.mesh import Mesh
|
| 16 |
+
from kiui.mesh_utils import clean_mesh, decimate_mesh
|
| 17 |
+
from kiui.mesh_utils import normal_consistency
|
| 18 |
+
from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
|
| 19 |
+
from kiui.cam import orbit_camera, get_perspective
|
| 20 |
+
from kiui.nn import MLP, trunc_exp
|
| 21 |
+
from kiui.gridencoder import GridEncoder
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_rays(pose, h, w, fovy, opengl=True):
|
| 25 |
+
|
| 26 |
+
x, y = torch.meshgrid(
|
| 27 |
+
torch.arange(w, device=pose.device),
|
| 28 |
+
torch.arange(h, device=pose.device),
|
| 29 |
+
indexing="xy",
|
| 30 |
+
)
|
| 31 |
+
x = x.flatten()
|
| 32 |
+
y = y.flatten()
|
| 33 |
+
|
| 34 |
+
cx = w * 0.5
|
| 35 |
+
cy = h * 0.5
|
| 36 |
+
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
|
| 37 |
+
|
| 38 |
+
camera_dirs = F.pad(
|
| 39 |
+
torch.stack(
|
| 40 |
+
[
|
| 41 |
+
(x - cx + 0.5) / focal,
|
| 42 |
+
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
|
| 43 |
+
],
|
| 44 |
+
dim=-1,
|
| 45 |
+
),
|
| 46 |
+
(0, 1),
|
| 47 |
+
value=(-1.0 if opengl else 1.0),
|
| 48 |
+
) # [hw, 3]
|
| 49 |
+
|
| 50 |
+
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
| 51 |
+
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
|
| 52 |
+
|
| 53 |
+
rays_d = safe_normalize(rays_d)
|
| 54 |
+
|
| 55 |
+
return rays_o, rays_d
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Triple renderer of gaussians, gaussian, and diso mesh.
|
| 59 |
+
# gaussian --> nerf --> mesh
|
| 60 |
+
class Converter(nn.Module):
|
| 61 |
+
def __init__(self, opt: Options):
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
self.opt = opt
|
| 65 |
+
self.device = torch.device("cuda")
|
| 66 |
+
|
| 67 |
+
# gs renderer
|
| 68 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
| 69 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
|
| 70 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
| 71 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
| 72 |
+
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
| 73 |
+
self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
| 74 |
+
self.proj_matrix[2, 3] = 1
|
| 75 |
+
|
| 76 |
+
self.gs_renderer = GaussianRenderer(opt)
|
| 77 |
+
|
| 78 |
+
self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
|
| 79 |
+
|
| 80 |
+
# nerf renderer
|
| 81 |
+
if not self.opt.force_cuda_rast:
|
| 82 |
+
self.glctx = dr.RasterizeGLContext()
|
| 83 |
+
else:
|
| 84 |
+
self.glctx = dr.RasterizeCudaContext()
|
| 85 |
+
|
| 86 |
+
self.step = 0
|
| 87 |
+
self.render_step_size = 5e-3
|
| 88 |
+
self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
|
| 89 |
+
self.estimator = nerfacc.OccGridEstimator(
|
| 90 |
+
roi_aabb=self.aabb, resolution=64, levels=1
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.encoder_density = GridEncoder(
|
| 94 |
+
num_levels=12
|
| 95 |
+
) # VMEncoder(output_dim=16, mode='sum')
|
| 96 |
+
self.encoder = GridEncoder(num_levels=12)
|
| 97 |
+
self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
|
| 98 |
+
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
|
| 99 |
+
|
| 100 |
+
# mesh renderer
|
| 101 |
+
self.proj = (
|
| 102 |
+
torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
|
| 103 |
+
)
|
| 104 |
+
self.v = self.f = None
|
| 105 |
+
self.vt = self.ft = None
|
| 106 |
+
self.deform = None
|
| 107 |
+
self.albedo = None
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def render_gs(self, pose):
|
| 111 |
+
|
| 112 |
+
cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
| 113 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 114 |
+
|
| 115 |
+
# cameras needed by gaussian rasterizer
|
| 116 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 117 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
| 118 |
+
cam_pos = -cam_poses[:, :3, 3] # [V, 3]
|
| 119 |
+
|
| 120 |
+
out = self.gs_renderer.render(
|
| 121 |
+
self.gaussians.unsqueeze(0),
|
| 122 |
+
cam_view.unsqueeze(0),
|
| 123 |
+
cam_view_proj.unsqueeze(0),
|
| 124 |
+
cam_pos.unsqueeze(0),
|
| 125 |
+
)
|
| 126 |
+
image = out["image"].squeeze(1).squeeze(0) # [C, H, W]
|
| 127 |
+
alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W]
|
| 128 |
+
|
| 129 |
+
return image, alpha
|
| 130 |
+
|
| 131 |
+
def get_density(self, xs):
|
| 132 |
+
# xs: [..., 3]
|
| 133 |
+
prefix = xs.shape[:-1]
|
| 134 |
+
xs = xs.view(-1, 3)
|
| 135 |
+
feats = self.encoder_density(xs)
|
| 136 |
+
density = trunc_exp(self.mlp_density(feats))
|
| 137 |
+
density = density.view(*prefix, 1)
|
| 138 |
+
return density
|
| 139 |
+
|
| 140 |
+
def render_nerf(self, pose):
|
| 141 |
+
|
| 142 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
| 143 |
+
|
| 144 |
+
# get rays
|
| 145 |
+
resolution = self.opt.output_size
|
| 146 |
+
rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
|
| 147 |
+
|
| 148 |
+
# update occ grid
|
| 149 |
+
if self.training:
|
| 150 |
+
|
| 151 |
+
def occ_eval_fn(xs):
|
| 152 |
+
sigmas = self.get_density(xs)
|
| 153 |
+
return self.render_step_size * sigmas
|
| 154 |
+
|
| 155 |
+
self.estimator.update_every_n_steps(
|
| 156 |
+
self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8
|
| 157 |
+
)
|
| 158 |
+
self.step += 1
|
| 159 |
+
|
| 160 |
+
# render
|
| 161 |
+
def sigma_fn(t_starts, t_ends, ray_indices):
|
| 162 |
+
t_origins = rays_o[ray_indices]
|
| 163 |
+
t_dirs = rays_d[ray_indices]
|
| 164 |
+
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
|
| 165 |
+
sigmas = self.get_density(xs)
|
| 166 |
+
return sigmas.squeeze(-1)
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
ray_indices, t_starts, t_ends = self.estimator.sampling(
|
| 170 |
+
rays_o,
|
| 171 |
+
rays_d,
|
| 172 |
+
sigma_fn=sigma_fn,
|
| 173 |
+
near_plane=0.01,
|
| 174 |
+
far_plane=100,
|
| 175 |
+
render_step_size=self.render_step_size,
|
| 176 |
+
stratified=self.training,
|
| 177 |
+
cone_angle=0,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
t_origins = rays_o[ray_indices]
|
| 181 |
+
t_dirs = rays_d[ray_indices]
|
| 182 |
+
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
|
| 183 |
+
sigmas = self.get_density(xs).squeeze(-1)
|
| 184 |
+
rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
|
| 185 |
+
|
| 186 |
+
n_rays = rays_o.shape[0]
|
| 187 |
+
weights, trans, alphas = nerfacc.render_weight_from_density(
|
| 188 |
+
t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays
|
| 189 |
+
)
|
| 190 |
+
color = nerfacc.accumulate_along_rays(
|
| 191 |
+
weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays
|
| 192 |
+
)
|
| 193 |
+
alpha = nerfacc.accumulate_along_rays(
|
| 194 |
+
weights, values=None, ray_indices=ray_indices, n_rays=n_rays
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
color = color + 1 * (1.0 - alpha)
|
| 198 |
+
|
| 199 |
+
color = (
|
| 200 |
+
color.view(resolution, resolution, 3)
|
| 201 |
+
.clamp(0, 1)
|
| 202 |
+
.permute(2, 0, 1)
|
| 203 |
+
.contiguous()
|
| 204 |
+
)
|
| 205 |
+
alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
|
| 206 |
+
|
| 207 |
+
return color, alpha
|
| 208 |
+
|
| 209 |
+
def fit_nerf(self, iters=512, resolution=128):
|
| 210 |
+
|
| 211 |
+
self.opt.output_size = resolution
|
| 212 |
+
|
| 213 |
+
optimizer = torch.optim.Adam(
|
| 214 |
+
[
|
| 215 |
+
{"params": self.encoder_density.parameters(), "lr": 1e-2},
|
| 216 |
+
{"params": self.encoder.parameters(), "lr": 1e-2},
|
| 217 |
+
{"params": self.mlp_density.parameters(), "lr": 1e-3},
|
| 218 |
+
{"params": self.mlp.parameters(), "lr": 1e-3},
|
| 219 |
+
]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
print("[INFO] fitting nerf...")
|
| 223 |
+
pbar = tqdm.trange(iters)
|
| 224 |
+
for i in pbar:
|
| 225 |
+
|
| 226 |
+
ver = np.random.randint(-45, 45)
|
| 227 |
+
hor = np.random.randint(-180, 180)
|
| 228 |
+
rad = np.random.uniform(1.5, 3.0)
|
| 229 |
+
|
| 230 |
+
pose = orbit_camera(ver, hor, rad)
|
| 231 |
+
|
| 232 |
+
image_gt, alpha_gt = self.render_gs(pose)
|
| 233 |
+
image_pred, alpha_pred = self.render_nerf(pose)
|
| 234 |
+
|
| 235 |
+
# if i % 200 == 0:
|
| 236 |
+
# kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
|
| 237 |
+
|
| 238 |
+
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
| 239 |
+
alpha_pred, alpha_gt
|
| 240 |
+
)
|
| 241 |
+
loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
|
| 242 |
+
|
| 243 |
+
loss.backward()
|
| 244 |
+
self.encoder_density.grad_total_variation(1e-8)
|
| 245 |
+
|
| 246 |
+
optimizer.step()
|
| 247 |
+
optimizer.zero_grad()
|
| 248 |
+
|
| 249 |
+
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
| 250 |
+
|
| 251 |
+
print("[INFO] finished fitting nerf!")
|
| 252 |
+
|
| 253 |
+
def render_mesh(self, pose):
|
| 254 |
+
|
| 255 |
+
h = w = self.opt.output_size
|
| 256 |
+
|
| 257 |
+
v = self.v + self.deform
|
| 258 |
+
f = self.f
|
| 259 |
+
|
| 260 |
+
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
|
| 261 |
+
|
| 262 |
+
# get v_clip and render rgb
|
| 263 |
+
v_cam = (
|
| 264 |
+
torch.matmul(
|
| 265 |
+
F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T
|
| 266 |
+
)
|
| 267 |
+
.float()
|
| 268 |
+
.unsqueeze(0)
|
| 269 |
+
)
|
| 270 |
+
v_clip = v_cam @ self.proj.T
|
| 271 |
+
|
| 272 |
+
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
|
| 273 |
+
|
| 274 |
+
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
|
| 275 |
+
alpha = (
|
| 276 |
+
dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0)
|
| 277 |
+
) # [H, W] important to enable gradients!
|
| 278 |
+
|
| 279 |
+
if self.albedo is None:
|
| 280 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
|
| 281 |
+
xyzs = xyzs.view(-1, 3)
|
| 282 |
+
mask = (alpha > 0).view(-1)
|
| 283 |
+
image = torch.zeros_like(xyzs, dtype=torch.float32)
|
| 284 |
+
if mask.any():
|
| 285 |
+
masked_albedo = torch.sigmoid(
|
| 286 |
+
self.mlp(self.encoder(xyzs[mask].detach(), bound=1))
|
| 287 |
+
)
|
| 288 |
+
image[mask] = masked_albedo.float()
|
| 289 |
+
else:
|
| 290 |
+
texc, texc_db = dr.interpolate(
|
| 291 |
+
self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all"
|
| 292 |
+
)
|
| 293 |
+
image = torch.sigmoid(
|
| 294 |
+
dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)
|
| 295 |
+
) # [1, H, W, 3]
|
| 296 |
+
|
| 297 |
+
image = image.view(1, h, w, 3)
|
| 298 |
+
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
|
| 299 |
+
image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
|
| 300 |
+
image = alpha * image + (1 - alpha)
|
| 301 |
+
|
| 302 |
+
return image, alpha
|
| 303 |
+
|
| 304 |
+
def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
|
| 305 |
+
|
| 306 |
+
self.opt.output_size = resolution
|
| 307 |
+
|
| 308 |
+
# init mesh from nerf
|
| 309 |
+
grid_size = 256
|
| 310 |
+
sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
|
| 311 |
+
|
| 312 |
+
S = 128
|
| 313 |
+
density_thresh = 10
|
| 314 |
+
|
| 315 |
+
X = torch.linspace(-1, 1, grid_size).split(S)
|
| 316 |
+
Y = torch.linspace(-1, 1, grid_size).split(S)
|
| 317 |
+
Z = torch.linspace(-1, 1, grid_size).split(S)
|
| 318 |
+
|
| 319 |
+
for xi, xs in enumerate(X):
|
| 320 |
+
for yi, ys in enumerate(Y):
|
| 321 |
+
for zi, zs in enumerate(Z):
|
| 322 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
|
| 323 |
+
pts = torch.cat(
|
| 324 |
+
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
|
| 325 |
+
dim=-1,
|
| 326 |
+
) # [S, 3]
|
| 327 |
+
val = self.get_density(pts.to(self.device))
|
| 328 |
+
sigmas[
|
| 329 |
+
xi * S : xi * S + len(xs),
|
| 330 |
+
yi * S : yi * S + len(ys),
|
| 331 |
+
zi * S : zi * S + len(zs),
|
| 332 |
+
] = (
|
| 333 |
+
val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
| 334 |
+
) # [S, 1] --> [x, y, z]
|
| 335 |
+
|
| 336 |
+
print(
|
| 337 |
+
f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
| 341 |
+
vertices = vertices / (grid_size - 1.0) * 2 - 1
|
| 342 |
+
|
| 343 |
+
# clean
|
| 344 |
+
vertices = vertices.astype(np.float32)
|
| 345 |
+
triangles = triangles.astype(np.int32)
|
| 346 |
+
vertices, triangles = clean_mesh(
|
| 347 |
+
vertices, triangles, remesh=True, remesh_size=0.01
|
| 348 |
+
)
|
| 349 |
+
if triangles.shape[0] > decimate_target:
|
| 350 |
+
vertices, triangles = decimate_mesh(
|
| 351 |
+
vertices, triangles, decimate_target, optimalplacement=False
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
| 355 |
+
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
| 356 |
+
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
| 357 |
+
|
| 358 |
+
# fit mesh from gs
|
| 359 |
+
lr_factor = 1
|
| 360 |
+
optimizer = torch.optim.Adam(
|
| 361 |
+
[
|
| 362 |
+
{"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
|
| 363 |
+
{"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
|
| 364 |
+
{"params": self.deform, "lr": 1e-4},
|
| 365 |
+
]
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
print("[INFO] fitting mesh...")
|
| 369 |
+
pbar = tqdm.trange(iters)
|
| 370 |
+
for i in pbar:
|
| 371 |
+
|
| 372 |
+
ver = np.random.randint(-10, 10)
|
| 373 |
+
hor = np.random.randint(-180, 180)
|
| 374 |
+
rad = self.opt.cam_radius # np.random.uniform(1, 2)
|
| 375 |
+
|
| 376 |
+
pose = orbit_camera(ver, hor, rad)
|
| 377 |
+
|
| 378 |
+
image_gt, alpha_gt = self.render_gs(pose)
|
| 379 |
+
image_pred, alpha_pred = self.render_mesh(pose)
|
| 380 |
+
|
| 381 |
+
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
| 382 |
+
alpha_pred, alpha_gt
|
| 383 |
+
)
|
| 384 |
+
# loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
|
| 385 |
+
loss_normal = normal_consistency(self.v + self.deform, self.f)
|
| 386 |
+
loss_offsets = (self.deform**2).sum(-1).mean()
|
| 387 |
+
loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
|
| 388 |
+
|
| 389 |
+
loss.backward()
|
| 390 |
+
|
| 391 |
+
optimizer.step()
|
| 392 |
+
optimizer.zero_grad()
|
| 393 |
+
|
| 394 |
+
# remesh periodically
|
| 395 |
+
if i > 0 and i % 512 == 0:
|
| 396 |
+
vertices = (self.v + self.deform).detach().cpu().numpy()
|
| 397 |
+
triangles = self.f.detach().cpu().numpy()
|
| 398 |
+
vertices, triangles = clean_mesh(
|
| 399 |
+
vertices, triangles, remesh=True, remesh_size=0.01
|
| 400 |
+
)
|
| 401 |
+
if triangles.shape[0] > decimate_target:
|
| 402 |
+
vertices, triangles = decimate_mesh(
|
| 403 |
+
vertices, triangles, decimate_target, optimalplacement=False
|
| 404 |
+
)
|
| 405 |
+
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
| 406 |
+
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
| 407 |
+
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
| 408 |
+
lr_factor *= 0.5
|
| 409 |
+
optimizer = torch.optim.Adam(
|
| 410 |
+
[
|
| 411 |
+
{"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
|
| 412 |
+
{"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
|
| 413 |
+
{"params": self.deform, "lr": 1e-4},
|
| 414 |
+
]
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
| 418 |
+
|
| 419 |
+
# last clean
|
| 420 |
+
vertices = (self.v + self.deform).detach().cpu().numpy()
|
| 421 |
+
triangles = self.f.detach().cpu().numpy()
|
| 422 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
|
| 423 |
+
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
| 424 |
+
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
| 425 |
+
self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
|
| 426 |
+
|
| 427 |
+
print("[INFO] finished fitting mesh!")
|
| 428 |
+
|
| 429 |
+
# uv mesh refine
|
| 430 |
+
def fit_mesh_uv(
|
| 431 |
+
self, iters=512, resolution=512, texture_resolution=1024, padding=2
|
| 432 |
+
):
|
| 433 |
+
|
| 434 |
+
self.opt.output_size = resolution
|
| 435 |
+
|
| 436 |
+
# unwrap uv
|
| 437 |
+
print("[INFO] uv unwrapping...")
|
| 438 |
+
mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
|
| 439 |
+
mesh.auto_normal()
|
| 440 |
+
mesh.auto_uv()
|
| 441 |
+
|
| 442 |
+
self.vt = mesh.vt
|
| 443 |
+
self.ft = mesh.ft
|
| 444 |
+
|
| 445 |
+
# render uv maps
|
| 446 |
+
h = w = texture_resolution
|
| 447 |
+
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
| 448 |
+
uv = torch.cat(
|
| 449 |
+
(uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1
|
| 450 |
+
) # [N, 4]
|
| 451 |
+
|
| 452 |
+
rast, _ = dr.rasterize(
|
| 453 |
+
self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)
|
| 454 |
+
) # [1, h, w, 4]
|
| 455 |
+
xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
|
| 456 |
+
mask, _ = dr.interpolate(
|
| 457 |
+
torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f
|
| 458 |
+
) # [1, h, w, 1]
|
| 459 |
+
|
| 460 |
+
# masked query
|
| 461 |
+
xyzs = xyzs.view(-1, 3)
|
| 462 |
+
mask = (mask > 0).view(-1)
|
| 463 |
+
|
| 464 |
+
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
|
| 465 |
+
|
| 466 |
+
if mask.any():
|
| 467 |
+
print("[INFO] querying texture...")
|
| 468 |
+
|
| 469 |
+
xyzs = xyzs[mask] # [M, 3]
|
| 470 |
+
|
| 471 |
+
# batched inference to avoid OOM
|
| 472 |
+
batch = []
|
| 473 |
+
head = 0
|
| 474 |
+
while head < xyzs.shape[0]:
|
| 475 |
+
tail = min(head + 640000, xyzs.shape[0])
|
| 476 |
+
batch.append(
|
| 477 |
+
torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()
|
| 478 |
+
)
|
| 479 |
+
head += 640000
|
| 480 |
+
|
| 481 |
+
albedo[mask] = torch.cat(batch, dim=0)
|
| 482 |
+
|
| 483 |
+
albedo = albedo.view(h, w, -1)
|
| 484 |
+
mask = mask.view(h, w)
|
| 485 |
+
albedo = uv_padding(albedo, mask, padding)
|
| 486 |
+
|
| 487 |
+
# optimize texture
|
| 488 |
+
self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
|
| 489 |
+
|
| 490 |
+
optimizer = torch.optim.Adam(
|
| 491 |
+
[
|
| 492 |
+
{"params": self.albedo, "lr": 1e-3},
|
| 493 |
+
]
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
print("[INFO] fitting mesh texture...")
|
| 497 |
+
pbar = tqdm.trange(iters)
|
| 498 |
+
for i in pbar:
|
| 499 |
+
|
| 500 |
+
# shrink to front view as we care more about it...
|
| 501 |
+
ver = np.random.randint(-5, 5)
|
| 502 |
+
hor = np.random.randint(-15, 15)
|
| 503 |
+
rad = self.opt.cam_radius # np.random.uniform(1, 2)
|
| 504 |
+
|
| 505 |
+
pose = orbit_camera(ver, hor, rad)
|
| 506 |
+
|
| 507 |
+
image_gt, alpha_gt = self.render_gs(pose)
|
| 508 |
+
image_pred, alpha_pred = self.render_mesh(pose)
|
| 509 |
+
|
| 510 |
+
loss_mse = F.mse_loss(image_pred, image_gt)
|
| 511 |
+
loss = loss_mse
|
| 512 |
+
|
| 513 |
+
loss.backward()
|
| 514 |
+
|
| 515 |
+
optimizer.step()
|
| 516 |
+
optimizer.zero_grad()
|
| 517 |
+
|
| 518 |
+
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
| 519 |
+
|
| 520 |
+
print("[INFO] finished fitting mesh texture!")
|
| 521 |
+
|
| 522 |
+
@torch.no_grad()
|
| 523 |
+
def export_mesh(self, path):
|
| 524 |
+
|
| 525 |
+
mesh = Mesh(
|
| 526 |
+
v=self.v,
|
| 527 |
+
f=self.f,
|
| 528 |
+
vt=self.vt,
|
| 529 |
+
ft=self.ft,
|
| 530 |
+
albedo=torch.sigmoid(self.albedo),
|
| 531 |
+
device=self.device,
|
| 532 |
+
)
|
| 533 |
+
mesh.auto_normal()
|
| 534 |
+
mesh.write(path)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
opt = tyro.cli(AllConfigs)
|
| 538 |
+
|
| 539 |
+
# load a saved ply and convert to mesh
|
| 540 |
+
assert opt.test_path.endswith(
|
| 541 |
+
".ply"
|
| 542 |
+
), "--test_path must be a .ply file saved by infer.py"
|
| 543 |
+
|
| 544 |
+
converter = Converter(opt).cuda()
|
| 545 |
+
converter.fit_nerf()
|
| 546 |
+
converter.fit_mesh()
|
| 547 |
+
converter.fit_mesh_uv()
|
| 548 |
+
converter.export_mesh(opt.test_path.replace(".ply", ".glb"))
|
core/__init__.py
ADDED
|
File without changes
|
core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (135 Bytes). View file
|
|
|
core/__pycache__/gs.cpython-310.pyc
ADDED
|
Binary file (5.45 kB). View file
|
|
|
core/__pycache__/options.cpython-310.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
core/attention.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 17 |
+
try:
|
| 18 |
+
if XFORMERS_ENABLED:
|
| 19 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 20 |
+
|
| 21 |
+
XFORMERS_AVAILABLE = True
|
| 22 |
+
warnings.warn("xFormers is available (Attention)")
|
| 23 |
+
else:
|
| 24 |
+
warnings.warn("xFormers is disabled (Attention)")
|
| 25 |
+
raise ImportError
|
| 26 |
+
except ImportError:
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
warnings.warn("xFormers is not available (Attention)")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Attention(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
dim: int,
|
| 35 |
+
num_heads: int = 8,
|
| 36 |
+
qkv_bias: bool = False,
|
| 37 |
+
proj_bias: bool = True,
|
| 38 |
+
attn_drop: float = 0.0,
|
| 39 |
+
proj_drop: float = 0.0,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
head_dim = dim // num_heads
|
| 44 |
+
self.scale = head_dim**-0.5
|
| 45 |
+
|
| 46 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 47 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 48 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 49 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 52 |
+
B, N, C = x.shape
|
| 53 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 54 |
+
|
| 55 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 56 |
+
attn = q @ k.transpose(-2, -1)
|
| 57 |
+
|
| 58 |
+
attn = attn.softmax(dim=-1)
|
| 59 |
+
attn = self.attn_drop(attn)
|
| 60 |
+
|
| 61 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 62 |
+
x = self.proj(x)
|
| 63 |
+
x = self.proj_drop(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MemEffAttention(Attention):
|
| 68 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 69 |
+
if not XFORMERS_AVAILABLE:
|
| 70 |
+
if attn_bias is not None:
|
| 71 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 72 |
+
return super().forward(x)
|
| 73 |
+
|
| 74 |
+
B, N, C = x.shape
|
| 75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 76 |
+
|
| 77 |
+
q, k, v = unbind(qkv, 2)
|
| 78 |
+
|
| 79 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 80 |
+
x = x.reshape([B, N, C])
|
| 81 |
+
|
| 82 |
+
x = self.proj(x)
|
| 83 |
+
x = self.proj_drop(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CrossAttention(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
dim: int,
|
| 91 |
+
dim_q: int,
|
| 92 |
+
dim_k: int,
|
| 93 |
+
dim_v: int,
|
| 94 |
+
num_heads: int = 8,
|
| 95 |
+
qkv_bias: bool = False,
|
| 96 |
+
proj_bias: bool = True,
|
| 97 |
+
attn_drop: float = 0.0,
|
| 98 |
+
proj_drop: float = 0.0,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.dim = dim
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
head_dim = dim // num_heads
|
| 104 |
+
self.scale = head_dim**-0.5
|
| 105 |
+
|
| 106 |
+
self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
|
| 107 |
+
self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
|
| 108 |
+
self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
|
| 109 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 110 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 111 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 112 |
+
|
| 113 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 114 |
+
# q: [B, N, Cq]
|
| 115 |
+
# k: [B, M, Ck]
|
| 116 |
+
# v: [B, M, Cv]
|
| 117 |
+
# return: [B, N, C]
|
| 118 |
+
|
| 119 |
+
B, N, _ = q.shape
|
| 120 |
+
M = k.shape[1]
|
| 121 |
+
|
| 122 |
+
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
|
| 123 |
+
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
|
| 124 |
+
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
|
| 125 |
+
|
| 126 |
+
attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
|
| 127 |
+
|
| 128 |
+
attn = attn.softmax(dim=-1) # [B, nh, N, M]
|
| 129 |
+
attn = self.attn_drop(attn)
|
| 130 |
+
|
| 131 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
x = self.proj_drop(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MemEffCrossAttention(CrossAttention):
|
| 138 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
|
| 139 |
+
if not XFORMERS_AVAILABLE:
|
| 140 |
+
if attn_bias is not None:
|
| 141 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 142 |
+
return super().forward(x)
|
| 143 |
+
|
| 144 |
+
B, N, _ = q.shape
|
| 145 |
+
M = k.shape[1]
|
| 146 |
+
|
| 147 |
+
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
|
| 148 |
+
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
|
| 149 |
+
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
|
| 150 |
+
|
| 151 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 152 |
+
x = x.reshape(B, N, -1)
|
| 153 |
+
|
| 154 |
+
x = self.proj(x)
|
| 155 |
+
x = self.proj_drop(x)
|
| 156 |
+
return x
|
core/gs.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from diff_gaussian_rasterization import (
|
| 8 |
+
GaussianRasterizationSettings,
|
| 9 |
+
GaussianRasterizer,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from core.options import Options
|
| 13 |
+
|
| 14 |
+
import kiui
|
| 15 |
+
|
| 16 |
+
class GaussianRenderer:
|
| 17 |
+
def __init__(self, opt: Options):
|
| 18 |
+
|
| 19 |
+
self.opt = opt
|
| 20 |
+
self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
|
| 21 |
+
|
| 22 |
+
# intrinsics
|
| 23 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
| 24 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
| 25 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
| 26 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
| 27 |
+
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
| 28 |
+
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
| 29 |
+
self.proj_matrix[2, 3] = 1
|
| 30 |
+
|
| 31 |
+
def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
|
| 32 |
+
# gaussians: [B, N, 14]
|
| 33 |
+
# cam_view, cam_view_proj: [B, V, 4, 4]
|
| 34 |
+
# cam_pos: [B, V, 3]
|
| 35 |
+
|
| 36 |
+
device = gaussians.device
|
| 37 |
+
B, V = cam_view.shape[:2]
|
| 38 |
+
|
| 39 |
+
# loop of loop...
|
| 40 |
+
images = []
|
| 41 |
+
alphas = []
|
| 42 |
+
for b in range(B):
|
| 43 |
+
|
| 44 |
+
# pos, opacity, scale, rotation, shs
|
| 45 |
+
means3D = gaussians[b, :, 0:3].contiguous().float()
|
| 46 |
+
opacity = gaussians[b, :, 3:4].contiguous().float()
|
| 47 |
+
scales = gaussians[b, :, 4:7].contiguous().float()
|
| 48 |
+
rotations = gaussians[b, :, 7:11].contiguous().float()
|
| 49 |
+
rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
|
| 50 |
+
|
| 51 |
+
for v in range(V):
|
| 52 |
+
|
| 53 |
+
# render novel views
|
| 54 |
+
view_matrix = cam_view[b, v].float()
|
| 55 |
+
view_proj_matrix = cam_view_proj[b, v].float()
|
| 56 |
+
campos = cam_pos[b, v].float()
|
| 57 |
+
|
| 58 |
+
raster_settings = GaussianRasterizationSettings(
|
| 59 |
+
image_height=self.opt.output_size,
|
| 60 |
+
image_width=self.opt.output_size,
|
| 61 |
+
tanfovx=self.tan_half_fov,
|
| 62 |
+
tanfovy=self.tan_half_fov,
|
| 63 |
+
bg=self.bg_color if bg_color is None else bg_color,
|
| 64 |
+
scale_modifier=scale_modifier,
|
| 65 |
+
viewmatrix=view_matrix,
|
| 66 |
+
projmatrix=view_proj_matrix,
|
| 67 |
+
sh_degree=0,
|
| 68 |
+
campos=campos,
|
| 69 |
+
prefiltered=False,
|
| 70 |
+
debug=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
| 74 |
+
|
| 75 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
| 76 |
+
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
|
| 77 |
+
means3D=means3D,
|
| 78 |
+
means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
|
| 79 |
+
shs=None,
|
| 80 |
+
colors_precomp=rgbs,
|
| 81 |
+
opacities=opacity,
|
| 82 |
+
scales=scales,
|
| 83 |
+
rotations=rotations,
|
| 84 |
+
cov3D_precomp=None,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
rendered_image = rendered_image.clamp(0, 1)
|
| 88 |
+
|
| 89 |
+
images.append(rendered_image)
|
| 90 |
+
alphas.append(rendered_alpha)
|
| 91 |
+
|
| 92 |
+
images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
|
| 93 |
+
alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"image": images, # [B, V, 3, H, W]
|
| 97 |
+
"alpha": alphas, # [B, V, 1, H, W]
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def save_ply(self, gaussians, path, compatible=True):
|
| 102 |
+
# gaussians: [B, N, 14]
|
| 103 |
+
# compatible: save pre-activated gaussians as in the original paper
|
| 104 |
+
|
| 105 |
+
assert gaussians.shape[0] == 1, 'only support batch size 1'
|
| 106 |
+
|
| 107 |
+
from plyfile import PlyData, PlyElement
|
| 108 |
+
|
| 109 |
+
means3D = gaussians[0, :, 0:3].contiguous().float()
|
| 110 |
+
opacity = gaussians[0, :, 3:4].contiguous().float()
|
| 111 |
+
scales = gaussians[0, :, 4:7].contiguous().float()
|
| 112 |
+
rotations = gaussians[0, :, 7:11].contiguous().float()
|
| 113 |
+
shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
|
| 114 |
+
|
| 115 |
+
# prune by opacity
|
| 116 |
+
mask = opacity.squeeze(-1) >= 0.005
|
| 117 |
+
means3D = means3D[mask]
|
| 118 |
+
opacity = opacity[mask]
|
| 119 |
+
scales = scales[mask]
|
| 120 |
+
rotations = rotations[mask]
|
| 121 |
+
shs = shs[mask]
|
| 122 |
+
|
| 123 |
+
# invert activation to make it compatible with the original ply format
|
| 124 |
+
if compatible:
|
| 125 |
+
opacity = kiui.op.inverse_sigmoid(opacity)
|
| 126 |
+
scales = torch.log(scales + 1e-8)
|
| 127 |
+
shs = (shs - 0.5) / 0.28209479177387814
|
| 128 |
+
|
| 129 |
+
xyzs = means3D.detach().cpu().numpy()
|
| 130 |
+
f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
| 131 |
+
opacities = opacity.detach().cpu().numpy()
|
| 132 |
+
scales = scales.detach().cpu().numpy()
|
| 133 |
+
rotations = rotations.detach().cpu().numpy()
|
| 134 |
+
|
| 135 |
+
l = ['x', 'y', 'z']
|
| 136 |
+
# All channels except the 3 DC
|
| 137 |
+
for i in range(f_dc.shape[1]):
|
| 138 |
+
l.append('f_dc_{}'.format(i))
|
| 139 |
+
l.append('opacity')
|
| 140 |
+
for i in range(scales.shape[1]):
|
| 141 |
+
l.append('scale_{}'.format(i))
|
| 142 |
+
for i in range(rotations.shape[1]):
|
| 143 |
+
l.append('rot_{}'.format(i))
|
| 144 |
+
|
| 145 |
+
dtype_full = [(attribute, 'f4') for attribute in l]
|
| 146 |
+
|
| 147 |
+
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
|
| 148 |
+
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
|
| 149 |
+
elements[:] = list(map(tuple, attributes))
|
| 150 |
+
el = PlyElement.describe(elements, 'vertex')
|
| 151 |
+
|
| 152 |
+
PlyData([el]).write(path)
|
| 153 |
+
|
| 154 |
+
def load_ply(self, path, compatible=True):
|
| 155 |
+
|
| 156 |
+
from plyfile import PlyData, PlyElement
|
| 157 |
+
|
| 158 |
+
plydata = PlyData.read(path)
|
| 159 |
+
|
| 160 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
| 161 |
+
np.asarray(plydata.elements[0]["y"]),
|
| 162 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
| 163 |
+
print("Number of points at loading : ", xyz.shape[0])
|
| 164 |
+
|
| 165 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
| 166 |
+
|
| 167 |
+
shs = np.zeros((xyz.shape[0], 3))
|
| 168 |
+
shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
| 169 |
+
shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
|
| 170 |
+
shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
|
| 171 |
+
|
| 172 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
| 173 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
| 174 |
+
for idx, attr_name in enumerate(scale_names):
|
| 175 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
| 176 |
+
|
| 177 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
|
| 178 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
| 179 |
+
for idx, attr_name in enumerate(rot_names):
|
| 180 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
| 181 |
+
|
| 182 |
+
gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
|
| 183 |
+
gaussians = torch.from_numpy(gaussians).float() # cpu
|
| 184 |
+
|
| 185 |
+
if compatible:
|
| 186 |
+
gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
|
| 187 |
+
gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
|
| 188 |
+
gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
|
| 189 |
+
|
| 190 |
+
return gaussians
|
core/models.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import kiui
|
| 7 |
+
from kiui.lpips import LPIPS
|
| 8 |
+
|
| 9 |
+
from core.unet import UNet
|
| 10 |
+
from core.options import Options
|
| 11 |
+
from core.gs import GaussianRenderer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LGM(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
opt: Options,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.opt = opt
|
| 22 |
+
|
| 23 |
+
# unet
|
| 24 |
+
self.unet = UNet(
|
| 25 |
+
9, 14,
|
| 26 |
+
down_channels=self.opt.down_channels,
|
| 27 |
+
down_attention=self.opt.down_attention,
|
| 28 |
+
mid_attention=self.opt.mid_attention,
|
| 29 |
+
up_channels=self.opt.up_channels,
|
| 30 |
+
up_attention=self.opt.up_attention,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# last conv
|
| 34 |
+
self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again
|
| 35 |
+
|
| 36 |
+
# Gaussian Renderer
|
| 37 |
+
self.gs = GaussianRenderer(opt)
|
| 38 |
+
|
| 39 |
+
# activations...
|
| 40 |
+
self.pos_act = lambda x: x.clamp(-1, 1)
|
| 41 |
+
self.scale_act = lambda x: 0.1 * F.softplus(x)
|
| 42 |
+
self.opacity_act = lambda x: torch.sigmoid(x)
|
| 43 |
+
self.rot_act = F.normalize
|
| 44 |
+
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
|
| 45 |
+
|
| 46 |
+
# LPIPS loss
|
| 47 |
+
if self.opt.lambda_lpips > 0:
|
| 48 |
+
self.lpips_loss = LPIPS(net='vgg')
|
| 49 |
+
self.lpips_loss.requires_grad_(False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def state_dict(self, **kwargs):
|
| 53 |
+
# remove lpips_loss
|
| 54 |
+
state_dict = super().state_dict(**kwargs)
|
| 55 |
+
for k in list(state_dict.keys()):
|
| 56 |
+
if 'lpips_loss' in k:
|
| 57 |
+
del state_dict[k]
|
| 58 |
+
return state_dict
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prepare_default_rays(self, device, elevation=0):
|
| 62 |
+
|
| 63 |
+
from kiui.cam import orbit_camera
|
| 64 |
+
from core.utils import get_rays
|
| 65 |
+
|
| 66 |
+
cam_poses = np.stack([
|
| 67 |
+
orbit_camera(elevation, 0, radius=self.opt.cam_radius),
|
| 68 |
+
orbit_camera(elevation, 90, radius=self.opt.cam_radius),
|
| 69 |
+
orbit_camera(elevation, 180, radius=self.opt.cam_radius),
|
| 70 |
+
orbit_camera(elevation, 270, radius=self.opt.cam_radius),
|
| 71 |
+
], axis=0) # [4, 4, 4]
|
| 72 |
+
cam_poses = torch.from_numpy(cam_poses)
|
| 73 |
+
|
| 74 |
+
rays_embeddings = []
|
| 75 |
+
for i in range(cam_poses.shape[0]):
|
| 76 |
+
rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
| 77 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
| 78 |
+
rays_embeddings.append(rays_plucker)
|
| 79 |
+
|
| 80 |
+
## visualize rays for plotting figure
|
| 81 |
+
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
|
| 82 |
+
|
| 83 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
|
| 84 |
+
|
| 85 |
+
return rays_embeddings
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def forward_gaussians(self, images):
|
| 89 |
+
# images: [B, 4, 9, H, W]
|
| 90 |
+
# return: Gaussians: [B, dim_t]
|
| 91 |
+
|
| 92 |
+
B, V, C, H, W = images.shape
|
| 93 |
+
images = images.view(B*V, C, H, W)
|
| 94 |
+
|
| 95 |
+
x = self.unet(images) # [B*4, 14, h, w]
|
| 96 |
+
x = self.conv(x) # [B*4, 14, h, w]
|
| 97 |
+
|
| 98 |
+
x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
|
| 99 |
+
|
| 100 |
+
## visualize multi-view gaussian features for plotting figure
|
| 101 |
+
# tmp_alpha = self.opacity_act(x[0, :, 3:4])
|
| 102 |
+
# tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)
|
| 103 |
+
# tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5
|
| 104 |
+
# kiui.vis.plot_image(tmp_img_rgb, save=True)
|
| 105 |
+
# kiui.vis.plot_image(tmp_img_pos, save=True)
|
| 106 |
+
|
| 107 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
| 108 |
+
|
| 109 |
+
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
| 110 |
+
opacity = self.opacity_act(x[..., 3:4])
|
| 111 |
+
scale = self.scale_act(x[..., 4:7])
|
| 112 |
+
rotation = self.rot_act(x[..., 7:11])
|
| 113 |
+
rgbs = self.rgb_act(x[..., 11:])
|
| 114 |
+
|
| 115 |
+
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
| 116 |
+
|
| 117 |
+
return gaussians
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def forward(self, data, step_ratio=1):
|
| 121 |
+
# data: output of the dataloader
|
| 122 |
+
# return: loss
|
| 123 |
+
|
| 124 |
+
results = {}
|
| 125 |
+
loss = 0
|
| 126 |
+
|
| 127 |
+
images = data['input'] # [B, 4, 9, h, W], input features
|
| 128 |
+
|
| 129 |
+
# use the first view to predict gaussians
|
| 130 |
+
gaussians = self.forward_gaussians(images) # [B, N, 14]
|
| 131 |
+
|
| 132 |
+
results['gaussians'] = gaussians
|
| 133 |
+
|
| 134 |
+
# random bg for training
|
| 135 |
+
if self.training:
|
| 136 |
+
bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device)
|
| 137 |
+
else:
|
| 138 |
+
bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
|
| 139 |
+
|
| 140 |
+
# use the other views for rendering and supervision
|
| 141 |
+
results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
|
| 142 |
+
pred_images = results['image'] # [B, V, C, output_size, output_size]
|
| 143 |
+
pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
|
| 144 |
+
|
| 145 |
+
results['images_pred'] = pred_images
|
| 146 |
+
results['alphas_pred'] = pred_alphas
|
| 147 |
+
|
| 148 |
+
gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
| 149 |
+
gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
|
| 150 |
+
|
| 151 |
+
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
| 152 |
+
|
| 153 |
+
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
|
| 154 |
+
loss = loss + loss_mse
|
| 155 |
+
|
| 156 |
+
if self.opt.lambda_lpips > 0:
|
| 157 |
+
loss_lpips = self.lpips_loss(
|
| 158 |
+
# gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
|
| 159 |
+
# pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
|
| 160 |
+
# downsampled to at most 256 to reduce memory cost
|
| 161 |
+
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
| 162 |
+
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
| 163 |
+
).mean()
|
| 164 |
+
results['loss_lpips'] = loss_lpips
|
| 165 |
+
loss = loss + self.opt.lambda_lpips * loss_lpips
|
| 166 |
+
|
| 167 |
+
results['loss'] = loss
|
| 168 |
+
|
| 169 |
+
# metric
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
|
| 172 |
+
results['psnr'] = psnr
|
| 173 |
+
|
| 174 |
+
return results
|
core/options.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tyro
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Literal, Dict, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Options:
|
| 8 |
+
### model
|
| 9 |
+
# Unet image input size
|
| 10 |
+
input_size: int = 256
|
| 11 |
+
# Unet definition
|
| 12 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
|
| 13 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
|
| 14 |
+
mid_attention: bool = True
|
| 15 |
+
up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
|
| 16 |
+
up_attention: Tuple[bool, ...] = (True, True, True, False)
|
| 17 |
+
# Unet output size, dependent on the input_size and U-Net structure!
|
| 18 |
+
splat_size: int = 64
|
| 19 |
+
# gaussian render size
|
| 20 |
+
output_size: int = 256
|
| 21 |
+
|
| 22 |
+
### dataset
|
| 23 |
+
# data mode (only support s3 now)
|
| 24 |
+
data_mode: Literal['s3'] = 's3'
|
| 25 |
+
# fovy of the dataset
|
| 26 |
+
fovy: float = 49.1
|
| 27 |
+
# camera near plane
|
| 28 |
+
znear: float = 0.5
|
| 29 |
+
# camera far plane
|
| 30 |
+
zfar: float = 2.5
|
| 31 |
+
# number of all views (input + output)
|
| 32 |
+
num_views: int = 12
|
| 33 |
+
# number of views
|
| 34 |
+
num_input_views: int = 4
|
| 35 |
+
# camera radius
|
| 36 |
+
cam_radius: float = 1.5 # to better use [-1, 1]^3 space
|
| 37 |
+
# num workers
|
| 38 |
+
num_workers: int = 8
|
| 39 |
+
|
| 40 |
+
### training
|
| 41 |
+
# workspace
|
| 42 |
+
workspace: str = './workspace'
|
| 43 |
+
# resume
|
| 44 |
+
resume: Optional[str] = None
|
| 45 |
+
# batch size (per-GPU)
|
| 46 |
+
batch_size: int = 8
|
| 47 |
+
# gradient accumulation
|
| 48 |
+
gradient_accumulation_steps: int = 1
|
| 49 |
+
# training epochs
|
| 50 |
+
num_epochs: int = 30
|
| 51 |
+
# lpips loss weight
|
| 52 |
+
lambda_lpips: float = 1.0
|
| 53 |
+
# gradient clip
|
| 54 |
+
gradient_clip: float = 1.0
|
| 55 |
+
# mixed precision
|
| 56 |
+
mixed_precision: str = 'bf16'
|
| 57 |
+
# learning rate
|
| 58 |
+
lr: float = 4e-4
|
| 59 |
+
# augmentation prob for grid distortion
|
| 60 |
+
prob_grid_distortion: float = 0.5
|
| 61 |
+
# augmentation prob for camera jitter
|
| 62 |
+
prob_cam_jitter: float = 0.5
|
| 63 |
+
|
| 64 |
+
### testing
|
| 65 |
+
# test image path
|
| 66 |
+
test_path: Optional[str] = None
|
| 67 |
+
|
| 68 |
+
### misc
|
| 69 |
+
# nvdiffrast backend setting
|
| 70 |
+
force_cuda_rast: bool = False
|
| 71 |
+
# render fancy video with gaussian scaling effect
|
| 72 |
+
fancy_video: bool = False
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# all the default settings
|
| 76 |
+
config_defaults: Dict[str, Options] = {}
|
| 77 |
+
config_doc: Dict[str, str] = {}
|
| 78 |
+
|
| 79 |
+
config_doc['lrm'] = 'the default settings for LGM'
|
| 80 |
+
config_defaults['lrm'] = Options()
|
| 81 |
+
|
| 82 |
+
config_doc['small'] = 'small model with lower resolution Gaussians'
|
| 83 |
+
config_defaults['small'] = Options(
|
| 84 |
+
input_size=256,
|
| 85 |
+
splat_size=64,
|
| 86 |
+
output_size=256,
|
| 87 |
+
batch_size=8,
|
| 88 |
+
gradient_accumulation_steps=1,
|
| 89 |
+
mixed_precision='bf16',
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
config_doc['big'] = 'big model with higher resolution Gaussians'
|
| 93 |
+
config_defaults['big'] = Options(
|
| 94 |
+
input_size=256,
|
| 95 |
+
up_channels=(1024, 1024, 512, 256, 128), # one more decoder
|
| 96 |
+
up_attention=(True, True, True, False, False),
|
| 97 |
+
splat_size=128,
|
| 98 |
+
output_size=512, # render & supervise Gaussians at a higher resolution.
|
| 99 |
+
batch_size=8,
|
| 100 |
+
num_views=8,
|
| 101 |
+
gradient_accumulation_steps=1,
|
| 102 |
+
mixed_precision='bf16',
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
config_doc['tiny'] = 'tiny model for ablation'
|
| 106 |
+
config_defaults['tiny'] = Options(
|
| 107 |
+
input_size=256,
|
| 108 |
+
down_channels=(32, 64, 128, 256, 512),
|
| 109 |
+
down_attention=(False, False, False, False, True),
|
| 110 |
+
up_channels=(512, 256, 128),
|
| 111 |
+
up_attention=(True, False, False, False),
|
| 112 |
+
splat_size=64,
|
| 113 |
+
output_size=256,
|
| 114 |
+
batch_size=16,
|
| 115 |
+
num_views=8,
|
| 116 |
+
gradient_accumulation_steps=1,
|
| 117 |
+
mixed_precision='bf16',
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
|
core/provider_objaverse.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
import kiui
|
| 13 |
+
from core.options import Options
|
| 14 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
| 15 |
+
|
| 16 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 17 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ObjaverseDataset(Dataset):
|
| 21 |
+
|
| 22 |
+
def _warn(self):
|
| 23 |
+
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
|
| 24 |
+
|
| 25 |
+
def __init__(self, opt: Options, training=True):
|
| 26 |
+
|
| 27 |
+
self.opt = opt
|
| 28 |
+
self.training = training
|
| 29 |
+
|
| 30 |
+
# TODO: remove this barrier
|
| 31 |
+
self._warn()
|
| 32 |
+
|
| 33 |
+
# TODO: load the list of objects for training
|
| 34 |
+
self.items = []
|
| 35 |
+
with open('TODO: file containing the list', 'r') as f:
|
| 36 |
+
for line in f.readlines():
|
| 37 |
+
self.items.append(line.strip())
|
| 38 |
+
|
| 39 |
+
# naive split
|
| 40 |
+
if self.training:
|
| 41 |
+
self.items = self.items[:-self.opt.batch_size]
|
| 42 |
+
else:
|
| 43 |
+
self.items = self.items[-self.opt.batch_size:]
|
| 44 |
+
|
| 45 |
+
# default camera intrinsics
|
| 46 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
| 47 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
| 48 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
| 49 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
| 50 |
+
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
| 51 |
+
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
| 52 |
+
self.proj_matrix[2, 3] = 1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.items)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
|
| 60 |
+
uid = self.items[idx]
|
| 61 |
+
results = {}
|
| 62 |
+
|
| 63 |
+
# load num_views images
|
| 64 |
+
images = []
|
| 65 |
+
masks = []
|
| 66 |
+
cam_poses = []
|
| 67 |
+
|
| 68 |
+
vid_cnt = 0
|
| 69 |
+
|
| 70 |
+
# TODO: choose views, based on your rendering settings
|
| 71 |
+
if self.training:
|
| 72 |
+
# input views are in (36, 72), other views are randomly selected
|
| 73 |
+
vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()
|
| 74 |
+
else:
|
| 75 |
+
# fixed views
|
| 76 |
+
vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist()
|
| 77 |
+
|
| 78 |
+
for vid in vids:
|
| 79 |
+
|
| 80 |
+
image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png')
|
| 81 |
+
camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt')
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# TODO: load data (modify self.client here)
|
| 85 |
+
image = np.frombuffer(self.client.get(image_path), np.uint8)
|
| 86 |
+
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
|
| 87 |
+
c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')]
|
| 88 |
+
c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
# print(f'[WARN] dataset {uid} {vid}: {e}')
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# TODO: you may have a different camera system
|
| 94 |
+
# blender world + opencv cam --> opengl world & cam
|
| 95 |
+
c2w[1] *= -1
|
| 96 |
+
c2w[[1, 2]] = c2w[[2, 1]]
|
| 97 |
+
c2w[:3, 1:3] *= -1 # invert up and forward direction
|
| 98 |
+
|
| 99 |
+
# scale up radius to fully use the [-1, 1]^3 space!
|
| 100 |
+
c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale
|
| 101 |
+
|
| 102 |
+
image = image.permute(2, 0, 1) # [4, 512, 512]
|
| 103 |
+
mask = image[3:4] # [1, 512, 512]
|
| 104 |
+
image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
|
| 105 |
+
image = image[[2,1,0]].contiguous() # bgr to rgb
|
| 106 |
+
|
| 107 |
+
images.append(image)
|
| 108 |
+
masks.append(mask.squeeze(0))
|
| 109 |
+
cam_poses.append(c2w)
|
| 110 |
+
|
| 111 |
+
vid_cnt += 1
|
| 112 |
+
if vid_cnt == self.opt.num_views:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if vid_cnt < self.opt.num_views:
|
| 116 |
+
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
|
| 117 |
+
n = self.opt.num_views - vid_cnt
|
| 118 |
+
images = images + [images[-1]] * n
|
| 119 |
+
masks = masks + [masks[-1]] * n
|
| 120 |
+
cam_poses = cam_poses + [cam_poses[-1]] * n
|
| 121 |
+
|
| 122 |
+
images = torch.stack(images, dim=0) # [V, C, H, W]
|
| 123 |
+
masks = torch.stack(masks, dim=0) # [V, H, W]
|
| 124 |
+
cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
|
| 125 |
+
|
| 126 |
+
# normalized camera feats as in paper (transform the first pose to a fixed position)
|
| 127 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
|
| 128 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
|
| 129 |
+
|
| 130 |
+
images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
|
| 131 |
+
cam_poses_input = cam_poses[:self.opt.num_input_views].clone()
|
| 132 |
+
|
| 133 |
+
# data augmentation
|
| 134 |
+
if self.training:
|
| 135 |
+
# apply random grid distortion to simulate 3D inconsistency
|
| 136 |
+
if random.random() < self.opt.prob_grid_distortion:
|
| 137 |
+
images_input[1:] = grid_distortion(images_input[1:])
|
| 138 |
+
# apply camera jittering (only to input!)
|
| 139 |
+
if random.random() < self.opt.prob_cam_jitter:
|
| 140 |
+
cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:])
|
| 141 |
+
|
| 142 |
+
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 143 |
+
|
| 144 |
+
# resize render ground-truth images, range still in [0, 1]
|
| 145 |
+
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
|
| 146 |
+
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
|
| 147 |
+
|
| 148 |
+
# build rays for input views
|
| 149 |
+
rays_embeddings = []
|
| 150 |
+
for i in range(self.opt.num_input_views):
|
| 151 |
+
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
| 152 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
| 153 |
+
rays_embeddings.append(rays_plucker)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
|
| 157 |
+
final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
|
| 158 |
+
results['input'] = final_input
|
| 159 |
+
|
| 160 |
+
# opengl to colmap camera for gaussian renderer
|
| 161 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 162 |
+
|
| 163 |
+
# cameras needed by gaussian rasterizer
|
| 164 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 165 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
| 166 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 167 |
+
|
| 168 |
+
results['cam_view'] = cam_view
|
| 169 |
+
results['cam_view_proj'] = cam_view_proj
|
| 170 |
+
results['cam_pos'] = cam_pos
|
| 171 |
+
|
| 172 |
+
return results
|
core/unet.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Tuple, Literal
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from core.attention import MemEffAttention
|
| 10 |
+
|
| 11 |
+
class MVAttention(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dim: int,
|
| 15 |
+
num_heads: int = 8,
|
| 16 |
+
qkv_bias: bool = False,
|
| 17 |
+
proj_bias: bool = True,
|
| 18 |
+
attn_drop: float = 0.0,
|
| 19 |
+
proj_drop: float = 0.0,
|
| 20 |
+
groups: int = 32,
|
| 21 |
+
eps: float = 1e-5,
|
| 22 |
+
residual: bool = True,
|
| 23 |
+
skip_scale: float = 1,
|
| 24 |
+
num_frames: int = 4, # WARN: hardcoded!
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.residual = residual
|
| 29 |
+
self.skip_scale = skip_scale
|
| 30 |
+
self.num_frames = num_frames
|
| 31 |
+
|
| 32 |
+
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
|
| 33 |
+
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
# x: [B*V, C, H, W]
|
| 37 |
+
BV, C, H, W = x.shape
|
| 38 |
+
B = BV // self.num_frames # assert BV % self.num_frames == 0
|
| 39 |
+
|
| 40 |
+
res = x
|
| 41 |
+
x = self.norm(x)
|
| 42 |
+
|
| 43 |
+
x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
|
| 44 |
+
x = self.attn(x)
|
| 45 |
+
x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
|
| 46 |
+
|
| 47 |
+
if self.residual:
|
| 48 |
+
x = (x + res) * self.skip_scale
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class ResnetBlock(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
in_channels: int,
|
| 55 |
+
out_channels: int,
|
| 56 |
+
resample: Literal['default', 'up', 'down'] = 'default',
|
| 57 |
+
groups: int = 32,
|
| 58 |
+
eps: float = 1e-5,
|
| 59 |
+
skip_scale: float = 1, # multiplied to output
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.in_channels = in_channels
|
| 64 |
+
self.out_channels = out_channels
|
| 65 |
+
self.skip_scale = skip_scale
|
| 66 |
+
|
| 67 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 68 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 71 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 72 |
+
|
| 73 |
+
self.act = F.silu
|
| 74 |
+
|
| 75 |
+
self.resample = None
|
| 76 |
+
if resample == 'up':
|
| 77 |
+
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
| 78 |
+
elif resample == 'down':
|
| 79 |
+
self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 80 |
+
|
| 81 |
+
self.shortcut = nn.Identity()
|
| 82 |
+
if self.in_channels != self.out_channels:
|
| 83 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
res = x
|
| 88 |
+
|
| 89 |
+
x = self.norm1(x)
|
| 90 |
+
x = self.act(x)
|
| 91 |
+
|
| 92 |
+
if self.resample:
|
| 93 |
+
res = self.resample(res)
|
| 94 |
+
x = self.resample(x)
|
| 95 |
+
|
| 96 |
+
x = self.conv1(x)
|
| 97 |
+
x = self.norm2(x)
|
| 98 |
+
x = self.act(x)
|
| 99 |
+
x = self.conv2(x)
|
| 100 |
+
|
| 101 |
+
x = (x + self.shortcut(res)) * self.skip_scale
|
| 102 |
+
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
class DownBlock(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
in_channels: int,
|
| 109 |
+
out_channels: int,
|
| 110 |
+
num_layers: int = 1,
|
| 111 |
+
downsample: bool = True,
|
| 112 |
+
attention: bool = True,
|
| 113 |
+
attention_heads: int = 16,
|
| 114 |
+
skip_scale: float = 1,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
nets = []
|
| 119 |
+
attns = []
|
| 120 |
+
for i in range(num_layers):
|
| 121 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 122 |
+
nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
|
| 123 |
+
if attention:
|
| 124 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
|
| 125 |
+
else:
|
| 126 |
+
attns.append(None)
|
| 127 |
+
self.nets = nn.ModuleList(nets)
|
| 128 |
+
self.attns = nn.ModuleList(attns)
|
| 129 |
+
|
| 130 |
+
self.downsample = None
|
| 131 |
+
if downsample:
|
| 132 |
+
self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
xs = []
|
| 136 |
+
|
| 137 |
+
for attn, net in zip(self.attns, self.nets):
|
| 138 |
+
x = net(x)
|
| 139 |
+
if attn:
|
| 140 |
+
x = attn(x)
|
| 141 |
+
xs.append(x)
|
| 142 |
+
|
| 143 |
+
if self.downsample:
|
| 144 |
+
x = self.downsample(x)
|
| 145 |
+
xs.append(x)
|
| 146 |
+
|
| 147 |
+
return x, xs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class MidBlock(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_channels: int,
|
| 154 |
+
num_layers: int = 1,
|
| 155 |
+
attention: bool = True,
|
| 156 |
+
attention_heads: int = 16,
|
| 157 |
+
skip_scale: float = 1,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
nets = []
|
| 162 |
+
attns = []
|
| 163 |
+
# first layer
|
| 164 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
| 165 |
+
# more layers
|
| 166 |
+
for i in range(num_layers):
|
| 167 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
| 168 |
+
if attention:
|
| 169 |
+
attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale))
|
| 170 |
+
else:
|
| 171 |
+
attns.append(None)
|
| 172 |
+
self.nets = nn.ModuleList(nets)
|
| 173 |
+
self.attns = nn.ModuleList(attns)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
x = self.nets[0](x)
|
| 177 |
+
for attn, net in zip(self.attns, self.nets[1:]):
|
| 178 |
+
if attn:
|
| 179 |
+
x = attn(x)
|
| 180 |
+
x = net(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class UpBlock(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
in_channels: int,
|
| 188 |
+
prev_out_channels: int,
|
| 189 |
+
out_channels: int,
|
| 190 |
+
num_layers: int = 1,
|
| 191 |
+
upsample: bool = True,
|
| 192 |
+
attention: bool = True,
|
| 193 |
+
attention_heads: int = 16,
|
| 194 |
+
skip_scale: float = 1,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
nets = []
|
| 199 |
+
attns = []
|
| 200 |
+
for i in range(num_layers):
|
| 201 |
+
cin = in_channels if i == 0 else out_channels
|
| 202 |
+
cskip = prev_out_channels if (i == num_layers - 1) else out_channels
|
| 203 |
+
|
| 204 |
+
nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
|
| 205 |
+
if attention:
|
| 206 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
|
| 207 |
+
else:
|
| 208 |
+
attns.append(None)
|
| 209 |
+
self.nets = nn.ModuleList(nets)
|
| 210 |
+
self.attns = nn.ModuleList(attns)
|
| 211 |
+
|
| 212 |
+
self.upsample = None
|
| 213 |
+
if upsample:
|
| 214 |
+
self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 215 |
+
|
| 216 |
+
def forward(self, x, xs):
|
| 217 |
+
|
| 218 |
+
for attn, net in zip(self.attns, self.nets):
|
| 219 |
+
res_x = xs[-1]
|
| 220 |
+
xs = xs[:-1]
|
| 221 |
+
x = torch.cat([x, res_x], dim=1)
|
| 222 |
+
x = net(x)
|
| 223 |
+
if attn:
|
| 224 |
+
x = attn(x)
|
| 225 |
+
|
| 226 |
+
if self.upsample:
|
| 227 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
| 228 |
+
x = self.upsample(x)
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# it could be asymmetric!
|
| 234 |
+
class UNet(nn.Module):
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
in_channels: int = 3,
|
| 238 |
+
out_channels: int = 3,
|
| 239 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
|
| 240 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True),
|
| 241 |
+
mid_attention: bool = True,
|
| 242 |
+
up_channels: Tuple[int, ...] = (1024, 512, 256),
|
| 243 |
+
up_attention: Tuple[bool, ...] = (True, True, False),
|
| 244 |
+
layers_per_block: int = 2,
|
| 245 |
+
skip_scale: float = np.sqrt(0.5),
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
# first
|
| 250 |
+
self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
|
| 251 |
+
|
| 252 |
+
# down
|
| 253 |
+
down_blocks = []
|
| 254 |
+
cout = down_channels[0]
|
| 255 |
+
for i in range(len(down_channels)):
|
| 256 |
+
cin = cout
|
| 257 |
+
cout = down_channels[i]
|
| 258 |
+
|
| 259 |
+
down_blocks.append(DownBlock(
|
| 260 |
+
cin, cout,
|
| 261 |
+
num_layers=layers_per_block,
|
| 262 |
+
downsample=(i != len(down_channels) - 1), # not final layer
|
| 263 |
+
attention=down_attention[i],
|
| 264 |
+
skip_scale=skip_scale,
|
| 265 |
+
))
|
| 266 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 267 |
+
|
| 268 |
+
# mid
|
| 269 |
+
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)
|
| 270 |
+
|
| 271 |
+
# up
|
| 272 |
+
up_blocks = []
|
| 273 |
+
cout = up_channels[0]
|
| 274 |
+
for i in range(len(up_channels)):
|
| 275 |
+
cin = cout
|
| 276 |
+
cout = up_channels[i]
|
| 277 |
+
cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
|
| 278 |
+
|
| 279 |
+
up_blocks.append(UpBlock(
|
| 280 |
+
cin, cskip, cout,
|
| 281 |
+
num_layers=layers_per_block + 1, # one more layer for up
|
| 282 |
+
upsample=(i != len(up_channels) - 1), # not final layer
|
| 283 |
+
attention=up_attention[i],
|
| 284 |
+
skip_scale=skip_scale,
|
| 285 |
+
))
|
| 286 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
| 287 |
+
|
| 288 |
+
# last
|
| 289 |
+
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
|
| 290 |
+
self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
# x: [B, Cin, H, W]
|
| 295 |
+
|
| 296 |
+
# first
|
| 297 |
+
x = self.conv_in(x)
|
| 298 |
+
|
| 299 |
+
# down
|
| 300 |
+
xss = [x]
|
| 301 |
+
for block in self.down_blocks:
|
| 302 |
+
x, xs = block(x)
|
| 303 |
+
xss.extend(xs)
|
| 304 |
+
|
| 305 |
+
# mid
|
| 306 |
+
x = self.mid_block(x)
|
| 307 |
+
|
| 308 |
+
# up
|
| 309 |
+
for block in self.up_blocks:
|
| 310 |
+
xs = xss[-len(block.nets):]
|
| 311 |
+
xss = xss[:-len(block.nets)]
|
| 312 |
+
x = block(x, xs)
|
| 313 |
+
|
| 314 |
+
# last
|
| 315 |
+
x = self.norm_out(x)
|
| 316 |
+
x = F.silu(x)
|
| 317 |
+
x = self.conv_out(x) # [B, Cout, H', W']
|
| 318 |
+
|
| 319 |
+
return x
|
core/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import roma
|
| 8 |
+
from kiui.op import safe_normalize
|
| 9 |
+
|
| 10 |
+
def get_rays(pose, h, w, fovy, opengl=True):
|
| 11 |
+
|
| 12 |
+
x, y = torch.meshgrid(
|
| 13 |
+
torch.arange(w, device=pose.device),
|
| 14 |
+
torch.arange(h, device=pose.device),
|
| 15 |
+
indexing="xy",
|
| 16 |
+
)
|
| 17 |
+
x = x.flatten()
|
| 18 |
+
y = y.flatten()
|
| 19 |
+
|
| 20 |
+
cx = w * 0.5
|
| 21 |
+
cy = h * 0.5
|
| 22 |
+
|
| 23 |
+
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
|
| 24 |
+
|
| 25 |
+
camera_dirs = F.pad(
|
| 26 |
+
torch.stack(
|
| 27 |
+
[
|
| 28 |
+
(x - cx + 0.5) / focal,
|
| 29 |
+
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
|
| 30 |
+
],
|
| 31 |
+
dim=-1,
|
| 32 |
+
),
|
| 33 |
+
(0, 1),
|
| 34 |
+
value=(-1.0 if opengl else 1.0),
|
| 35 |
+
) # [hw, 3]
|
| 36 |
+
|
| 37 |
+
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
| 38 |
+
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
|
| 39 |
+
|
| 40 |
+
rays_o = rays_o.view(h, w, 3)
|
| 41 |
+
rays_d = safe_normalize(rays_d).view(h, w, 3)
|
| 42 |
+
|
| 43 |
+
return rays_o, rays_d
|
| 44 |
+
|
| 45 |
+
def orbit_camera_jitter(poses, strength=0.1):
|
| 46 |
+
# poses: [B, 4, 4], assume orbit camera in opengl format
|
| 47 |
+
# random orbital rotate
|
| 48 |
+
|
| 49 |
+
B = poses.shape[0]
|
| 50 |
+
rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
| 51 |
+
rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
| 52 |
+
|
| 53 |
+
rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
|
| 54 |
+
R = rot @ poses[:, :3, :3]
|
| 55 |
+
T = rot @ poses[:, :3, 3:]
|
| 56 |
+
|
| 57 |
+
new_poses = poses.clone()
|
| 58 |
+
new_poses[:, :3, :3] = R
|
| 59 |
+
new_poses[:, :3, 3:] = T
|
| 60 |
+
|
| 61 |
+
return new_poses
|
| 62 |
+
|
| 63 |
+
def grid_distortion(images, strength=0.5):
|
| 64 |
+
# images: [B, C, H, W]
|
| 65 |
+
# num_steps: int, grid resolution for distortion
|
| 66 |
+
# strength: float in [0, 1], strength of distortion
|
| 67 |
+
|
| 68 |
+
B, C, H, W = images.shape
|
| 69 |
+
|
| 70 |
+
num_steps = np.random.randint(8, 17)
|
| 71 |
+
grid_steps = torch.linspace(-1, 1, num_steps)
|
| 72 |
+
|
| 73 |
+
# have to loop batch...
|
| 74 |
+
grids = []
|
| 75 |
+
for b in range(B):
|
| 76 |
+
# construct displacement
|
| 77 |
+
x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
| 78 |
+
x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
| 79 |
+
x_steps = (x_steps * W).long() # [num_steps]
|
| 80 |
+
x_steps[0] = 0
|
| 81 |
+
x_steps[-1] = W
|
| 82 |
+
xs = []
|
| 83 |
+
for i in range(num_steps - 1):
|
| 84 |
+
xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
|
| 85 |
+
xs = torch.cat(xs, dim=0) # [W]
|
| 86 |
+
|
| 87 |
+
y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
| 88 |
+
y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
| 89 |
+
y_steps = (y_steps * H).long() # [num_steps]
|
| 90 |
+
y_steps[0] = 0
|
| 91 |
+
y_steps[-1] = H
|
| 92 |
+
ys = []
|
| 93 |
+
for i in range(num_steps - 1):
|
| 94 |
+
ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
|
| 95 |
+
ys = torch.cat(ys, dim=0) # [H]
|
| 96 |
+
|
| 97 |
+
# construct grid
|
| 98 |
+
grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
|
| 99 |
+
grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
|
| 100 |
+
|
| 101 |
+
grids.append(grid)
|
| 102 |
+
|
| 103 |
+
grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
|
| 104 |
+
|
| 105 |
+
# grid sample
|
| 106 |
+
images = F.grid_sample(images, grids, align_corners=False)
|
| 107 |
+
|
| 108 |
+
return images
|
| 109 |
+
|
data_test/catstatue.ply
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:57dc6f5902301d7577c53a73ce4c9d1bbff2fca86bf93d015b6cdfa1d3de9b18
|
| 3 |
+
size 2390497
|