Vo Minh Vu commited on
Commit
4878904
·
1 Parent(s): 832704e

Initial 3d process project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. LICENSE.md +51 -0
  3. README.md +40 -12
  4. __init__.py +201 -0
  5. app.py +75 -0
  6. comfyui/visualization.js +134 -0
  7. comfyui/web/visualizer.html +19 -0
  8. comfyui/web/visualizer.js +86 -0
  9. gradio_app.py +369 -0
  10. requirements-demo.txt +2 -0
  11. requirements-dev.txt +2 -0
  12. requirements.txt +20 -0
  13. ruff.toml +3 -0
  14. run.py +141 -0
  15. sf3d/__pycache__/system.cpython-313.pyc +0 -0
  16. sf3d/models/__pycache__/isosurface.cpython-313.pyc +0 -0
  17. sf3d/models/__pycache__/mesh.cpython-313.pyc +0 -0
  18. sf3d/models/__pycache__/utils.cpython-313.pyc +0 -0
  19. sf3d/models/camera.py +32 -0
  20. sf3d/models/global_estimator/multi_head_estimator.py +118 -0
  21. sf3d/models/image_estimator/clip_based_estimator.py +168 -0
  22. sf3d/models/isosurface.py +229 -0
  23. sf3d/models/mesh.py +289 -0
  24. sf3d/models/network.py +213 -0
  25. sf3d/models/tokenizers/dinov2.py +1196 -0
  26. sf3d/models/tokenizers/image.py +99 -0
  27. sf3d/models/tokenizers/triplane.py +49 -0
  28. sf3d/models/transformers/attention.py +31 -0
  29. sf3d/models/transformers/backbone.py +515 -0
  30. sf3d/models/utils.py +236 -0
  31. sf3d/system.py +532 -0
  32. sf3d/utils.py +105 -0
  33. texture_baker/README.md +26 -0
  34. texture_baker/requirements.txt +2 -0
  35. texture_baker/setup.py +131 -0
  36. texture_baker/texture_baker/__init__.py +4 -0
  37. texture_baker/texture_baker/baker.py +86 -0
  38. texture_baker/texture_baker/csrc/baker.cpp +548 -0
  39. texture_baker/texture_baker/csrc/baker.h +203 -0
  40. texture_baker/texture_baker/csrc/baker_kernel.cu +306 -0
  41. texture_baker/texture_baker/csrc/baker_kernel.metal +170 -0
  42. texture_baker/texture_baker/csrc/baker_kernel.mm +260 -0
  43. uv_unwrapper/README.md +0 -0
  44. uv_unwrapper/requirements.txt +2 -0
  45. uv_unwrapper/setup.py +80 -0
  46. uv_unwrapper/uv_unwrapper/__init__.py +6 -0
  47. uv_unwrapper/uv_unwrapper/csrc/bvh.cpp +381 -0
  48. uv_unwrapper/uv_unwrapper/csrc/bvh.h +118 -0
  49. uv_unwrapper/uv_unwrapper/csrc/common.h +493 -0
  50. uv_unwrapper/uv_unwrapper/csrc/intersect.cpp +702 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice should be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+ Last Updated: July 5, 2024
3
+
4
+
5
+ I. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity ("You", "Your" or "Licensee") that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+
10
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
11
+
12
+
13
+ By clicking "I Accept" or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then "You" includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity's behalf.
14
+
15
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
16
+
17
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. "Research Purpose" means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. "Non-Commercial Purpose" means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
18
+
19
+ III. COMMERCIAL USE LICENSE
20
+
21
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI's intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. "Commercial Purpose" means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business's or organization's internal operations.
22
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
23
+
24
+ IV. GENERAL TERMS
25
+
26
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
27
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved", and (iii) prominently display "Powered by Stability AI" on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the "Notice" text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the "Notice" text file that You changed the Stability AI Materials and how it was modified.
28
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI's AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
29
+ c. Intellectual Property.
30
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
31
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI's ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
32
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
33
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
34
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI's existing or prospective technology, products or services (collectively, "Feedback"). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided "AS IS" and You make no warranties whatsoever about any Feedback.
35
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
36
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
37
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
38
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
39
+
40
+ V. DEFINITIONS
41
+
42
+ "Affiliate(s)" means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, "control" means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
43
+ "Agreement" means this Stability AI Community License Agreement.
44
+ "AUP" means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time.
45
+ "Derivative Work(s)" means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model's output, including"fine tune" and "low-rank adaptation" models derived from a Model or a Model's output, but do not include the output of any Model.
46
+ "Documentation" means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
47
+ "Model(s)" means, collectively, Stability AI's proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability's Core Models Webpage available at, https://stability.ai/core-models, as may be updated from time to time.
48
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
49
+ "Software" means Stability AI's proprietary software made available under this Agreement now or in the future.
50
+ "Stability AI Materials" means, collectively, Stability's proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
51
+ "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,40 @@
1
- ---
2
- title: Demo3d StabilityAI
3
- emoji: 🏃
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.25.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stable Fast 3D
2
+
3
+ A powerful tool for converting 2D images to 3D models using Stable Diffusion and advanced 3D reconstruction techniques.
4
+
5
+ ## Features
6
+
7
+ - Convert 2D images to 3D models
8
+ - Text-guided 3D generation
9
+ - High-quality texture mapping
10
+ - UV unwrapping
11
+ - Export to GLB format
12
+
13
+ ## Usage
14
+
15
+ 1. Upload an image
16
+ 2. Enter a prompt describing the 3D model you want to generate
17
+ 3. Click "Generate 3D Model"
18
+ 4. Wait for the processing to complete
19
+ 5. Download the generated 3D model
20
+
21
+ ## Technical Details
22
+
23
+ This Space uses:
24
+ - Stable Diffusion for image-to-3D conversion
25
+ - Advanced texture baking techniques
26
+ - UV unwrapping for optimal texture mapping
27
+ - CPU-only inference for compatibility
28
+
29
+ ## Requirements
30
+
31
+ All dependencies are listed in `requirements.txt`. The main requirements are:
32
+ - PyTorch
33
+ - Transformers
34
+ - Gradio
35
+ - Pillow
36
+ - NumPy
37
+
38
+ ## License
39
+
40
+ This project is licensed under the MIT License - see the LICENSE file for details.
__init__.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import sys
5
+ from contextlib import nullcontext
6
+
7
+ import comfy.model_management
8
+ import folder_paths
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ from PIL import Image
13
+ from trimesh.exchange import gltf
14
+
15
+ sys.path.append(os.path.dirname(__file__))
16
+ from sf3d.system import SF3D
17
+ from sf3d.utils import resize_foreground
18
+
19
+ SF3D_CATEGORY = "StableFast3D"
20
+ SF3D_MODEL_NAME = "stabilityai/stable-fast-3d"
21
+
22
+
23
+ class StableFast3DLoader:
24
+ CATEGORY = SF3D_CATEGORY
25
+ FUNCTION = "load"
26
+ RETURN_NAMES = ("sf3d_model",)
27
+ RETURN_TYPES = ("SF3D_MODEL",)
28
+
29
+ @classmethod
30
+ def INPUT_TYPES(cls):
31
+ return {"required": {}}
32
+
33
+ def load(self):
34
+ device = comfy.model_management.get_torch_device()
35
+ model = SF3D.from_pretrained(
36
+ SF3D_MODEL_NAME,
37
+ config_name="config.yaml",
38
+ weight_name="model.safetensors",
39
+ )
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ return (model,)
44
+
45
+
46
+ class StableFast3DPreview:
47
+ CATEGORY = SF3D_CATEGORY
48
+ FUNCTION = "preview"
49
+ OUTPUT_NODE = True
50
+ RETURN_TYPES = ()
51
+
52
+ @classmethod
53
+ def INPUT_TYPES(s):
54
+ return {"required": {"mesh": ("MESH",)}}
55
+
56
+ def preview(self, mesh):
57
+ glbs = []
58
+ for m in mesh:
59
+ scene = trimesh.Scene(m)
60
+ glb_data = gltf.export_glb(scene, include_normals=True)
61
+ glb_base64 = base64.b64encode(glb_data).decode("utf-8")
62
+ glbs.append(glb_base64)
63
+ return {"ui": {"glbs": glbs}}
64
+
65
+
66
+ class StableFast3DSampler:
67
+ CATEGORY = SF3D_CATEGORY
68
+ FUNCTION = "predict"
69
+ RETURN_NAMES = ("mesh",)
70
+ RETURN_TYPES = ("MESH",)
71
+
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {
75
+ "required": {
76
+ "model": ("SF3D_MODEL",),
77
+ "image": ("IMAGE",),
78
+ "foreground_ratio": (
79
+ "FLOAT",
80
+ {"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.01},
81
+ ),
82
+ "texture_resolution": (
83
+ "INT",
84
+ {"default": 1024, "min": 512, "max": 2048, "step": 256},
85
+ ),
86
+ },
87
+ "optional": {
88
+ "mask": ("MASK",),
89
+ "remesh": (["none", "triangle", "quad"],),
90
+ "vertex_count": (
91
+ "INT",
92
+ {"default": -1, "min": -1, "max": 20000, "step": 1},
93
+ ),
94
+ },
95
+ }
96
+
97
+ def predict(
98
+ s,
99
+ model,
100
+ image,
101
+ mask,
102
+ foreground_ratio,
103
+ texture_resolution,
104
+ remesh="none",
105
+ vertex_count=-1,
106
+ ):
107
+ if image.shape[0] != 1:
108
+ raise ValueError("Only one image can be processed at a time")
109
+
110
+ pil_image = Image.fromarray(
111
+ torch.clamp(torch.round(255.0 * image[0]), 0, 255)
112
+ .type(torch.uint8)
113
+ .cpu()
114
+ .numpy()
115
+ )
116
+
117
+ if mask is not None:
118
+ print("Using Mask")
119
+ mask_np = np.clip(255.0 * mask[0].detach().cpu().numpy(), 0, 255).astype(
120
+ np.uint8
121
+ )
122
+ mask_pil = Image.fromarray(mask_np, mode="L")
123
+ pil_image.putalpha(mask_pil)
124
+ else:
125
+ if image.shape[3] != 4:
126
+ print("No mask or alpha channel detected, Converting to RGBA")
127
+ pil_image = pil_image.convert("RGBA")
128
+
129
+ pil_image = resize_foreground(pil_image, foreground_ratio)
130
+ print(remesh)
131
+ with torch.no_grad():
132
+ with torch.autocast(
133
+ device_type="cuda", dtype=torch.bfloat16
134
+ ) if "cuda" in comfy.model_management.get_torch_device().type else nullcontext():
135
+ mesh, glob_dict = model.run_image(
136
+ pil_image,
137
+ bake_resolution=texture_resolution,
138
+ remesh=remesh,
139
+ vertex_count=vertex_count,
140
+ )
141
+
142
+ if mesh.vertices.shape[0] == 0:
143
+ raise ValueError("No subject detected in the image")
144
+
145
+ return ([mesh],)
146
+
147
+
148
+ class StableFast3DSave:
149
+ CATEGORY = SF3D_CATEGORY
150
+ FUNCTION = "save"
151
+ OUTPUT_NODE = True
152
+ RETURN_TYPES = ()
153
+
154
+ @classmethod
155
+ def INPUT_TYPES(s):
156
+ return {
157
+ "required": {
158
+ "mesh": ("MESH",),
159
+ "filename_prefix": ("STRING", {"default": "SF3D"}),
160
+ }
161
+ }
162
+
163
+ def __init__(self):
164
+ self.type = "output"
165
+
166
+ def save(self, mesh, filename_prefix):
167
+ output_dir = folder_paths.get_output_directory()
168
+ glbs = []
169
+ for idx, m in enumerate(mesh):
170
+ scene = trimesh.Scene(m)
171
+ glb_data = gltf.export_glb(scene, include_normals=True)
172
+ logging.info(f"Generated GLB model with {len(glb_data)} bytes")
173
+
174
+ full_output_folder, filename, counter, subfolder, filename_prefix = (
175
+ folder_paths.get_save_image_path(filename_prefix, output_dir)
176
+ )
177
+ filename = filename.replace("%batch_num%", str(idx))
178
+ out_path = os.path.join(full_output_folder, f"{filename}_{counter:05}_.glb")
179
+ with open(out_path, "wb") as f:
180
+ f.write(glb_data)
181
+ glbs.append(base64.b64encode(glb_data).decode("utf-8"))
182
+ return {"ui": {"glbs": glbs}}
183
+
184
+
185
+ NODE_DISPLAY_NAME_MAPPINGS = {
186
+ "StableFast3DLoader": "Stable Fast 3D Loader",
187
+ "StableFast3DPreview": "Stable Fast 3D Preview",
188
+ "StableFast3DSampler": "Stable Fast 3D Sampler",
189
+ "StableFast3DSave": "Stable Fast 3D Save",
190
+ }
191
+
192
+ NODE_CLASS_MAPPINGS = {
193
+ "StableFast3DLoader": StableFast3DLoader,
194
+ "StableFast3DPreview": StableFast3DPreview,
195
+ "StableFast3DSampler": StableFast3DSampler,
196
+ "StableFast3DSave": StableFast3DSave,
197
+ }
198
+
199
+ WEB_DIRECTORY = "./comfyui"
200
+
201
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import os
6
+ from pathlib import Path
7
+ import tempfile
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Set environment variable to use CPU
11
+ os.environ["SF3D_USE_CPU"] = "1"
12
+
13
+ # Import the main pipeline
14
+ from stable_fast_3d import StableFast3D
15
+
16
+ # Initialize the model
17
+ model = StableFast3D()
18
+
19
+ def process_image(image, prompt):
20
+ # Convert image to PIL if it's not already
21
+ if isinstance(image, np.ndarray):
22
+ image = Image.fromarray(image)
23
+
24
+ # Create temporary directory for output
25
+ with tempfile.TemporaryDirectory() as tmpdir:
26
+ output_path = Path(tmpdir) / "output.glb"
27
+
28
+ # Process the image
29
+ model.process_image(
30
+ image=image,
31
+ prompt=prompt,
32
+ output_path=str(output_path)
33
+ )
34
+
35
+ # Return the GLB file
36
+ return str(output_path)
37
+
38
+ def convert_2d_to_3d(image, prompt=None):
39
+ """
40
+ Convert a 2D image to a 3D model using Stable Diffusion and advanced 3D reconstruction.
41
+
42
+ Args:
43
+ image (PIL.Image): Input 2D image
44
+ prompt (str, optional): Text prompt to guide the 3D generation
45
+
46
+ Returns:
47
+ str: Path to the generated GLB file
48
+ """
49
+ # TODO: Implement the actual 2D to 3D conversion logic
50
+ # For now, return a placeholder message
51
+ return "3D conversion will be implemented soon!"
52
+
53
+ # Create Gradio interface
54
+ with gr.Blocks() as demo:
55
+ gr.Markdown("# Stable Fast 3D - Convert 2D Images to 3D Models")
56
+ gr.Markdown("Upload a 2D image and get a 3D model in return. Optionally provide a text prompt to guide the generation.")
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ input_image = gr.Image(type="pil", label="Input Image")
61
+ text_prompt = gr.Textbox(label="Text Prompt (optional)", placeholder="Enter a description to guide the 3D generation...")
62
+ convert_btn = gr.Button("Convert to 3D")
63
+
64
+ with gr.Column():
65
+ output = gr.Text(label="Output")
66
+
67
+ convert_btn.click(
68
+ fn=convert_2d_to_3d,
69
+ inputs=[input_image, text_prompt],
70
+ outputs=output
71
+ )
72
+
73
+ # Launch the app
74
+ if __name__ == "__main__":
75
+ demo.launch()
comfyui/visualization.js ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Thanks to MrForExample ComfyUI-3D-Pack for the base code
2
+ // https://github.com/MrForExample/ComfyUI-3D-Pack/blob/main/web/visualization.js
3
+
4
+ import { app } from "/scripts/app.js"
5
+
6
+ class Visualizer {
7
+ constructor(node, container) {
8
+ this.node = node;
9
+ this.iframe = document.createElement('iframe');
10
+ Object.assign(this.iframe, {
11
+ scrolling: "no",
12
+ overflow: "hidden",
13
+ });
14
+ this.iframe.src = "/extensions/stable-fast-3d/web/visualizer.html";
15
+ container.appendChild(this.iframe);
16
+ }
17
+
18
+ update(b64_glb) {
19
+ const iframeDocument = this.iframe.contentWindow.document;
20
+ const previewScript = iframeDocument.getElementById('visualizer');
21
+ previewScript.setAttribute("b64_glb", b64_glb);
22
+ previewScript.setAttribute("timestamp", Date.now().toString());
23
+ }
24
+
25
+ remove() {
26
+ this.container.remove();
27
+ }
28
+ }
29
+
30
+ function createWidget(node, app) {
31
+ const widget = {
32
+ type: "StableFast3DViewer",
33
+ name: "preview",
34
+ callback: () => { },
35
+ draw: function (ctx, node, widgetWidth, widgetY, widgetHeight) {
36
+ const margin = 10;
37
+ const top_offset = 5;
38
+ const visible = app.canvas.ds.scale > 0.5;
39
+ const w = widgetWidth - margin * 4;
40
+ const clientRectBound = ctx.canvas.getBoundingClientRect();
41
+ const transform = new DOMMatrix()
42
+ .scaleSelf(
43
+ clientRectBound.width / ctx.canvas.width,
44
+ clientRectBound.height / ctx.canvas.height
45
+ )
46
+ .multiplySelf(ctx.getTransform())
47
+ .translateSelf(margin, margin + widgetY);
48
+
49
+ Object.assign(this.visualizer.style, {
50
+ left: `${transform.a * margin + transform.e}px`,
51
+ top: `${transform.d + transform.f + top_offset}px`,
52
+ width: `${(w * transform.a)}px`,
53
+ height: `${(w * transform.d - widgetHeight - (margin * 15) * transform.d)}px`,
54
+ position: "absolute",
55
+ overflow: "hidden",
56
+ zIndex: app.graph._nodes.indexOf(node),
57
+ });
58
+
59
+ Object.assign(this.visualizer.children[0].style, {
60
+ transformOrigin: "50% 50%",
61
+ width: '100%',
62
+ height: '100%',
63
+ border: '0 none',
64
+ });
65
+
66
+ this.visualizer.hidden = !visible;
67
+ },
68
+ };
69
+
70
+ const container = document.createElement('div');
71
+
72
+ node.visualizer = new Visualizer(node, container);
73
+ widget.visualizer = container;
74
+ widget.parent = node;
75
+ document.body.appendChild(widget.visualizer);
76
+ node.addCustomWidget(widget);
77
+
78
+ node.onDrawBackground = (ctx) => {
79
+ node.visualizer.iframe.hidden = this.flags.collapsed;
80
+ };
81
+
82
+ node.onRemoved = () => {
83
+ for (let w in node.widgets) {
84
+ if (node.widgets[w].visualizer) {
85
+ node.widgets[w].visualizer.remove();
86
+ }
87
+ }
88
+ };
89
+
90
+ node.onResize = () => {
91
+ let [w, h] = this.size;
92
+ if (w <= 600) w = 600;
93
+ if (h <= 500) h = 500;
94
+ if (w > 600) {
95
+ h = w - 100;
96
+ }
97
+ this.size = [w, h];
98
+ };
99
+
100
+ node.updateParameters = (b64_glb) => {
101
+ node.visualizer.update(b64_glb);
102
+ };
103
+
104
+ return { widget: widget }
105
+ }
106
+
107
+
108
+ function registerVisualizer(nodeType, nodeData) {
109
+ if (nodeData.name !== "StableFast3DSave" && nodeData.name !== "StableFast3DPreview")
110
+ return;
111
+
112
+ const originalOnNodeCreated = nodeType.prototype.onNodeCreated;
113
+
114
+ nodeType.prototype.onNodeCreated = async function () {
115
+ const result = originalOnNodeCreated?.apply(this, arguments);
116
+ await createWidget.apply(this, [this, app]);
117
+ this.setSize([512, 512]);
118
+ return result;
119
+ };
120
+
121
+ nodeType.prototype.onExecuted = function (message) {
122
+ if (message?.glbs?.[0]) {
123
+ this.updateParameters(message.glbs[0]);
124
+ }
125
+ };
126
+ }
127
+
128
+ app.registerExtension({
129
+ name: "StableFast3D.Visualizer",
130
+ async init(app) { },
131
+ async beforeRegisterNodeDef(nodeType, nodeData, app) {
132
+ registerVisualizer(nodeType, nodeData);
133
+ },
134
+ });
comfyui/web/visualizer.html ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
6
+ </head>
7
+ <body>
8
+ <div id="container"></div>
9
+ <script type="importmap">
10
+ {
11
+ "imports": {
12
+ "three": "https://unpkg.com/three@latest/build/three.module.js",
13
+ "three/addons/": "https://unpkg.com/three@latest/examples/jsm/"
14
+ }
15
+ }
16
+ </script>
17
+ <script id="visualizer" type="module" b64_glb="" timestamp="" crossorigin src="visualizer.js"></script>
18
+ </body>
19
+ </html>
comfyui/web/visualizer.js ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from 'three';
2
+ import { GLTFLoader } from 'three/addons/loaders/GLTFLoader.js';
3
+ import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
4
+
5
+ const container = document.getElementById("container");
6
+ const visualizer = document.getElementById("visualizer");
7
+
8
+ const renderer = new THREE.WebGLRenderer({ antialias: true });
9
+ renderer.setClearColor(0x808080);
10
+ renderer.setPixelRatio(window.devicePixelRatio);
11
+ renderer.setSize(window.innerWidth, window.innerHeight);
12
+ container.appendChild(renderer.domElement);
13
+
14
+ const scene = new THREE.Scene();
15
+ const camera = new THREE.PerspectiveCamera(75, 1, 0.1, 1000);
16
+
17
+ const controls = new OrbitControls(camera, renderer.domElement);
18
+ controls.dampingFactor = 0.25;
19
+ controls.enableDamping = true;
20
+ controls.enableZoom = true;
21
+
22
+ const ambientLight = new THREE.AmbientLight(0xffffff, 0.75);
23
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1.0);
24
+ directionalLight.position.set(0.5, 1, -1.5);
25
+ const hemisphereLight = new THREE.HemisphereLight(0xffffbb, 0x080820, 0.5);
26
+
27
+
28
+ var lastTimestamp = "";
29
+
30
+ window.onresize = function () {
31
+ camera.aspect = window.innerWidth / window.innerHeight;
32
+ camera.updateProjectionMatrix();
33
+ renderer.setSize(window.innerWidth, window.innerHeight);
34
+ };
35
+
36
+ function render() {
37
+ var timestamp = visualizer.getAttribute("timestamp");
38
+ var b64_glb = visualizer.getAttribute("b64_glb");
39
+ if (timestamp != lastTimestamp) {
40
+ lastTimestamp = timestamp;
41
+ init(b64_glb);
42
+ }
43
+ controls.update();
44
+ renderer.render(scene, camera);
45
+ requestAnimationFrame(render);
46
+ }
47
+
48
+ async function init(b64_glb) {
49
+ scene.clear();
50
+ scene.add(ambientLight);
51
+ scene.add(camera);
52
+ scene.add(directionalLight);
53
+ scene.add(hemisphereLight);
54
+
55
+ if (b64_glb) {
56
+ const loader = new GLTFLoader();
57
+ const glbData = atob(b64_glb);
58
+ const glbBuffer = new Uint8Array(glbData.length);
59
+ for (let i = 0; i < glbData.length; i++) {
60
+ glbBuffer[i] = glbData.charCodeAt(i);
61
+ }
62
+
63
+ loader.parse(glbBuffer.buffer, '', (gltf) => {
64
+ scene.add(gltf.scene);
65
+
66
+ const box = new THREE.Box3().setFromObject(gltf.scene);
67
+ const center = box.getCenter(new THREE.Vector3());
68
+ const size = box.getSize(new THREE.Vector3());
69
+ const maxDim = Math.max(size.x, size.y, size.z);
70
+
71
+ const fov = camera.fov * (Math.PI / 180);
72
+ let cameraZ = Math.abs(maxDim / 2 / Math.tan(fov / 2));
73
+ camera.position.z = cameraZ * -1.5;
74
+ camera.lookAt(center);
75
+
76
+ controls.target.copy(center);
77
+ controls.update();
78
+ }, undefined, (error) => {
79
+ console.error('An error occurred loading GLB:', error);
80
+ });
81
+ }
82
+
83
+ render();
84
+ }
85
+
86
+ init();
gradio_app.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+ from contextlib import nullcontext
5
+ from functools import lru_cache
6
+ from typing import Any
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import rembg
11
+ import torch
12
+ from gradio_litmodel3d import LitModel3D
13
+ from PIL import Image
14
+
15
+ import sf3d.utils as sf3d_utils
16
+ from sf3d.system import SF3D
17
+
18
+ os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio")
19
+
20
+ rembg_session = rembg.new_session()
21
+
22
+ COND_WIDTH = 512
23
+ COND_HEIGHT = 512
24
+ COND_DISTANCE = 1.6
25
+ COND_FOVY_DEG = 40
26
+ BACKGROUND_COLOR = [0.5, 0.5, 0.5]
27
+
28
+ # Cached. Doesn't change
29
+ c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
30
+ intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
31
+ COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
32
+ )
33
+
34
+ generated_files = []
35
+
36
+ # Delete previous gradio temp dir folder
37
+ if os.path.exists(os.environ["GRADIO_TEMP_DIR"]):
38
+ print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}")
39
+ import shutil
40
+
41
+ shutil.rmtree(os.environ["GRADIO_TEMP_DIR"])
42
+
43
+ device = sf3d_utils.get_device()
44
+
45
+ model = SF3D.from_pretrained(
46
+ "stabilityai/stable-fast-3d",
47
+ config_name="config.yaml",
48
+ weight_name="model.safetensors",
49
+ )
50
+ model.eval()
51
+ model = model.to(device)
52
+
53
+ example_files = [
54
+ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
55
+ ]
56
+
57
+
58
+ def run_model(input_image, remesh_option, vertex_count, texture_size):
59
+ start = time.time()
60
+ with torch.no_grad():
61
+ with torch.autocast(
62
+ device_type=device, dtype=torch.bfloat16
63
+ ) if "cuda" in device else nullcontext():
64
+ model_batch = create_batch(input_image)
65
+ model_batch = {k: v.to(device) for k, v in model_batch.items()}
66
+ trimesh_mesh, _glob_dict = model.generate_mesh(
67
+ model_batch, texture_size, remesh_option, vertex_count
68
+ )
69
+ trimesh_mesh = trimesh_mesh[0]
70
+
71
+ # Create new tmp file
72
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
73
+
74
+ trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
75
+ generated_files.append(tmp_file.name)
76
+
77
+ print("Generation took:", time.time() - start, "s")
78
+
79
+ return tmp_file.name
80
+
81
+
82
+ def create_batch(input_image: Image) -> dict[str, Any]:
83
+ img_cond = (
84
+ torch.from_numpy(
85
+ np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
86
+ / 255.0
87
+ )
88
+ .float()
89
+ .clip(0, 1)
90
+ )
91
+ mask_cond = img_cond[:, :, -1:]
92
+ rgb_cond = torch.lerp(
93
+ torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
94
+ )
95
+
96
+ batch_elem = {
97
+ "rgb_cond": rgb_cond,
98
+ "mask_cond": mask_cond,
99
+ "c2w_cond": c2w_cond.unsqueeze(0),
100
+ "intrinsic_cond": intrinsic.unsqueeze(0),
101
+ "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
102
+ }
103
+ # Add batch dim
104
+ batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
105
+ return batched
106
+
107
+
108
+ @lru_cache
109
+ def checkerboard(squares: int, size: int, min_value: float = 0.5):
110
+ base = np.zeros((squares, squares)) + min_value
111
+ base[1::2, ::2] = 1
112
+ base[::2, 1::2] = 1
113
+
114
+ repeat_mult = size // squares
115
+ return (
116
+ base.repeat(repeat_mult, axis=0)
117
+ .repeat(repeat_mult, axis=1)[:, :, None]
118
+ .repeat(3, axis=-1)
119
+ )
120
+
121
+
122
+ def remove_background(input_image: Image) -> Image:
123
+ return rembg.remove(input_image, session=rembg_session)
124
+
125
+
126
+ def show_mask_img(input_image: Image) -> Image:
127
+ img_numpy = np.array(input_image)
128
+ alpha = img_numpy[:, :, 3] / 255.0
129
+ chkb = checkerboard(32, 512) * 255
130
+ new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
131
+ return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
132
+
133
+
134
+ def run_button(
135
+ run_btn,
136
+ input_image,
137
+ background_state,
138
+ foreground_ratio,
139
+ remesh_option,
140
+ vertex_count,
141
+ texture_size,
142
+ ):
143
+ if run_btn == "Run":
144
+ if torch.cuda.is_available():
145
+ torch.cuda.reset_peak_memory_stats()
146
+ glb_file: str = run_model(
147
+ background_state, remesh_option.lower(), vertex_count, texture_size
148
+ )
149
+ if torch.cuda.is_available():
150
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
151
+ elif torch.backends.mps.is_available():
152
+ print(
153
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
154
+ )
155
+
156
+ return (
157
+ gr.update(),
158
+ gr.update(),
159
+ gr.update(),
160
+ gr.update(),
161
+ gr.update(value=glb_file, visible=True),
162
+ gr.update(visible=True),
163
+ )
164
+ elif run_btn == "Remove Background":
165
+ rem_removed = remove_background(input_image)
166
+
167
+ fr_res = sf3d_utils.resize_foreground(
168
+ rem_removed, foreground_ratio, out_size=(COND_WIDTH, COND_HEIGHT)
169
+ )
170
+
171
+ return (
172
+ gr.update(value="Run", visible=True),
173
+ rem_removed,
174
+ fr_res,
175
+ gr.update(value=show_mask_img(fr_res), visible=True),
176
+ gr.update(value=None, visible=False),
177
+ gr.update(visible=False),
178
+ )
179
+
180
+
181
+ def requires_bg_remove(image, fr):
182
+ if image is None:
183
+ return (
184
+ gr.update(visible=False, value="Run"),
185
+ None,
186
+ None,
187
+ gr.update(value=None, visible=False),
188
+ gr.update(visible=False),
189
+ gr.update(visible=False),
190
+ )
191
+ alpha_channel = np.array(image.getchannel("A"))
192
+ min_alpha = alpha_channel.min()
193
+
194
+ if min_alpha == 0:
195
+ print("Already has alpha")
196
+ fr_res = sf3d_utils.resize_foreground(
197
+ image, fr, out_size=(COND_WIDTH, COND_HEIGHT)
198
+ )
199
+ return (
200
+ gr.update(value="Run", visible=True),
201
+ image,
202
+ fr_res,
203
+ gr.update(value=show_mask_img(fr_res), visible=True),
204
+ gr.update(visible=False),
205
+ gr.update(visible=False),
206
+ )
207
+ return (
208
+ gr.update(value="Remove Background", visible=True),
209
+ None,
210
+ None,
211
+ gr.update(value=None, visible=False),
212
+ gr.update(visible=False),
213
+ gr.update(visible=False),
214
+ )
215
+
216
+
217
+ def update_foreground_ratio(img_proc, fr):
218
+ foreground_res = sf3d_utils.resize_foreground(
219
+ img_proc, fr, out_size=(COND_WIDTH, COND_HEIGHT)
220
+ )
221
+ return (
222
+ foreground_res,
223
+ gr.update(value=show_mask_img(foreground_res)),
224
+ )
225
+
226
+
227
+ with gr.Blocks() as demo:
228
+ img_proc_state = gr.State()
229
+ background_remove_state = gr.State()
230
+ gr.Markdown("""
231
+ # SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
232
+
233
+ **SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
234
+ This demo allows you to upload an image and generate a 3D mesh model from it.
235
+
236
+ **Tips**
237
+ 1. If the image already has an alpha channel, you can skip the background removal step.
238
+ 2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
239
+ 3. You can select the remeshing option to control the mesh topology. This can introduce artifacts in the mesh on thin surfaces and should be turned off in such cases.
240
+ 4. You can upload your own HDR environment map to light the 3D model.
241
+ """)
242
+ with gr.Row(variant="panel"):
243
+ with gr.Column():
244
+ with gr.Row():
245
+ input_img = gr.Image(
246
+ type="pil", label="Input Image", sources="upload", image_mode="RGBA"
247
+ )
248
+ preview_removal = gr.Image(
249
+ label="Preview Background Removal",
250
+ type="pil",
251
+ image_mode="RGB",
252
+ interactive=False,
253
+ visible=False,
254
+ )
255
+
256
+ foreground_ratio = gr.Slider(
257
+ label="Foreground Ratio",
258
+ minimum=0.5,
259
+ maximum=1.0,
260
+ value=0.85,
261
+ step=0.05,
262
+ )
263
+
264
+ foreground_ratio.change(
265
+ update_foreground_ratio,
266
+ inputs=[img_proc_state, foreground_ratio],
267
+ outputs=[background_remove_state, preview_removal],
268
+ )
269
+
270
+ remesh_option = gr.Radio(
271
+ choices=["None", "Triangle", "Quad"],
272
+ label="Remeshing",
273
+ value="None",
274
+ visible=True,
275
+ )
276
+
277
+ vertex_count_slider = gr.Slider(
278
+ label="Target Vertex Count",
279
+ minimum=-1,
280
+ maximum=20000,
281
+ value=-1,
282
+ visible=True,
283
+ )
284
+
285
+ texture_size = gr.Slider(
286
+ label="Texture Size",
287
+ minimum=512,
288
+ maximum=2048,
289
+ value=1024,
290
+ step=256,
291
+ visible=True,
292
+ )
293
+
294
+ run_btn = gr.Button("Run", variant="primary", visible=False)
295
+
296
+ with gr.Column():
297
+ output_3d = LitModel3D(
298
+ label="3D Model",
299
+ visible=False,
300
+ clear_color=[0.0, 0.0, 0.0, 0.0],
301
+ tonemapping="aces",
302
+ contrast=1.0,
303
+ scale=1.0,
304
+ )
305
+ with gr.Column(visible=False, scale=1.0) as hdr_row:
306
+ gr.Markdown("""## HDR Environment Map
307
+
308
+ Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
309
+ """)
310
+
311
+ with gr.Row():
312
+ hdr_illumination_file = gr.File(
313
+ label="HDR Env Map", file_types=[".hdr"], file_count="single"
314
+ )
315
+ example_hdris = [
316
+ os.path.join("demo_files/hdri", f)
317
+ for f in os.listdir("demo_files/hdri")
318
+ ]
319
+ hdr_illumination_example = gr.Examples(
320
+ examples=example_hdris,
321
+ inputs=hdr_illumination_file,
322
+ )
323
+
324
+ hdr_illumination_file.change(
325
+ lambda x: gr.update(env_map=x.name if x is not None else None),
326
+ inputs=hdr_illumination_file,
327
+ outputs=[output_3d],
328
+ )
329
+
330
+ examples = gr.Examples(
331
+ examples=example_files,
332
+ inputs=input_img,
333
+ )
334
+
335
+ input_img.change(
336
+ requires_bg_remove,
337
+ inputs=[input_img, foreground_ratio],
338
+ outputs=[
339
+ run_btn,
340
+ img_proc_state,
341
+ background_remove_state,
342
+ preview_removal,
343
+ output_3d,
344
+ hdr_row,
345
+ ],
346
+ )
347
+
348
+ run_btn.click(
349
+ run_button,
350
+ inputs=[
351
+ run_btn,
352
+ input_img,
353
+ background_remove_state,
354
+ foreground_ratio,
355
+ remesh_option,
356
+ vertex_count_slider,
357
+ texture_size,
358
+ ],
359
+ outputs=[
360
+ run_btn,
361
+ img_proc_state,
362
+ background_remove_state,
363
+ preview_removal,
364
+ output_3d,
365
+ hdr_row,
366
+ ],
367
+ )
368
+
369
+ demo.queue().launch(share=False)
requirements-demo.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==4.41.0
2
+ gradio-litmodel3d==0.0.1
requirements-dev.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ruff
2
+ pre-commit
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ torchaudio>=2.0.0
4
+ einops>=0.6.0
5
+ jaxtyping>=0.2.0
6
+ omegaconf>=2.3.0
7
+ transformers>=4.30.0
8
+ open_clip_torch>=2.0.0
9
+ trimesh>=3.9.0
10
+ numpy>=1.24.0
11
+ huggingface-hub>=0.15.0
12
+ rembg>=2.0.0
13
+ onnxruntime>=1.14.0
14
+ pynanoinstantmeshes>=0.0.3
15
+ scipy>=1.10.0
16
+ scikit-image>=0.20.0
17
+ opencv-python-headless>=4.7.0
18
+ gradio>=3.50.0
19
+ Pillow>=9.0.0
20
+ pathlib>=1.0.1
ruff.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [lint]
2
+ ignore = ["F722"]
3
+ extend-select = ["I"]
run.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from contextlib import nullcontext
4
+
5
+ import rembg
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+ from sf3d.system import SF3D
11
+ from sf3d.utils import get_device, remove_background, resize_foreground
12
+
13
+ if __name__ == "__main__":
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "image", type=str, nargs="+", help="Path to input image(s) or folder."
17
+ )
18
+ parser.add_argument(
19
+ "--device",
20
+ default=get_device(),
21
+ type=str,
22
+ help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'",
23
+ )
24
+ parser.add_argument(
25
+ "--pretrained-model",
26
+ default="stabilityai/stable-fast-3d",
27
+ type=str,
28
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-fast-3d'",
29
+ )
30
+ parser.add_argument(
31
+ "--foreground-ratio",
32
+ default=0.85,
33
+ type=float,
34
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
35
+ )
36
+ parser.add_argument(
37
+ "--output-dir",
38
+ default="output/",
39
+ type=str,
40
+ help="Output directory to save the results. Default: 'output/'",
41
+ )
42
+ parser.add_argument(
43
+ "--texture-resolution",
44
+ default=1024,
45
+ type=int,
46
+ help="Texture atlas resolution. Default: 1024",
47
+ )
48
+ parser.add_argument(
49
+ "--remesh_option",
50
+ choices=["none", "triangle", "quad"],
51
+ default="none",
52
+ help="Remeshing option",
53
+ )
54
+ parser.add_argument(
55
+ "--target_vertex_count",
56
+ type=int,
57
+ help="Target vertex count. -1 does not perform a reduction.",
58
+ default=-1,
59
+ )
60
+ parser.add_argument(
61
+ "--batch_size", default=1, type=int, help="Batch size for inference"
62
+ )
63
+ args = parser.parse_args()
64
+
65
+ # Ensure args.device contains cuda
66
+ devices = ["cuda", "mps", "cpu"]
67
+ if not any(args.device in device for device in devices):
68
+ raise ValueError("Invalid device. Use cuda, mps or cpu")
69
+
70
+ output_dir = args.output_dir
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ device = args.device
74
+ if not (torch.cuda.is_available() or torch.backends.mps.is_available()):
75
+ device = "cpu"
76
+
77
+ print("Device used: ", device)
78
+
79
+ model = SF3D.from_pretrained(
80
+ args.pretrained_model,
81
+ config_name="config.yaml",
82
+ weight_name="model.safetensors",
83
+ )
84
+ model.to(device)
85
+ model.eval()
86
+
87
+ rembg_session = rembg.new_session()
88
+ images = []
89
+ idx = 0
90
+ for image_path in args.image:
91
+
92
+ def handle_image(image_path, idx):
93
+ image = remove_background(
94
+ Image.open(image_path).convert("RGBA"), rembg_session
95
+ )
96
+ image = resize_foreground(image, args.foreground_ratio)
97
+ os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
98
+ image.save(os.path.join(output_dir, str(idx), "input.png"))
99
+ images.append(image)
100
+
101
+ if os.path.isdir(image_path):
102
+ image_paths = [
103
+ os.path.join(image_path, f)
104
+ for f in os.listdir(image_path)
105
+ if f.endswith((".png", ".jpg", ".jpeg"))
106
+ ]
107
+ for image_path in image_paths:
108
+ handle_image(image_path, idx)
109
+ idx += 1
110
+ else:
111
+ handle_image(image_path, idx)
112
+ idx += 1
113
+
114
+ for i in tqdm(range(0, len(images), args.batch_size)):
115
+ image = images[i : i + args.batch_size]
116
+ if torch.cuda.is_available():
117
+ torch.cuda.reset_peak_memory_stats()
118
+ with torch.no_grad():
119
+ with torch.autocast(
120
+ device_type=device, dtype=torch.bfloat16
121
+ ) if "cuda" in device else nullcontext():
122
+ mesh, glob_dict = model.run_image(
123
+ image,
124
+ bake_resolution=args.texture_resolution,
125
+ remesh=args.remesh_option,
126
+ vertex_count=args.target_vertex_count,
127
+ )
128
+ if torch.cuda.is_available():
129
+ print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
130
+ elif torch.backends.mps.is_available():
131
+ print(
132
+ "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB"
133
+ )
134
+
135
+ if len(image) == 1:
136
+ out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
137
+ mesh.export(out_mesh_path, include_normals=True)
138
+ else:
139
+ for j in range(len(mesh)):
140
+ out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
141
+ mesh[j].export(out_mesh_path, include_normals=True)
sf3d/__pycache__/system.cpython-313.pyc ADDED
Binary file (23.9 kB). View file
 
sf3d/models/__pycache__/isosurface.cpython-313.pyc ADDED
Binary file (10.9 kB). View file
 
sf3d/models/__pycache__/mesh.cpython-313.pyc ADDED
Binary file (14.1 kB). View file
 
sf3d/models/__pycache__/utils.cpython-313.pyc ADDED
Binary file (12.6 kB). View file
 
sf3d/models/camera.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sf3d.models.utils import BaseModule
8
+
9
+
10
+ class LinearCameraEmbedder(BaseModule):
11
+ @dataclass
12
+ class Config(BaseModule.Config):
13
+ in_channels: int = 25
14
+ out_channels: int = 768
15
+ conditions: List[str] = field(default_factory=list)
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
+
22
+ def forward(self, **kwargs):
23
+ cond_tensors = []
24
+ for cond_name in self.cfg.conditions:
25
+ assert cond_name in kwargs
26
+ cond = kwargs[cond_name]
27
+ # cond in shape (B, Nv, ...)
28
+ cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
+ cond_tensor = torch.cat(cond_tensors, dim=-1)
30
+ assert cond_tensor.shape[-1] == self.cfg.in_channels
31
+ embedding = self.linear(cond_tensor)
32
+ return embedding
sf3d/models/global_estimator/multi_head_estimator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import torch.nn as nn
5
+ from jaxtyping import Float
6
+ from torch import Tensor
7
+
8
+ from sf3d.models.network import get_activation
9
+ from sf3d.models.utils import BaseModule
10
+
11
+
12
+ @dataclass
13
+ class HeadSpec:
14
+ name: str
15
+ out_channels: int
16
+ n_hidden_layers: int
17
+ output_activation: Optional[str] = None
18
+ output_bias: float = 0.0
19
+ add_to_decoder_features: bool = False
20
+ shape: Optional[list[int]] = None
21
+
22
+
23
+ class MultiHeadEstimator(BaseModule):
24
+ @dataclass
25
+ class Config(BaseModule.Config):
26
+ triplane_features: int = 1024
27
+
28
+ n_layers: int = 2
29
+ hidden_features: int = 512
30
+ activation: str = "relu"
31
+
32
+ pool: str = "max"
33
+ # Literal["mean", "max"] = "mean" # noqa: F821
34
+
35
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
36
+
37
+ cfg: Config
38
+
39
+ def configure(self):
40
+ layers = []
41
+ cur_features = self.cfg.triplane_features * 3
42
+ for _ in range(self.cfg.n_layers):
43
+ layers.append(
44
+ nn.Conv2d(
45
+ cur_features,
46
+ self.cfg.hidden_features,
47
+ kernel_size=3,
48
+ padding=0,
49
+ stride=2,
50
+ )
51
+ )
52
+ layers.append(self.make_activation(self.cfg.activation))
53
+
54
+ cur_features = self.cfg.hidden_features
55
+
56
+ self.layers = nn.Sequential(*layers)
57
+
58
+ assert len(self.cfg.heads) > 0
59
+ heads = {}
60
+ for head in self.cfg.heads:
61
+ head_layers = []
62
+ for i in range(head.n_hidden_layers):
63
+ head_layers += [
64
+ nn.Linear(
65
+ self.cfg.hidden_features,
66
+ self.cfg.hidden_features,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ head_layers += [
71
+ nn.Linear(
72
+ self.cfg.hidden_features,
73
+ head.out_channels,
74
+ ),
75
+ ]
76
+ heads[head.name] = nn.Sequential(*head_layers)
77
+ self.heads = nn.ModuleDict(heads)
78
+
79
+ def make_activation(self, activation):
80
+ if activation == "relu":
81
+ return nn.ReLU(inplace=True)
82
+ elif activation == "silu":
83
+ return nn.SiLU(inplace=True)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(
88
+ self,
89
+ triplane: Float[Tensor, "B 3 F Ht Wt"],
90
+ ) -> dict[str, Any]:
91
+ x = self.layers(
92
+ triplane.reshape(
93
+ triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
94
+ )
95
+ )
96
+
97
+ if self.cfg.pool == "max":
98
+ x = x.amax(dim=[-2, -1])
99
+ elif self.cfg.pool == "mean":
100
+ x = x.mean(dim=[-2, -1])
101
+ else:
102
+ raise NotImplementedError
103
+
104
+ out = {
105
+ ("decoder_" if head.add_to_decoder_features else "")
106
+ + head.name: get_activation(head.output_activation)(
107
+ self.heads[head.name](x) + head.output_bias
108
+ )
109
+ for head in self.cfg.heads
110
+ }
111
+ for head in self.cfg.heads:
112
+ if head.shape:
113
+ head_name = (
114
+ "decoder_" if head.add_to_decoder_features else ""
115
+ ) + head.name
116
+ out[head_name] = out[head_name].reshape(*head.shape)
117
+
118
+ return out
sf3d/models/image_estimator/clip_based_estimator.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Optional
3
+
4
+ import open_clip
5
+ import torch
6
+ import torch.nn as nn
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+ from torchvision.transforms import Normalize
10
+
11
+ from sf3d.models.network import get_activation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ @dataclass
16
+ class HeadSpec:
17
+ name: str
18
+ out_channels: int
19
+ n_hidden_layers: int
20
+ output_activation: Optional[str] = None
21
+ output_bias: float = 0.0
22
+ add_to_decoder_features: bool = False
23
+ shape: Optional[list[int]] = None
24
+
25
+
26
+ class ClipBasedHeadEstimator(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ model: str = "ViT-B-32"
30
+ pretrain: str = "laion2b_s34b_b79k"
31
+
32
+ distribution: str = "beta"
33
+
34
+ # ["mean", "mode", "sample", "sample_mean"]
35
+ distribution_eval: str = "mode"
36
+
37
+ activation: str = "relu"
38
+ hidden_features: int = 512
39
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
40
+
41
+ cfg: Config
42
+
43
+ def configure(self):
44
+ self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
+ self.cfg.model, pretrained=self.cfg.pretrain
46
+ )
47
+ self.model.eval()
48
+
49
+ # Do not add the weights in self.model to the optimizer
50
+ for param in self.model.parameters():
51
+ param.requires_grad = False
52
+
53
+ assert len(self.cfg.heads) > 0
54
+ heads = {}
55
+ for head in self.cfg.heads:
56
+ head_layers = []
57
+
58
+ for i in range(head.n_hidden_layers):
59
+ head_layers += [
60
+ nn.Linear(
61
+ self.cfg.hidden_features,
62
+ self.cfg.hidden_features,
63
+ ),
64
+ self.make_activation(self.cfg.activation),
65
+ ]
66
+
67
+ head_layers = [nn.Sequential(*head_layers)]
68
+ head_layers += [
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ self.cfg.hidden_features,
72
+ self.cfg.hidden_features,
73
+ ),
74
+ self.make_activation(self.cfg.activation),
75
+ nn.Linear(self.cfg.hidden_features, 1),
76
+ )
77
+ for _ in range(2)
78
+ ]
79
+ heads[head.name] = nn.ModuleList(head_layers)
80
+ self.heads = nn.ModuleDict(heads)
81
+
82
+ def make_activation(self, activation):
83
+ if activation == "relu":
84
+ return nn.ReLU(inplace=True)
85
+ elif activation == "silu":
86
+ return nn.SiLU(inplace=True)
87
+ else:
88
+ raise NotImplementedError
89
+
90
+ def forward(
91
+ self,
92
+ cond_image: Float[Tensor, "B 1 H W 3"],
93
+ sample: bool = True,
94
+ ) -> dict[str, Any]:
95
+ # Run the model
96
+ # Resize cond_image to 224
97
+ cond_image = nn.functional.interpolate(
98
+ cond_image.flatten(0, 1).permute(0, 3, 1, 2).contiguous(),
99
+ size=(224, 224),
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+ cond_image = Normalize(
104
+ mean=open_clip.constants.OPENAI_DATASET_MEAN,
105
+ std=open_clip.constants.OPENAI_DATASET_STD,
106
+ )(cond_image)
107
+ image_features = self.model.encode_image(cond_image)
108
+
109
+ # Run the heads
110
+ outputs = {}
111
+
112
+ for head_dict in self.cfg.heads:
113
+ head_name = head_dict.name
114
+ shared_head, d1_h, d2_h = self.heads[head_name]
115
+ shared_features = shared_head(image_features)
116
+ d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
117
+ if self.cfg.distribution == "normal":
118
+ mean = d1
119
+ var = d2
120
+ if mean.shape[-1] == 1:
121
+ outputs[head_name] = torch.distributions.Normal(
122
+ mean + head_dict.output_bias,
123
+ torch.nn.functional.softplus(var),
124
+ )
125
+ else:
126
+ outputs[head_name] = torch.distributions.MultivariateNormal(
127
+ mean + head_dict.output_bias,
128
+ torch.nn.functional.softplus(var).diag_embed(),
129
+ )
130
+ elif self.cfg.distribution == "beta":
131
+ outputs[head_name] = torch.distributions.Beta(
132
+ torch.nn.functional.softplus(d1 + head_dict.output_bias),
133
+ torch.nn.functional.softplus(d2 + head_dict.output_bias),
134
+ )
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ if sample:
139
+ for head_dict in self.cfg.heads:
140
+ head_name = head_dict.name
141
+ dist = outputs[head_name]
142
+
143
+ if self.cfg.distribution_eval == "mean":
144
+ out = dist.mean
145
+ elif self.cfg.distribution_eval == "mode":
146
+ out = dist.mode
147
+ elif self.cfg.distribution_eval == "sample_mean":
148
+ out = dist.sample([10]).mean(-1)
149
+ else:
150
+ # use rsample if gradient is needed
151
+ out = dist.rsample() if self.training else dist.sample()
152
+
153
+ outputs[head_name] = get_activation(head_dict.output_activation)(out)
154
+ outputs[f"{head_name}_dist"] = dist
155
+
156
+ for head in self.cfg.heads:
157
+ if head.shape:
158
+ if not sample:
159
+ raise ValueError(
160
+ "Cannot reshape non-sampled probabilisitic outputs"
161
+ )
162
+ outputs[head.name] = outputs[head.name].reshape(*head.shape)
163
+
164
+ if head.add_to_decoder_features:
165
+ outputs[f"decoder_{head.name}"] = outputs[head.name]
166
+ del outputs[head.name]
167
+
168
+ return outputs
sf3d/models/isosurface.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from jaxtyping import Float, Integer
7
+ from torch import Tensor
8
+
9
+ from .mesh import Mesh
10
+
11
+
12
+ class IsosurfaceHelper(nn.Module):
13
+ points_range: Tuple[float, float] = (0, 1)
14
+
15
+ @property
16
+ def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
+ raise NotImplementedError
18
+
19
+ @property
20
+ def requires_instance_per_batch(self) -> bool:
21
+ return False
22
+
23
+
24
+ class MarchingTetrahedraHelper(IsosurfaceHelper):
25
+ def __init__(self, resolution: int, tets_path: str):
26
+ super().__init__()
27
+ self.resolution = resolution
28
+ self.tets_path = tets_path
29
+
30
+ self.triangle_table: Float[Tensor, "..."]
31
+ self.register_buffer(
32
+ "triangle_table",
33
+ torch.as_tensor(
34
+ [
35
+ [-1, -1, -1, -1, -1, -1],
36
+ [1, 0, 2, -1, -1, -1],
37
+ [4, 0, 3, -1, -1, -1],
38
+ [1, 4, 2, 1, 3, 4],
39
+ [3, 1, 5, -1, -1, -1],
40
+ [2, 3, 0, 2, 5, 3],
41
+ [1, 4, 0, 1, 5, 4],
42
+ [4, 2, 5, -1, -1, -1],
43
+ [4, 5, 2, -1, -1, -1],
44
+ [4, 1, 0, 4, 5, 1],
45
+ [3, 2, 0, 3, 5, 2],
46
+ [1, 3, 5, -1, -1, -1],
47
+ [4, 1, 2, 4, 3, 1],
48
+ [3, 0, 4, -1, -1, -1],
49
+ [2, 0, 1, -1, -1, -1],
50
+ [-1, -1, -1, -1, -1, -1],
51
+ ],
52
+ dtype=torch.long,
53
+ ),
54
+ persistent=False,
55
+ )
56
+ self.num_triangles_table: Integer[Tensor, "..."]
57
+ self.register_buffer(
58
+ "num_triangles_table",
59
+ torch.as_tensor(
60
+ [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
+ ),
62
+ persistent=False,
63
+ )
64
+ self.base_tet_edges: Integer[Tensor, "..."]
65
+ self.register_buffer(
66
+ "base_tet_edges",
67
+ torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
+ persistent=False,
69
+ )
70
+
71
+ tets = np.load(self.tets_path)
72
+ self._grid_vertices: Float[Tensor, "..."]
73
+ self.register_buffer(
74
+ "_grid_vertices",
75
+ torch.from_numpy(tets["vertices"]).float(),
76
+ persistent=False,
77
+ )
78
+ self.indices: Integer[Tensor, "..."]
79
+ self.register_buffer(
80
+ "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
+ )
82
+
83
+ self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
+
85
+ center_indices, boundary_indices = self.get_center_boundary_index(
86
+ self._grid_vertices
87
+ )
88
+ self.center_indices: Integer[Tensor, "..."]
89
+ self.register_buffer("center_indices", center_indices, persistent=False)
90
+ self.boundary_indices: Integer[Tensor, "..."]
91
+ self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
+
93
+ def get_center_boundary_index(self, verts):
94
+ magn = torch.sum(verts**2, dim=-1)
95
+
96
+ center_idx = torch.argmin(magn)
97
+ boundary_neg = verts == verts.max()
98
+ boundary_pos = verts == verts.min()
99
+
100
+ boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
+ boundary = torch.sum(boundary.float(), dim=-1)
102
+
103
+ boundary_idx = torch.nonzero(boundary)
104
+ return center_idx, boundary_idx.squeeze(dim=-1)
105
+
106
+ def normalize_grid_deformation(
107
+ self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
+ ) -> Float[Tensor, "Nv 3"]:
109
+ return (
110
+ (self.points_range[1] - self.points_range[0])
111
+ / self.resolution # half tet size is approximately 1 / self.resolution
112
+ * torch.tanh(grid_vertex_offsets)
113
+ ) # FIXME: hard-coded activation
114
+
115
+ @property
116
+ def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
+ return self._grid_vertices
118
+
119
+ @property
120
+ def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
+ if self._all_edges is None:
122
+ # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
+ edges = torch.tensor(
124
+ [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
+ dtype=torch.long,
126
+ device=self.indices.device,
127
+ )
128
+ _all_edges = self.indices[:, edges].reshape(-1, 2)
129
+ _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
+ _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
+ self._all_edges = _all_edges
132
+ return self._all_edges
133
+
134
+ def sort_edges(self, edges_ex2):
135
+ with torch.no_grad():
136
+ order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
+ order = order.unsqueeze(dim=1)
138
+
139
+ a = torch.gather(input=edges_ex2, index=order, dim=1)
140
+ b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
+
142
+ return torch.stack([a, b], -1)
143
+
144
+ def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
+ with torch.no_grad():
146
+ occ_n = sdf_n > 0
147
+ occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
+ occ_sum = torch.sum(occ_fx4, -1)
149
+ valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
+ occ_sum = occ_sum[valid_tets]
151
+
152
+ # find all vertices
153
+ all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
+ all_edges = self.sort_edges(all_edges)
155
+ unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
+
157
+ unique_edges = unique_edges.long()
158
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
+ mapping = (
160
+ torch.ones(
161
+ (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
+ )
163
+ * -1
164
+ )
165
+ mapping[mask_edges] = torch.arange(
166
+ mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
+ )
168
+ idx_map = mapping[idx_map] # map edges to verts
169
+
170
+ interp_v = unique_edges[mask_edges]
171
+ edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
+ edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
+ edges_to_interp_sdf[:, -1] *= -1
174
+
175
+ denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
+
177
+ edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
+ verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
+
180
+ idx_map = idx_map.reshape(-1, 6)
181
+
182
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
+ tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
+ num_triangles = self.num_triangles_table[tetindex]
185
+
186
+ # Generate triangle indices
187
+ faces = torch.cat(
188
+ (
189
+ torch.gather(
190
+ input=idx_map[num_triangles == 1],
191
+ dim=1,
192
+ index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
+ ).reshape(-1, 3),
194
+ torch.gather(
195
+ input=idx_map[num_triangles == 2],
196
+ dim=1,
197
+ index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
+ ).reshape(-1, 3),
199
+ ),
200
+ dim=0,
201
+ )
202
+
203
+ return verts, faces
204
+
205
+ def forward(
206
+ self,
207
+ level: Float[Tensor, "N3 1"],
208
+ deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
+ ) -> Mesh:
210
+ if deformation is not None:
211
+ grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
+ deformation
213
+ )
214
+ else:
215
+ grid_vertices = self.grid_vertices
216
+
217
+ v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
+
219
+ mesh = Mesh(
220
+ v_pos=v_pos,
221
+ t_pos_idx=t_pos_idx,
222
+ # extras
223
+ grid_vertices=grid_vertices,
224
+ tet_edges=self.all_edges,
225
+ grid_level=level,
226
+ grid_deformation=deformation,
227
+ )
228
+
229
+ return mesh
sf3d/models/mesh.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Any, Dict, Optional
5
+
6
+ import gpytoolbox
7
+ import numpy as np
8
+ import pynanoinstantmeshes
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import trimesh
12
+ from jaxtyping import Float, Integer
13
+ from torch import Tensor
14
+
15
+ from sf3d.models.utils import dot
16
+
17
+ try:
18
+ from uv_unwrapper import Unwrapper
19
+ except ImportError:
20
+ import logging
21
+
22
+ logging.warning(
23
+ "Could not import uv_unwrapper. Please install it via `pip install uv_unwrapper/`"
24
+ )
25
+ # Exit early to avoid further errors
26
+ raise ImportError("uv_unwrapper not found")
27
+
28
+
29
+ class Mesh:
30
+ def __init__(
31
+ self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
32
+ ) -> None:
33
+ self.v_pos: Float[Tensor, "Nv 3"] = v_pos
34
+ self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
35
+ self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
36
+ self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
37
+ self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
38
+ self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
39
+ self.extras: Dict[str, Any] = {}
40
+ for k, v in kwargs.items():
41
+ self.add_extra(k, v)
42
+
43
+ self.unwrapper = Unwrapper()
44
+
45
+ def add_extra(self, k, v) -> None:
46
+ self.extras[k] = v
47
+
48
+ @property
49
+ def requires_grad(self):
50
+ return self.v_pos.requires_grad
51
+
52
+ @property
53
+ def v_nrm(self):
54
+ if self._v_nrm is None:
55
+ self._v_nrm = self._compute_vertex_normal()
56
+ return self._v_nrm
57
+
58
+ @property
59
+ def v_tng(self):
60
+ if self._v_tng is None:
61
+ self._v_tng = self._compute_vertex_tangent()
62
+ return self._v_tng
63
+
64
+ @property
65
+ def v_tex(self):
66
+ if self._v_tex is None:
67
+ self.unwrap_uv()
68
+ return self._v_tex
69
+
70
+ @property
71
+ def edges(self):
72
+ if self._edges is None:
73
+ self._edges = self._compute_edges()
74
+ return self._edges
75
+
76
+ def _compute_vertex_normal(self):
77
+ i0 = self.t_pos_idx[:, 0]
78
+ i1 = self.t_pos_idx[:, 1]
79
+ i2 = self.t_pos_idx[:, 2]
80
+
81
+ v0 = self.v_pos[i0, :]
82
+ v1 = self.v_pos[i1, :]
83
+ v2 = self.v_pos[i2, :]
84
+
85
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
86
+
87
+ # Splat face normals to vertices
88
+ v_nrm = torch.zeros_like(self.v_pos)
89
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
90
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
91
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
92
+
93
+ # Normalize, replace zero (degenerated) normals with some default value
94
+ v_nrm = torch.where(
95
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
96
+ )
97
+ v_nrm = F.normalize(v_nrm, dim=1)
98
+
99
+ if torch.is_anomaly_enabled():
100
+ assert torch.all(torch.isfinite(v_nrm))
101
+
102
+ return v_nrm
103
+
104
+ def _compute_vertex_tangent(self):
105
+ vn_idx = [None] * 3
106
+ pos = [None] * 3
107
+ tex = [None] * 3
108
+ for i in range(0, 3):
109
+ pos[i] = self.v_pos[self.t_pos_idx[:, i]]
110
+ tex[i] = self.v_tex[self.t_pos_idx[:, i]]
111
+ # t_nrm_idx is always the same as t_pos_idx
112
+ vn_idx[i] = self.t_pos_idx[:, i]
113
+
114
+ tangents = torch.zeros_like(self.v_nrm)
115
+ tansum = torch.zeros_like(self.v_nrm)
116
+
117
+ # Compute tangent space for each triangle
118
+ duv1 = tex[1] - tex[0]
119
+ duv2 = tex[2] - tex[0]
120
+ dpos1 = pos[1] - pos[0]
121
+ dpos2 = pos[2] - pos[0]
122
+
123
+ tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
124
+
125
+ denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
126
+
127
+ # Avoid division by zero for degenerated texture coordinates
128
+ denom_safe = denom.clip(1e-6)
129
+ tang = tng_nom / denom_safe
130
+
131
+ # Update all 3 vertices
132
+ for i in range(0, 3):
133
+ idx = vn_idx[i][:, None].repeat(1, 3)
134
+ tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
135
+ tansum.scatter_add_(
136
+ 0, idx, torch.ones_like(tang)
137
+ ) # tansum[n_i] = tansum[n_i] + 1
138
+ # Also normalize it. Here we do not normalize the individual triangles first so larger area
139
+ # triangles influence the tangent space more
140
+ tangents = tangents / tansum
141
+
142
+ # Normalize and make sure tangent is perpendicular to normal
143
+ tangents = F.normalize(tangents, dim=1)
144
+ tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
145
+
146
+ if torch.is_anomaly_enabled():
147
+ assert torch.all(torch.isfinite(tangents))
148
+
149
+ return tangents
150
+
151
+ def quad_remesh(
152
+ self,
153
+ quad_vertex_count: int = -1,
154
+ quad_rosy: int = 4,
155
+ quad_crease_angle: float = -1.0,
156
+ quad_smooth_iter: int = 2,
157
+ quad_align_to_boundaries: bool = False,
158
+ ) -> Mesh:
159
+ if quad_vertex_count < 0:
160
+ quad_vertex_count = self.v_pos.shape[0]
161
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
162
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.uint32)
163
+
164
+ new_vert, new_faces = pynanoinstantmeshes.remesh(
165
+ v_pos,
166
+ t_pos_idx,
167
+ quad_vertex_count // 4,
168
+ rosy=quad_rosy,
169
+ posy=4,
170
+ creaseAngle=quad_crease_angle,
171
+ align_to_boundaries=quad_align_to_boundaries,
172
+ smooth_iter=quad_smooth_iter,
173
+ deterministic=False,
174
+ )
175
+
176
+ # Briefly load in trimesh
177
+ mesh = trimesh.Trimesh(vertices=new_vert, faces=new_faces.astype(np.int32))
178
+
179
+ v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos).contiguous()
180
+ t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx).contiguous()
181
+
182
+ # Create new mesh
183
+ return Mesh(v_pos, t_pos_idx)
184
+
185
+ def triangle_remesh(
186
+ self,
187
+ triangle_average_edge_length_multiplier: Optional[float] = None,
188
+ triangle_remesh_steps: int = 10,
189
+ triangle_vertex_count=-1,
190
+ ):
191
+ if triangle_vertex_count > 0:
192
+ reduction = triangle_vertex_count / self.v_pos.shape[0]
193
+ print("Triangle reduction:", reduction)
194
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float32)
195
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
196
+ if reduction > 1.0:
197
+ subdivide_iters = int(math.ceil(math.log(reduction) / math.log(2)))
198
+ print("Subdivide iters:", subdivide_iters)
199
+ v_pos, t_pos_idx = gpytoolbox.subdivide(
200
+ v_pos,
201
+ t_pos_idx,
202
+ iters=subdivide_iters,
203
+ )
204
+ reduction = triangle_vertex_count / v_pos.shape[0]
205
+
206
+ # Simplify
207
+ points_out, faces_out, _, _ = gpytoolbox.decimate(
208
+ v_pos,
209
+ t_pos_idx,
210
+ face_ratio=reduction,
211
+ )
212
+
213
+ # Convert back to torch
214
+ self.v_pos = torch.from_numpy(points_out).to(self.v_pos)
215
+ self.t_pos_idx = torch.from_numpy(faces_out).to(self.t_pos_idx)
216
+ self._edges = None
217
+ triangle_average_edge_length_multiplier = None
218
+
219
+ edges = self.edges
220
+ if triangle_average_edge_length_multiplier is None:
221
+ h = None
222
+ else:
223
+ h = float(
224
+ torch.linalg.norm(
225
+ self.v_pos[edges[:, 0]] - self.v_pos[edges[:, 1]], dim=1
226
+ )
227
+ .mean()
228
+ .item()
229
+ * triangle_average_edge_length_multiplier
230
+ )
231
+
232
+ # Convert to numpy
233
+ v_pos = self.v_pos.detach().cpu().numpy().astype(np.float64)
234
+ t_pos_idx = self.t_pos_idx.detach().cpu().numpy().astype(np.int32)
235
+
236
+ # Remesh
237
+ v_remesh, f_remesh = gpytoolbox.remesh_botsch(
238
+ v_pos,
239
+ t_pos_idx,
240
+ triangle_remesh_steps,
241
+ h,
242
+ )
243
+
244
+ # Convert back to torch
245
+ v_pos = torch.from_numpy(v_remesh).to(self.v_pos).contiguous()
246
+ t_pos_idx = torch.from_numpy(f_remesh).to(self.t_pos_idx).contiguous()
247
+
248
+ # Create new mesh
249
+ return Mesh(v_pos, t_pos_idx)
250
+
251
+ @torch.no_grad()
252
+ def unwrap_uv(
253
+ self,
254
+ island_padding: float = 0.02,
255
+ ) -> Mesh:
256
+ uv, indices = self.unwrapper(
257
+ self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
258
+ )
259
+
260
+ # Do store per vertex UVs.
261
+ # This means we need to duplicate some vertices at the seams
262
+ individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
263
+ individual_faces = torch.arange(
264
+ individual_vertices.shape[0],
265
+ device=individual_vertices.device,
266
+ dtype=self.t_pos_idx.dtype,
267
+ ).reshape(-1, 3)
268
+ uv_flat = uv[indices].reshape((-1, 2))
269
+ # uv_flat[:, 1] = 1 - uv_flat[:, 1]
270
+
271
+ self.v_pos = individual_vertices
272
+ self.t_pos_idx = individual_faces
273
+ self._v_tex = uv_flat
274
+ self._v_nrm = self._compute_vertex_normal()
275
+ self._v_tng = self._compute_vertex_tangent()
276
+
277
+ def _compute_edges(self):
278
+ # Compute edges
279
+ edges = torch.cat(
280
+ [
281
+ self.t_pos_idx[:, [0, 1]],
282
+ self.t_pos_idx[:, [1, 2]],
283
+ self.t_pos_idx[:, [2, 0]],
284
+ ],
285
+ dim=0,
286
+ )
287
+ edges = edges.sort()[0]
288
+ edges = torch.unique(edges, dim=0)
289
+ return edges
sf3d/models/network.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from jaxtyping import Float
9
+ from torch import Tensor
10
+ from torch.amp import custom_bwd, custom_fwd
11
+ from torch.autograd import Function
12
+
13
+ from sf3d.models.utils import BaseModule, normalize
14
+ from sf3d.utils import get_device
15
+
16
+
17
+ def conditional_decorator(decorator_with_args, condition, *args, **kwargs):
18
+ def wrapper(fn):
19
+ if condition:
20
+ if len(kwargs) == 0:
21
+ return decorator_with_args
22
+ return decorator_with_args(*args, **kwargs)(fn)
23
+ else:
24
+ return fn
25
+
26
+ return wrapper
27
+
28
+
29
+ class PixelShuffleUpsampleNetwork(BaseModule):
30
+ @dataclass
31
+ class Config(BaseModule.Config):
32
+ in_channels: int = 1024
33
+ out_channels: int = 40
34
+ scale_factor: int = 4
35
+
36
+ conv_layers: int = 4
37
+ conv_kernel_size: int = 3
38
+
39
+ cfg: Config
40
+
41
+ def configure(self) -> None:
42
+ layers = []
43
+ output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
44
+
45
+ in_channels = self.cfg.in_channels
46
+ for i in range(self.cfg.conv_layers):
47
+ cur_out_channels = (
48
+ in_channels if i != self.cfg.conv_layers - 1 else output_channels
49
+ )
50
+ layers.append(
51
+ nn.Conv2d(
52
+ in_channels,
53
+ cur_out_channels,
54
+ self.cfg.conv_kernel_size,
55
+ padding=(self.cfg.conv_kernel_size - 1) // 2,
56
+ )
57
+ )
58
+ if i != self.cfg.conv_layers - 1:
59
+ layers.append(nn.ReLU(inplace=True))
60
+
61
+ layers.append(nn.PixelShuffle(self.cfg.scale_factor))
62
+
63
+ self.upsample = nn.Sequential(*layers)
64
+
65
+ def forward(
66
+ self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
67
+ ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
68
+ return rearrange(
69
+ self.upsample(
70
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
71
+ ),
72
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
73
+ Np=3,
74
+ )
75
+
76
+
77
+ class _TruncExp(Function): # pylint: disable=abstract-method
78
+ # Implementation from torch-ngp:
79
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
80
+ @staticmethod
81
+ @conditional_decorator(
82
+ custom_fwd,
83
+ "cuda" in get_device(),
84
+ cast_inputs=torch.float32,
85
+ device_type="cuda",
86
+ )
87
+ def forward(ctx, x): # pylint: disable=arguments-differ
88
+ ctx.save_for_backward(x)
89
+ return torch.exp(x)
90
+
91
+ @staticmethod
92
+ @conditional_decorator(custom_bwd, "cuda" in get_device())
93
+ def backward(ctx, g): # pylint: disable=arguments-differ
94
+ x = ctx.saved_tensors[0]
95
+ return g * torch.exp(torch.clamp(x, max=15))
96
+
97
+
98
+ trunc_exp = _TruncExp.apply
99
+
100
+
101
+ def get_activation(name) -> Callable:
102
+ if name is None:
103
+ return lambda x: x
104
+ name = name.lower()
105
+ if name == "none" or name == "linear" or name == "identity":
106
+ return lambda x: x
107
+ elif name == "lin2srgb":
108
+ return lambda x: torch.where(
109
+ x > 0.0031308,
110
+ torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
111
+ 12.92 * x,
112
+ ).clamp(0.0, 1.0)
113
+ elif name == "exp":
114
+ return lambda x: torch.exp(x)
115
+ elif name == "shifted_exp":
116
+ return lambda x: torch.exp(x - 1.0)
117
+ elif name == "trunc_exp":
118
+ return trunc_exp
119
+ elif name == "shifted_trunc_exp":
120
+ return lambda x: trunc_exp(x - 1.0)
121
+ elif name == "sigmoid":
122
+ return lambda x: torch.sigmoid(x)
123
+ elif name == "tanh":
124
+ return lambda x: torch.tanh(x)
125
+ elif name == "shifted_softplus":
126
+ return lambda x: F.softplus(x - 1.0)
127
+ elif name == "scale_-11_01":
128
+ return lambda x: x * 0.5 + 0.5
129
+ elif name == "negative":
130
+ return lambda x: -x
131
+ elif name == "normalize_channel_last":
132
+ return lambda x: normalize(x)
133
+ elif name == "normalize_channel_first":
134
+ return lambda x: normalize(x, dim=1)
135
+ else:
136
+ try:
137
+ return getattr(F, name)
138
+ except AttributeError:
139
+ raise ValueError(f"Unknown activation function: {name}")
140
+
141
+
142
+ @dataclass
143
+ class HeadSpec:
144
+ name: str
145
+ out_channels: int
146
+ n_hidden_layers: int
147
+ output_activation: Optional[str] = None
148
+ out_bias: float = 0.0
149
+
150
+
151
+ class MaterialMLP(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ in_channels: int = 120
155
+ n_neurons: int = 64
156
+ activation: str = "silu"
157
+ heads: List[HeadSpec] = field(default_factory=lambda: [])
158
+
159
+ cfg: Config
160
+
161
+ def configure(self) -> None:
162
+ assert len(self.cfg.heads) > 0
163
+ heads = {}
164
+ for head in self.cfg.heads:
165
+ head_layers = []
166
+ for i in range(head.n_hidden_layers):
167
+ head_layers += [
168
+ nn.Linear(
169
+ self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
170
+ self.cfg.n_neurons,
171
+ ),
172
+ self.make_activation(self.cfg.activation),
173
+ ]
174
+ head_layers += [
175
+ nn.Linear(
176
+ self.cfg.n_neurons,
177
+ head.out_channels,
178
+ ),
179
+ ]
180
+ heads[head.name] = nn.Sequential(*head_layers)
181
+ self.heads = nn.ModuleDict(heads)
182
+
183
+ def make_activation(self, activation):
184
+ if activation == "relu":
185
+ return nn.ReLU(inplace=True)
186
+ elif activation == "silu":
187
+ return nn.SiLU(inplace=True)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ def keys(self):
192
+ return self.heads.keys()
193
+
194
+ def forward(
195
+ self, x, include: Optional[List] = None, exclude: Optional[List] = None
196
+ ):
197
+ if include is not None and exclude is not None:
198
+ raise ValueError("Cannot specify both include and exclude.")
199
+ if include is not None:
200
+ heads = [h for h in self.cfg.heads if h.name in include]
201
+ elif exclude is not None:
202
+ heads = [h for h in self.cfg.heads if h.name not in exclude]
203
+ else:
204
+ heads = self.cfg.heads
205
+
206
+ out = {
207
+ head.name: get_activation(head.output_activation)(
208
+ self.heads[head.name](x) + head.out_bias
209
+ )
210
+ for head in heads
211
+ }
212
+
213
+ return out
sf3d/models/tokenizers/dinov2.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
+ from transformers.pytorch_utils import (
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.utils.backbone_utils import BackboneMixin
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ # General docstring
52
+ _CONFIG_FOR_DOC = "Dinov2Config"
53
+
54
+ # Base docstring
55
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
+
61
+
62
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "facebook/dinov2-base",
64
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
+ ]
66
+
67
+
68
+ class Dinov2Embeddings(nn.Module):
69
+ """
70
+ Construct the CLS token, mask token, position and patch embeddings.
71
+ """
72
+
73
+ def __init__(self, config: Dinov2Config) -> None:
74
+ super().__init__()
75
+
76
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
+ # register as mask token as it's not used in optimization
78
+ # to avoid the use of find_unused_parameters_true
79
+ # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
+ num_patches = self.patch_embeddings.num_patches
83
+ self.position_embeddings = nn.Parameter(
84
+ torch.randn(1, num_patches + 1, config.hidden_size)
85
+ )
86
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
+ self.config = config
88
+
89
+ def interpolate_pos_encoding(
90
+ self, embeddings: torch.Tensor, height: int, width: int
91
+ ) -> torch.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ num_patches = embeddings.shape[1] - 1
101
+ num_positions = self.position_embeddings.shape[1] - 1
102
+ if num_patches == num_positions and height == width:
103
+ return self.position_embeddings
104
+ class_pos_embed = self.position_embeddings[:, 0]
105
+ patch_pos_embed = self.position_embeddings[:, 1:]
106
+ dim = embeddings.shape[-1]
107
+ height = height // self.config.patch_size
108
+ width = width // self.config.patch_size
109
+ # we add a small number to avoid floating point error in the interpolation
110
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
111
+ height, width = height + 0.1, width + 0.1
112
+ patch_pos_embed = patch_pos_embed.reshape(
113
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
+ )
115
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed,
118
+ scale_factor=(
119
+ height / math.sqrt(num_positions),
120
+ width / math.sqrt(num_positions),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ )
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self,
137
+ pixel_values: torch.Tensor,
138
+ bool_masked_pos: Optional[torch.Tensor] = None,
139
+ ) -> torch.Tensor:
140
+ batch_size, _, height, width = pixel_values.shape
141
+ patch_embeddings = self.patch_embeddings(pixel_values)
142
+ embeddings = patch_embeddings
143
+
144
+ if bool_masked_pos is not None:
145
+ embeddings = torch.where(
146
+ bool_masked_pos.unsqueeze(-1),
147
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
+ embeddings,
149
+ )
150
+
151
+ # add the [CLS] token to the embedded patch tokens
152
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
+
155
+ # add positional encoding to each token
156
+ embeddings = embeddings + self.interpolate_pos_encoding(
157
+ embeddings, height, width
158
+ )
159
+
160
+ embeddings = self.dropout(embeddings)
161
+
162
+ return embeddings
163
+
164
+
165
+ class Dinov2PatchEmbeddings(nn.Module):
166
+ """
167
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
+ Transformer.
170
+ """
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ image_size, patch_size = config.image_size, config.patch_size
175
+ num_channels, hidden_size = config.num_channels, config.hidden_size
176
+
177
+ image_size = (
178
+ image_size
179
+ if isinstance(image_size, collections.abc.Iterable)
180
+ else (image_size, image_size)
181
+ )
182
+ patch_size = (
183
+ patch_size
184
+ if isinstance(patch_size, collections.abc.Iterable)
185
+ else (patch_size, patch_size)
186
+ )
187
+ num_patches = (image_size[1] // patch_size[1]) * (
188
+ image_size[0] // patch_size[0]
189
+ )
190
+ self.image_size = image_size
191
+ self.patch_size = patch_size
192
+ self.num_channels = num_channels
193
+ self.num_patches = num_patches
194
+
195
+ self.projection = nn.Conv2d(
196
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
+ )
198
+
199
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
+ """
201
+ num_channels = pixel_values.shape[1]
202
+ if num_channels != self.num_channels:
203
+ raise ValueError(
204
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
+ f" Expected {self.num_channels} but got {num_channels}."
206
+ )
207
+ """
208
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
+ return embeddings
210
+
211
+
212
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
+ class Dinov2SelfAttention(nn.Module):
214
+ def __init__(self, config: Dinov2Config) -> None:
215
+ super().__init__()
216
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
+ config, "embedding_size"
218
+ ):
219
+ raise ValueError(
220
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
+ f"heads {config.num_attention_heads}."
222
+ )
223
+
224
+ self.num_attention_heads = config.num_attention_heads
225
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
227
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
+
229
+ self.query = nn.Linear(
230
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
+ )
232
+ self.key = nn.Linear(
233
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
+ )
235
+ self.value = nn.Linear(
236
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
+ )
238
+
239
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
+
241
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
+ new_x_shape = x.size()[:-1] + (
243
+ self.num_attention_heads,
244
+ self.attention_head_size,
245
+ )
246
+ x = x.view(new_x_shape)
247
+ return x.permute(0, 2, 1, 3)
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states,
252
+ head_mask: Optional[torch.Tensor] = None,
253
+ output_attentions: bool = False,
254
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
+ mixed_query_layer = self.query(hidden_states)
256
+
257
+ if hasattr(F, "scaled_dot_product_attention"):
258
+ assert head_mask is None and not output_attentions
259
+ new_size = hidden_states.size()[:-1] + (
260
+ self.num_attention_heads,
261
+ self.attention_head_size,
262
+ )
263
+ key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
+ value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
+ query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
+ context_layer = F.scaled_dot_product_attention(
267
+ query_layer,
268
+ key_layer,
269
+ value_layer,
270
+ dropout_p=self.attention_probs_dropout_prob,
271
+ is_causal=False,
272
+ )
273
+ context_layer = context_layer.transpose(1, 2).reshape(
274
+ *hidden_states.size()[:-1], -1
275
+ )
276
+ else:
277
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
278
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
279
+ query_layer = self.transpose_for_scores(mixed_query_layer)
280
+
281
+ # Take the dot product between "query" and "key" to get the raw attention scores.
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+
284
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
+
286
+ # Normalize the attention scores to probabilities.
287
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
+
289
+ # This is actually dropping out entire tokens to attend to, which might
290
+ # seem a bit unusual, but is taken from the original Transformer paper.
291
+ attention_probs = self.dropout(attention_probs)
292
+
293
+ # Mask heads if we want to
294
+ if head_mask is not None:
295
+ attention_probs = attention_probs * head_mask
296
+
297
+ context_layer = torch.matmul(attention_probs, value_layer)
298
+
299
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
+ context_layer = context_layer.view(new_context_layer_shape)
302
+
303
+ outputs = (
304
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
305
+ )
306
+
307
+ return outputs
308
+
309
+
310
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
+ class Dinov2SelfOutput(nn.Module):
312
+ """
313
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
+ layernorm applied before each block.
315
+ """
316
+
317
+ def __init__(self, config: Dinov2Config) -> None:
318
+ super().__init__()
319
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(
323
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
+ ) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+
328
+ return hidden_states
329
+
330
+
331
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
+ class Dinov2Attention(nn.Module):
333
+ def __init__(self, config: Dinov2Config) -> None:
334
+ super().__init__()
335
+ self.attention = Dinov2SelfAttention(config)
336
+ self.output = Dinov2SelfOutput(config)
337
+ self.pruned_heads = set()
338
+
339
+ def prune_heads(self, heads: Set[int]) -> None:
340
+ if len(heads) == 0:
341
+ return
342
+ heads, index = find_pruneable_heads_and_indices(
343
+ heads,
344
+ self.attention.num_attention_heads,
345
+ self.attention.attention_head_size,
346
+ self.pruned_heads,
347
+ )
348
+
349
+ # Prune linear layers
350
+ self.attention.query = prune_linear_layer(self.attention.query, index)
351
+ self.attention.key = prune_linear_layer(self.attention.key, index)
352
+ self.attention.value = prune_linear_layer(self.attention.value, index)
353
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
+
355
+ # Update hyper params and store pruned heads
356
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
+ heads
358
+ )
359
+ self.attention.all_head_size = (
360
+ self.attention.attention_head_size * self.attention.num_attention_heads
361
+ )
362
+ self.pruned_heads = self.pruned_heads.union(heads)
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ head_mask: Optional[torch.Tensor] = None,
368
+ output_attentions: bool = False,
369
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
+
372
+ attention_output = self.output(self_outputs[0], hidden_states)
373
+
374
+ outputs = (attention_output,) + self_outputs[
375
+ 1:
376
+ ] # add attentions if we output them
377
+ return outputs
378
+
379
+
380
+ class Dinov2LayerScale(nn.Module):
381
+ def __init__(self, config) -> None:
382
+ super().__init__()
383
+ self.lambda1 = nn.Parameter(
384
+ config.layerscale_value * torch.ones(config.hidden_size)
385
+ )
386
+
387
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
+ return hidden_state * self.lambda1
389
+
390
+
391
+ # Copied from transformers.models.beit.modeling_beit.drop_path
392
+ def drop_path(
393
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
+ ) -> torch.Tensor:
395
+ """
396
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
+
398
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
+ argument.
403
+ """
404
+ if drop_prob == 0.0 or not training:
405
+ return input
406
+ keep_prob = 1 - drop_prob
407
+ shape = (input.shape[0],) + (1,) * (
408
+ input.ndim - 1
409
+ ) # work with diff dim tensors, not just 2D ConvNets
410
+ random_tensor = keep_prob + torch.rand(
411
+ shape, dtype=input.dtype, device=input.device
412
+ )
413
+ random_tensor.floor_() # binarize
414
+ output = input.div(keep_prob) * random_tensor
415
+ return output
416
+
417
+
418
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
+ class Dinov2DropPath(nn.Module):
420
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
+
422
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
423
+ super().__init__()
424
+ self.drop_prob = drop_prob
425
+
426
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
+ return drop_path(hidden_states, self.drop_prob, self.training)
428
+
429
+ def extra_repr(self) -> str:
430
+ return "p={}".format(self.drop_prob)
431
+
432
+
433
+ class Dinov2MLP(nn.Module):
434
+ def __init__(self, config) -> None:
435
+ super().__init__()
436
+ in_features = out_features = config.hidden_size
437
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
438
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
+ if isinstance(config.hidden_act, str):
440
+ self.activation = ACT2FN[config.hidden_act]
441
+ else:
442
+ self.activation = config.hidden_act
443
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
+
445
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
+ hidden_state = self.fc1(hidden_state)
447
+ hidden_state = self.activation(hidden_state)
448
+ hidden_state = self.fc2(hidden_state)
449
+ return hidden_state
450
+
451
+
452
+ class Dinov2SwiGLUFFN(nn.Module):
453
+ def __init__(self, config) -> None:
454
+ super().__init__()
455
+ in_features = out_features = config.hidden_size
456
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
457
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
+
459
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
+
462
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
+ hidden_state = self.weights_in(hidden_state)
464
+ x1, x2 = hidden_state.chunk(2, dim=-1)
465
+ hidden = nn.functional.silu(x1) * x2
466
+ return self.weights_out(hidden)
467
+
468
+
469
+ class Dinov2Layer(nn.Module):
470
+ """This corresponds to the Block class in the original implementation."""
471
+
472
+ def __init__(self, config: Dinov2Config) -> None:
473
+ super().__init__()
474
+
475
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.norm1_modulation = None
477
+ self.attention = Dinov2Attention(config)
478
+ self.layer_scale1 = Dinov2LayerScale(config)
479
+ self.drop_path1 = (
480
+ Dinov2DropPath(config.drop_path_rate)
481
+ if config.drop_path_rate > 0.0
482
+ else nn.Identity()
483
+ )
484
+
485
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
+ self.norm2_modulation = None
487
+
488
+ if config.use_swiglu_ffn:
489
+ self.mlp = Dinov2SwiGLUFFN(config)
490
+ else:
491
+ self.mlp = Dinov2MLP(config)
492
+ self.layer_scale2 = Dinov2LayerScale(config)
493
+ self.drop_path2 = (
494
+ Dinov2DropPath(config.drop_path_rate)
495
+ if config.drop_path_rate > 0.0
496
+ else nn.Identity()
497
+ )
498
+
499
+ def forward(
500
+ self,
501
+ hidden_states: torch.Tensor,
502
+ head_mask: Optional[torch.Tensor] = None,
503
+ modulation_cond: Optional[torch.Tensor] = None,
504
+ output_attentions: bool = False,
505
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
+ hidden_states_norm = self.norm1(hidden_states)
507
+ if self.norm1_modulation is not None:
508
+ assert modulation_cond is not None
509
+ hidden_states_norm = self.norm1_modulation(
510
+ hidden_states_norm, modulation_cond
511
+ )
512
+ self_attention_outputs = self.attention(
513
+ hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
+ head_mask,
515
+ output_attentions=output_attentions,
516
+ )
517
+ attention_output = self_attention_outputs[0]
518
+
519
+ attention_output = self.layer_scale1(attention_output)
520
+ outputs = self_attention_outputs[
521
+ 1:
522
+ ] # add self attentions if we output attention weights
523
+
524
+ # first residual connection
525
+ hidden_states = attention_output + hidden_states
526
+
527
+ # in Dinov2, layernorm is also applied after self-attention
528
+ layer_output = self.norm2(hidden_states)
529
+ if self.norm2_modulation is not None:
530
+ assert modulation_cond is not None
531
+ layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
+ layer_output = self.mlp(layer_output)
533
+ layer_output = self.layer_scale2(layer_output)
534
+
535
+ # second residual connection
536
+ layer_output = layer_output + hidden_states
537
+
538
+ outputs = (layer_output,) + outputs
539
+
540
+ return outputs
541
+
542
+ def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
+ self.norm1_modulation = norm1_mod
544
+ self.norm2_modulation = norm2_mod
545
+
546
+
547
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
+ class Dinov2Encoder(nn.Module):
549
+ def __init__(self, config: Dinov2Config) -> None:
550
+ super().__init__()
551
+ self.config = config
552
+ self.layer = nn.ModuleList(
553
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
+ )
555
+ self.gradient_checkpointing = False
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ modulation_cond: Optional[torch.Tensor] = None,
562
+ output_attentions: bool = False,
563
+ output_hidden_states: bool = False,
564
+ return_dict: bool = True,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ all_hidden_states = () if output_hidden_states else None
567
+ all_self_attentions = () if output_attentions else None
568
+
569
+ for i, layer_module in enumerate(self.layer):
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ layer_head_mask = head_mask[i] if head_mask is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ def create_custom_forward(module):
578
+ def custom_forward(*inputs):
579
+ return module(*inputs, output_attentions)
580
+
581
+ return custom_forward
582
+
583
+ layer_outputs = torch.utils.checkpoint.checkpoint(
584
+ create_custom_forward(layer_module),
585
+ hidden_states,
586
+ layer_head_mask,
587
+ modulation_cond,
588
+ use_reentrant=False,
589
+ )
590
+ else:
591
+ layer_outputs = layer_module(
592
+ hidden_states, layer_head_mask, modulation_cond, output_attentions
593
+ )
594
+
595
+ hidden_states = layer_outputs[0]
596
+
597
+ if output_attentions:
598
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
+
600
+ if output_hidden_states:
601
+ all_hidden_states = all_hidden_states + (hidden_states,)
602
+
603
+ if not return_dict:
604
+ return tuple(
605
+ v
606
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
607
+ if v is not None
608
+ )
609
+ return BaseModelOutput(
610
+ last_hidden_state=hidden_states,
611
+ hidden_states=all_hidden_states,
612
+ attentions=all_self_attentions,
613
+ )
614
+
615
+
616
+ class Dinov2PreTrainedModel(PreTrainedModel):
617
+ """
618
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
+ models.
620
+ """
621
+
622
+ config_class = Dinov2Config
623
+ base_model_prefix = "dinov2"
624
+ main_input_name = "pixel_values"
625
+ supports_gradient_checkpointing = True
626
+
627
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
+ """Initialize the weights"""
629
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
630
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
+ # `trunc_normal_cpu` not implemented in `half` issues
632
+ module.weight.data = nn.init.trunc_normal_(
633
+ module.weight.data.to(torch.float32),
634
+ mean=0.0,
635
+ std=self.config.initializer_range,
636
+ ).to(module.weight.dtype)
637
+ if module.bias is not None:
638
+ module.bias.data.zero_()
639
+ elif isinstance(module, nn.LayerNorm):
640
+ module.bias.data.zero_()
641
+ module.weight.data.fill_(1.0)
642
+ elif isinstance(module, Dinov2Embeddings):
643
+ module.position_embeddings.data = nn.init.trunc_normal_(
644
+ module.position_embeddings.data.to(torch.float32),
645
+ mean=0.0,
646
+ std=self.config.initializer_range,
647
+ ).to(module.position_embeddings.dtype)
648
+
649
+ module.cls_token.data = nn.init.trunc_normal_(
650
+ module.cls_token.data.to(torch.float32),
651
+ mean=0.0,
652
+ std=self.config.initializer_range,
653
+ ).to(module.cls_token.dtype)
654
+
655
+ def _set_gradient_checkpointing(
656
+ self, module: Dinov2Encoder, value: bool = False
657
+ ) -> None:
658
+ if isinstance(module, Dinov2Encoder):
659
+ module.gradient_checkpointing = value
660
+
661
+
662
+ DINOV2_START_DOCSTRING = r"""
663
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
+ behavior.
666
+
667
+ Parameters:
668
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
+ [`BitImageProcessor.preprocess`] for details.
678
+
679
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
+ pre-training.
682
+
683
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
+
686
+ - 1 indicates the head is **not masked**,
687
+ - 0 indicates the head is **masked**.
688
+
689
+ output_attentions (`bool`, *optional*):
690
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
+ tensors for more detail.
692
+ output_hidden_states (`bool`, *optional*):
693
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
+ more detail.
695
+ return_dict (`bool`, *optional*):
696
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
+ """
698
+
699
+ DINOV2_INPUTS_DOCSTRING = r"""
700
+ Args:
701
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
+ [`BitImageProcessor.preprocess`] for details.
704
+
705
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
+
708
+ - 1 indicates the head is **not masked**,
709
+ - 0 indicates the head is **masked**.
710
+
711
+ output_attentions (`bool`, *optional*):
712
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
+ tensors for more detail.
714
+ output_hidden_states (`bool`, *optional*):
715
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
+ more detail.
717
+ return_dict (`bool`, *optional*):
718
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
+ """
720
+
721
+
722
+ @dataclass
723
+ class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
+ patch_embeddings: Optional[torch.FloatTensor] = None
725
+
726
+
727
+ @add_start_docstrings(
728
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
+ DINOV2_START_DOCSTRING,
730
+ )
731
+ class Dinov2Model(Dinov2PreTrainedModel):
732
+ def __init__(self, config: Dinov2Config):
733
+ super().__init__(config)
734
+ self.config = config
735
+
736
+ self.embeddings = Dinov2Embeddings(config)
737
+ self.encoder = Dinov2Encoder(config)
738
+
739
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
+
741
+ # Initialize weights and apply final processing
742
+ self.post_init()
743
+
744
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
+ return self.embeddings.patch_embeddings
746
+
747
+ def expand_input_channels(self, extra_input_channels: int) -> None:
748
+ if extra_input_channels == 0:
749
+ return
750
+ conv_old = self.embeddings.patch_embeddings.projection
751
+ conv_new = nn.Conv2d(
752
+ self.config.num_channels + extra_input_channels,
753
+ self.config.hidden_size,
754
+ kernel_size=self.config.patch_size,
755
+ stride=self.config.patch_size,
756
+ ).to(self.device)
757
+ with torch.no_grad():
758
+ conv_new.weight[:, :3] = conv_old.weight
759
+ conv_new.bias = conv_old.bias
760
+ self.embeddings.patch_embeddings.projection = conv_new
761
+ del conv_old
762
+
763
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
+ """
765
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
+ class PreTrainedModel
767
+ """
768
+ for layer, heads in heads_to_prune.items():
769
+ self.encoder.layer[layer].attention.prune_heads(heads)
770
+
771
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
+ @add_code_sample_docstrings(
773
+ checkpoint=_CHECKPOINT_FOR_DOC,
774
+ output_type=BaseModelOutputWithPooling,
775
+ config_class=_CONFIG_FOR_DOC,
776
+ modality="vision",
777
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
778
+ )
779
+ def forward(
780
+ self,
781
+ pixel_values: Optional[torch.Tensor] = None,
782
+ bool_masked_pos: Optional[torch.Tensor] = None,
783
+ head_mask: Optional[torch.Tensor] = None,
784
+ modulation_cond: Optional[torch.Tensor] = None,
785
+ output_attentions: Optional[bool] = None,
786
+ output_hidden_states: Optional[bool] = None,
787
+ return_dict: Optional[bool] = None,
788
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
+ output_attentions = (
790
+ output_attentions
791
+ if output_attentions is not None
792
+ else self.config.output_attentions
793
+ )
794
+ output_hidden_states = (
795
+ output_hidden_states
796
+ if output_hidden_states is not None
797
+ else self.config.output_hidden_states
798
+ )
799
+ return_dict = (
800
+ return_dict if return_dict is not None else self.config.use_return_dict
801
+ )
802
+
803
+ if pixel_values is None:
804
+ raise ValueError("You have to specify pixel_values")
805
+
806
+ # Prepare head mask if needed
807
+ # 1.0 in head_mask indicate we keep the head
808
+ # attention_probs has shape bsz x n_heads x N x N
809
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
+
813
+ embedding_output = self.embeddings(
814
+ pixel_values, bool_masked_pos=bool_masked_pos
815
+ )
816
+
817
+ encoder_outputs = self.encoder(
818
+ embedding_output,
819
+ head_mask=head_mask,
820
+ modulation_cond=modulation_cond,
821
+ output_attentions=output_attentions,
822
+ output_hidden_states=output_hidden_states,
823
+ return_dict=return_dict,
824
+ )
825
+ sequence_output = encoder_outputs[0]
826
+ sequence_output = self.layernorm(sequence_output)
827
+ pooled_output = sequence_output[:, 0, :]
828
+
829
+ if not return_dict:
830
+ head_outputs = (sequence_output, pooled_output)
831
+ return head_outputs + encoder_outputs[1:]
832
+
833
+ return CustomBaseModelOutputWithPooling(
834
+ last_hidden_state=sequence_output,
835
+ pooler_output=pooled_output,
836
+ hidden_states=encoder_outputs.hidden_states,
837
+ attentions=encoder_outputs.attentions,
838
+ patch_embeddings=embedding_output,
839
+ )
840
+
841
+ def set_gradient_checkpointing(self, value: bool = False) -> None:
842
+ self._set_gradient_checkpointing(self.encoder, value)
843
+
844
+
845
+ @add_start_docstrings(
846
+ """
847
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
+ of the [CLS] token) e.g. for ImageNet.
849
+ """,
850
+ DINOV2_START_DOCSTRING,
851
+ )
852
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
+ def __init__(self, config: Dinov2Config) -> None:
854
+ super().__init__(config)
855
+
856
+ self.num_labels = config.num_labels
857
+ self.dinov2 = Dinov2Model(config)
858
+
859
+ # Classifier head
860
+ self.classifier = (
861
+ nn.Linear(config.hidden_size * 2, config.num_labels)
862
+ if config.num_labels > 0
863
+ else nn.Identity()
864
+ )
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
+ @add_code_sample_docstrings(
871
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
+ output_type=ImageClassifierOutput,
873
+ config_class=_CONFIG_FOR_DOC,
874
+ )
875
+ def forward(
876
+ self,
877
+ pixel_values: Optional[torch.Tensor] = None,
878
+ head_mask: Optional[torch.Tensor] = None,
879
+ labels: Optional[torch.Tensor] = None,
880
+ output_attentions: Optional[bool] = None,
881
+ output_hidden_states: Optional[bool] = None,
882
+ return_dict: Optional[bool] = None,
883
+ ) -> Union[tuple, ImageClassifierOutput]:
884
+ r"""
885
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
+ """
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+
894
+ outputs = self.dinov2(
895
+ pixel_values,
896
+ head_mask=head_mask,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ return_dict=return_dict,
900
+ )
901
+
902
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
+
904
+ cls_token = sequence_output[:, 0]
905
+ patch_tokens = sequence_output[:, 1:]
906
+
907
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
+
909
+ logits = self.classifier(linear_input)
910
+
911
+ loss = None
912
+ if labels is not None:
913
+ # move labels to correct device to enable model parallelism
914
+ labels = labels.to(logits.device)
915
+ if self.config.problem_type is None:
916
+ if self.num_labels == 1:
917
+ self.config.problem_type = "regression"
918
+ elif self.num_labels > 1 and (
919
+ labels.dtype == torch.long or labels.dtype == torch.int
920
+ ):
921
+ self.config.problem_type = "single_label_classification"
922
+ else:
923
+ self.config.problem_type = "multi_label_classification"
924
+
925
+ if self.config.problem_type == "regression":
926
+ loss_fct = MSELoss()
927
+ if self.num_labels == 1:
928
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
929
+ else:
930
+ loss = loss_fct(logits, labels)
931
+ elif self.config.problem_type == "single_label_classification":
932
+ loss_fct = CrossEntropyLoss()
933
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
+ elif self.config.problem_type == "multi_label_classification":
935
+ loss_fct = BCEWithLogitsLoss()
936
+ loss = loss_fct(logits, labels)
937
+
938
+ if not return_dict:
939
+ output = (logits,) + outputs[2:]
940
+ return ((loss,) + output) if loss is not None else output
941
+
942
+ return ImageClassifierOutput(
943
+ loss=loss,
944
+ logits=logits,
945
+ hidden_states=outputs.hidden_states,
946
+ attentions=outputs.attentions,
947
+ )
948
+
949
+
950
+ @add_start_docstrings(
951
+ """
952
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
+ """,
954
+ DINOV2_START_DOCSTRING,
955
+ )
956
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
+ def __init__(self, config):
958
+ super().__init__(config)
959
+ super()._init_backbone(config)
960
+
961
+ self.num_features = [
962
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
+ ]
964
+ self.embeddings = Dinov2Embeddings(config)
965
+ self.encoder = Dinov2Encoder(config)
966
+
967
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
+ return self.embeddings.patch_embeddings
974
+
975
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
+ def forward(
978
+ self,
979
+ pixel_values: torch.Tensor,
980
+ output_hidden_states: Optional[bool] = None,
981
+ output_attentions: Optional[bool] = None,
982
+ return_dict: Optional[bool] = None,
983
+ ) -> BackboneOutput:
984
+ """
985
+ Returns:
986
+
987
+ Examples:
988
+
989
+ ```python
990
+ >>> from transformers import AutoImageProcessor, AutoBackbone
991
+ >>> import torch
992
+ >>> from PIL import Image
993
+ >>> import requests
994
+
995
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
+ >>> image = Image.open(requests.get(url, stream=True).raw)
997
+
998
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
+ >>> model = AutoBackbone.from_pretrained(
1000
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
+ ... )
1002
+
1003
+ >>> inputs = processor(image, return_tensors="pt")
1004
+
1005
+ >>> outputs = model(**inputs)
1006
+ >>> feature_maps = outputs.feature_maps
1007
+ >>> list(feature_maps[-1].shape)
1008
+ [1, 768, 16, 16]
1009
+ ```"""
1010
+ return_dict = (
1011
+ return_dict if return_dict is not None else self.config.use_return_dict
1012
+ )
1013
+ output_hidden_states = (
1014
+ output_hidden_states
1015
+ if output_hidden_states is not None
1016
+ else self.config.output_hidden_states
1017
+ )
1018
+ output_attentions = (
1019
+ output_attentions
1020
+ if output_attentions is not None
1021
+ else self.config.output_attentions
1022
+ )
1023
+
1024
+ embedding_output = self.embeddings(pixel_values)
1025
+
1026
+ outputs = self.encoder(
1027
+ embedding_output,
1028
+ output_hidden_states=True,
1029
+ output_attentions=output_attentions,
1030
+ return_dict=return_dict,
1031
+ )
1032
+
1033
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
+
1035
+ feature_maps = ()
1036
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
+ if stage in self.out_features:
1038
+ if self.config.apply_layernorm:
1039
+ hidden_state = self.layernorm(hidden_state)
1040
+ if self.config.reshape_hidden_states:
1041
+ batch_size, _, height, width = pixel_values.shape
1042
+ patch_size = self.config.patch_size
1043
+ hidden_state = hidden_state[:, 1:, :].reshape(
1044
+ batch_size, width // patch_size, height // patch_size, -1
1045
+ )
1046
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
+ feature_maps += (hidden_state,)
1048
+
1049
+ if not return_dict:
1050
+ if output_hidden_states:
1051
+ output = (feature_maps,) + outputs[1:]
1052
+ else:
1053
+ output = (feature_maps,) + outputs[2:]
1054
+ return output
1055
+
1056
+ return BackboneOutput(
1057
+ feature_maps=feature_maps,
1058
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
+ attentions=outputs.attentions if output_attentions else None,
1060
+ )
1061
+
1062
+
1063
+ class CustomPatchEmbeddings(nn.Module):
1064
+ """
1065
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
+ Transformer.
1068
+ """
1069
+
1070
+ def __init__(
1071
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
+ ):
1073
+ super().__init__()
1074
+
1075
+ image_size = (
1076
+ image_size
1077
+ if isinstance(image_size, collections.abc.Iterable)
1078
+ else (image_size, image_size)
1079
+ )
1080
+ patch_size = (
1081
+ patch_size
1082
+ if isinstance(patch_size, collections.abc.Iterable)
1083
+ else (patch_size, patch_size)
1084
+ )
1085
+ num_patches = (image_size[1] // patch_size[1]) * (
1086
+ image_size[0] // patch_size[0]
1087
+ )
1088
+ self.image_size = image_size
1089
+ self.patch_size = patch_size
1090
+ self.num_channels = num_channels
1091
+ self.num_patches = num_patches
1092
+
1093
+ self.projection = nn.Conv2d(
1094
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
+ )
1096
+
1097
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
+ num_channels = pixel_values.shape[1]
1099
+ if num_channels != self.num_channels:
1100
+ raise ValueError(
1101
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
+ f" Expected {self.num_channels} but got {num_channels}."
1103
+ )
1104
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
+ return embeddings
1106
+
1107
+
1108
+ class CustomEmbeddings(nn.Module):
1109
+ """
1110
+ Construct the CLS token, mask token, position and patch embeddings.
1111
+ """
1112
+
1113
+ def __init__(
1114
+ self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
+ ) -> None:
1116
+ super().__init__()
1117
+
1118
+ self.image_size = image_size
1119
+ self.patch_size = patch_size
1120
+ self.num_channels = num_channels
1121
+ self.hidden_size = hidden_size
1122
+
1123
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
+
1125
+ self.patch_embeddings = CustomPatchEmbeddings(
1126
+ image_size, patch_size, num_channels, hidden_size
1127
+ )
1128
+ num_patches = self.patch_embeddings.num_patches
1129
+ self.position_embeddings = nn.Parameter(
1130
+ torch.randn(1, num_patches + 1, self.hidden_size)
1131
+ )
1132
+
1133
+ def interpolate_pos_encoding(
1134
+ self, embeddings: torch.Tensor, height: int, width: int
1135
+ ) -> torch.Tensor:
1136
+ """
1137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
+ resolution images.
1139
+
1140
+ Source:
1141
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
+ """
1143
+
1144
+ num_patches = embeddings.shape[1] - 1
1145
+ num_positions = self.position_embeddings.shape[1] - 1
1146
+ if num_patches == num_positions and height == width:
1147
+ return self.position_embeddings
1148
+ class_pos_embed = self.position_embeddings[:, 0]
1149
+ patch_pos_embed = self.position_embeddings[:, 1:]
1150
+ dim = embeddings.shape[-1]
1151
+ height = height // self.patch_size
1152
+ width = width // self.patch_size
1153
+ # we add a small number to avoid floating point error in the interpolation
1154
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
+ height, width = height + 0.1, width + 0.1
1156
+ patch_pos_embed = patch_pos_embed.reshape(
1157
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
+ )
1159
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
+ patch_pos_embed = nn.functional.interpolate(
1161
+ patch_pos_embed,
1162
+ scale_factor=(
1163
+ height / math.sqrt(num_positions),
1164
+ width / math.sqrt(num_positions),
1165
+ ),
1166
+ mode="bicubic",
1167
+ align_corners=False,
1168
+ )
1169
+ if (
1170
+ int(height) != patch_pos_embed.shape[-2]
1171
+ or int(width) != patch_pos_embed.shape[-1]
1172
+ ):
1173
+ raise ValueError(
1174
+ "Width or height does not match with the interpolated position embeddings"
1175
+ )
1176
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
+
1179
+ def forward(
1180
+ self,
1181
+ pixel_values: torch.Tensor,
1182
+ ) -> torch.Tensor:
1183
+ batch_size, _, height, width = pixel_values.shape
1184
+ patch_embeddings = self.patch_embeddings(pixel_values)
1185
+ embeddings = patch_embeddings
1186
+
1187
+ # add the [CLS] token to the embedded patch tokens
1188
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
+
1191
+ # add positional encoding to each token
1192
+ embeddings = embeddings + self.interpolate_pos_encoding(
1193
+ embeddings, height, width
1194
+ )
1195
+
1196
+ return embeddings
sf3d/models/tokenizers/image.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.tokenizers.dinov2 import Dinov2Model
11
+ from sf3d.models.transformers.attention import Modulation
12
+ from sf3d.models.utils import BaseModule
13
+
14
+
15
+ class DINOV2SingleImageTokenizer(BaseModule):
16
+ @dataclass
17
+ class Config(BaseModule.Config):
18
+ pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
+ width: int = 512
20
+ height: int = 512
21
+ modulation_cond_dim: int = 768
22
+
23
+ cfg: Config
24
+
25
+ def configure(self) -> None:
26
+ self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
+
28
+ for p in self.model.parameters():
29
+ p.requires_grad_(False)
30
+ self.model.eval()
31
+
32
+ self.model.set_gradient_checkpointing(False)
33
+
34
+ # add modulation
35
+ modulations = []
36
+ for layer in self.model.encoder.layer:
37
+ norm1_modulation = Modulation(
38
+ self.model.config.hidden_size,
39
+ self.cfg.modulation_cond_dim,
40
+ zero_init=True,
41
+ single_layer=True,
42
+ )
43
+ norm2_modulation = Modulation(
44
+ self.model.config.hidden_size,
45
+ self.cfg.modulation_cond_dim,
46
+ zero_init=True,
47
+ single_layer=True,
48
+ )
49
+ layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
+ modulations += [norm1_modulation, norm2_modulation]
51
+ self.modulations = nn.ModuleList(modulations)
52
+
53
+ self.register_buffer(
54
+ "image_mean",
55
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
+ persistent=False,
57
+ )
58
+ self.register_buffer(
59
+ "image_std",
60
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
+ persistent=False,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ images: Float[Tensor, "B *N C H W"],
67
+ modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
+ **kwargs,
69
+ ) -> Float[Tensor, "B *N Ct Nt"]:
70
+ model = self.model
71
+
72
+ packed = False
73
+ if images.ndim == 4:
74
+ packed = True
75
+ images = images.unsqueeze(1)
76
+ if modulation_cond is not None:
77
+ assert modulation_cond.ndim == 2
78
+ modulation_cond = modulation_cond.unsqueeze(1)
79
+
80
+ batch_size, n_input_views = images.shape[:2]
81
+ images = (images - self.image_mean) / self.image_std
82
+ out = model(
83
+ rearrange(images, "B N C H W -> (B N) C H W"),
84
+ modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
+ if modulation_cond is not None
86
+ else None,
87
+ )
88
+ local_features = out.last_hidden_state
89
+ local_features = local_features.permute(0, 2, 1)
90
+ local_features = rearrange(
91
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
+ )
93
+ if packed:
94
+ local_features = local_features.squeeze(1)
95
+
96
+ return local_features
97
+
98
+ def detokenize(self, *args, **kwargs):
99
+ raise NotImplementedError
sf3d/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+ from jaxtyping import Float
8
+ from torch import Tensor
9
+
10
+ from sf3d.models.utils import BaseModule
11
+
12
+
13
+ class TriplaneLearnablePositionalEmbedding(BaseModule):
14
+ @dataclass
15
+ class Config(BaseModule.Config):
16
+ plane_size: int = 96
17
+ num_channels: int = 1024
18
+
19
+ cfg: Config
20
+
21
+ def configure(self) -> None:
22
+ self.embeddings = nn.Parameter(
23
+ torch.randn(
24
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
+ dtype=torch.float32,
26
+ )
27
+ * 1
28
+ / math.sqrt(self.cfg.num_channels)
29
+ )
30
+
31
+ def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
+ return rearrange(
33
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
+ )
36
+
37
+ def detokenize(
38
+ self, tokens: Float[Tensor, "B Ct Nt"]
39
+ ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
+ batch_size, Ct, Nt = tokens.shape
41
+ assert Nt == self.cfg.plane_size**2 * 3
42
+ assert Ct == self.cfg.num_channels
43
+ return rearrange(
44
+ tokens,
45
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
+ Np=3,
47
+ Hp=self.cfg.plane_size,
48
+ Wp=self.cfg.plane_size,
49
+ )
sf3d/models/transformers/attention.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Modulation(nn.Module):
6
+ def __init__(
7
+ self,
8
+ embedding_dim: int,
9
+ condition_dim: int,
10
+ zero_init: bool = False,
11
+ single_layer: bool = False,
12
+ ):
13
+ super().__init__()
14
+ self.silu = nn.SiLU()
15
+ if single_layer:
16
+ self.linear1 = nn.Identity()
17
+ else:
18
+ self.linear1 = nn.Linear(condition_dim, condition_dim)
19
+
20
+ self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
21
+
22
+ # Only zero init the last linear layer
23
+ if zero_init:
24
+ nn.init.zeros_(self.linear2.weight)
25
+ nn.init.zeros_(self.linear2.bias)
26
+
27
+ def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
28
+ emb = self.linear2(self.silu(self.linear1(condition)))
29
+ scale, shift = torch.chunk(emb, 2, dim=1)
30
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
+ return x
sf3d/models/transformers/backbone.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from sf3d.models.utils import BaseModule
9
+
10
+
11
+ class GEGLU(nn.Module):
12
+ r"""
13
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
14
+
15
+ Parameters:
16
+ dim_in (`int`): The number of channels in the input.
17
+ dim_out (`int`): The number of channels in the output.
18
+ """
19
+
20
+ def __init__(self, dim_in: int, dim_out: int):
21
+ super().__init__()
22
+ self.proj = nn.Linear(dim_in, dim_out * 2)
23
+
24
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
25
+ if gate.device.type != "mps":
26
+ return F.gelu(gate)
27
+ # mps: gelu is not implemented for float16
28
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
29
+
30
+ def forward(self, hidden_states, scale: float = 1.0):
31
+ args = ()
32
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
33
+ return hidden_states * self.gelu(gate)
34
+
35
+
36
+ class CrossAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim,
40
+ kv_dim=None,
41
+ num_heads=16,
42
+ qkv_bias=False,
43
+ attn_drop=0.0,
44
+ proj_drop=0.0,
45
+ ):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+ kv_dim = dim if not kv_dim else kv_dim
51
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
52
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
53
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
54
+ self.attn_drop = attn_drop
55
+ self.proj = nn.Linear(dim, dim)
56
+ self.proj_drop = nn.Dropout(proj_drop)
57
+
58
+ def forward(self, x_q, x_kv):
59
+ B, N_q, C = x_q.shape
60
+ B, N_kv, _ = x_kv.shape
61
+ # [B, N_q, C] -> [B, N_q, H, C/H]
62
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
63
+ # [B, N_kv, C] -> [B, N_kv, H, C/H]
64
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
65
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
66
+
67
+ # attention
68
+ x = torch.nn.functional.scaled_dot_product_attention(
69
+ q.permute(0, 2, 1, 3),
70
+ k.permute(0, 2, 1, 3),
71
+ v.permute(0, 2, 1, 3),
72
+ attn_mask=None,
73
+ dropout_p=self.attn_drop,
74
+ scale=self.scale,
75
+ ).permute(0, 2, 1, 3)
76
+
77
+ # [B, N_q, H, C/H] -> [B, N_q, C]
78
+ x = x.reshape(B, N_q, C)
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
84
+ class FeedForward(nn.Module):
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ dim_out: Optional[int] = None,
89
+ mult: int = 4,
90
+ dropout: float = 0.0,
91
+ ):
92
+ super().__init__()
93
+ inner_dim = int(dim * mult)
94
+ dim_out = dim_out if dim_out is not None else dim
95
+ act_fn = GEGLU(dim, inner_dim)
96
+ self.net = nn.ModuleList([])
97
+ self.net.append(act_fn)
98
+ self.net.append(nn.Dropout(dropout))
99
+ self.net.append(nn.Linear(inner_dim, dim_out))
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ for module in self.net:
103
+ x = module(x)
104
+ return x
105
+
106
+
107
+ class BasicBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ kv_dim: Optional[int] = None,
112
+ num_heads: int = 16,
113
+ qkv_bias: bool = False,
114
+ attn_drop: float = 0.0,
115
+ proj_drop: float = 0.0,
116
+ ff_drop: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.norm1 = nn.LayerNorm(dim)
120
+ self.attn1 = CrossAttention(
121
+ dim,
122
+ kv_dim=dim,
123
+ num_heads=num_heads,
124
+ qkv_bias=qkv_bias,
125
+ attn_drop=attn_drop,
126
+ proj_drop=proj_drop,
127
+ )
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ self.attn2 = CrossAttention(
130
+ dim,
131
+ kv_dim=kv_dim,
132
+ num_heads=num_heads,
133
+ qkv_bias=qkv_bias,
134
+ attn_drop=attn_drop,
135
+ proj_drop=proj_drop,
136
+ )
137
+ self.norm3 = nn.LayerNorm(dim)
138
+ self.ff = FeedForward(dim, dropout=ff_drop)
139
+
140
+ def forward(self, z, x):
141
+ z_norm = self.norm1(z)
142
+ z = z + self.attn1(z_norm, z_norm)
143
+ # TODO: do we need to have the second attention when x is None?
144
+ z_norm = self.norm2(z)
145
+ z = z + self.attn2(z_norm, x if x is not None else z_norm)
146
+ z_norm = self.norm3(z)
147
+ z = z + self.ff(z_norm)
148
+ return z
149
+
150
+
151
+ class SingleStreamTransformer(BaseModule):
152
+ @dataclass
153
+ class Config(BaseModule.Config):
154
+ num_attention_heads: int = 16
155
+ attention_head_dim: int = 88
156
+ in_channels: Optional[int] = None
157
+ out_channels: Optional[int] = None
158
+ num_layers: int = 16
159
+ dropout: float = 0.0
160
+ norm_num_groups: int = 32
161
+ cross_attention_dim: Optional[int] = None
162
+ attention_bias: bool = False
163
+
164
+ cfg: Config
165
+
166
+ def configure(self) -> None:
167
+ self.num_attention_heads = self.cfg.num_attention_heads
168
+ self.attention_head_dim = self.cfg.attention_head_dim
169
+ inner_dim = self.num_attention_heads * self.attention_head_dim
170
+
171
+ # Define input layers
172
+ self.norm = torch.nn.GroupNorm(
173
+ num_groups=self.cfg.norm_num_groups,
174
+ num_channels=self.cfg.in_channels,
175
+ eps=1e-6,
176
+ affine=True,
177
+ )
178
+ self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
179
+
180
+ # Define transformers blocks
181
+ self.transformer_blocks = nn.ModuleList(
182
+ [
183
+ BasicBlock(
184
+ inner_dim,
185
+ kv_dim=self.cfg.cross_attention_dim,
186
+ num_heads=self.num_attention_heads,
187
+ qkv_bias=self.cfg.attention_bias,
188
+ proj_drop=self.cfg.dropout,
189
+ ff_drop=self.cfg.dropout,
190
+ )
191
+ for d in range(self.cfg.num_layers)
192
+ ]
193
+ )
194
+
195
+ # 4. Define output layers
196
+ self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
197
+
198
+ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
199
+ residual = hidden_states
200
+ hidden_states = self.norm(hidden_states)
201
+ hidden_states = hidden_states.permute(0, 2, 1)
202
+ hidden_states = self.proj_in(hidden_states)
203
+ for block in self.transformer_blocks:
204
+ hidden_states = block(hidden_states, encoder_hidden_states)
205
+ hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
206
+ # TODO: do we really need to add the residual?
207
+ hidden_states = hidden_states + residual
208
+ return hidden_states
209
+
210
+
211
+ class FuseBlock(nn.Module):
212
+ """
213
+ Fuse X in to Z with cross attention
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ dim_z: int,
219
+ dim_x: int,
220
+ num_heads: int = 16,
221
+ qkv_bias: bool = False,
222
+ attn_drop: float = 0.0,
223
+ proj_drop: float = 0.0,
224
+ ff_drop: float = 0.0,
225
+ norm_x_input: bool = True,
226
+ ):
227
+ super().__init__()
228
+ self.norm_x_input = norm_x_input
229
+ if self.norm_x_input:
230
+ self.norm_x = nn.LayerNorm(dim_x)
231
+ self.attn = CrossAttention(
232
+ dim_z,
233
+ kv_dim=dim_x,
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ attn_drop=attn_drop,
237
+ proj_drop=proj_drop,
238
+ )
239
+ self.norm_z1 = nn.LayerNorm(dim_z)
240
+ self.norm_z2 = nn.LayerNorm(dim_z)
241
+ self.ff = FeedForward(dim_z, dropout=ff_drop)
242
+
243
+ def forward(self, z, x):
244
+ # TODO: do we need to normalize x?
245
+ z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
246
+ z = z + self.ff(self.norm_z2(z))
247
+ return z
248
+
249
+
250
+ @torch.no_grad()
251
+ def get_triplane_attention_mask(res):
252
+ N = 3 * res * res
253
+ attn_mask = torch.zeros(3, res, res, 3, res, res)
254
+
255
+ i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
256
+
257
+ attn_mask[0, i, j, 1, i, :] = 1.0
258
+ attn_mask[0, i, j, 2, j, :] = 1.0
259
+ attn_mask[1, i, j, 0, i, :] = 1.0
260
+ attn_mask[1, i, j, 2, :, j] = 1.0
261
+ attn_mask[2, i, j, 0, :, i] = 1.0
262
+ attn_mask[2, i, j, 1, :, j] = 1.0
263
+ attn_mask = attn_mask.bool()
264
+
265
+ attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
266
+ attn_bias.masked_fill_(attn_mask, 0.0)
267
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
268
+
269
+ return attn_bias.reshape(N, N)
270
+
271
+
272
+ class TriplaneAttention(nn.Module):
273
+ def __init__(
274
+ self,
275
+ dim: int,
276
+ resolution: int,
277
+ num_heads: int = 16,
278
+ qkv_bias: bool = False,
279
+ attn_drop: float = 0.0,
280
+ proj_drop: float = 0.0,
281
+ full_attention: bool = False,
282
+ ):
283
+ super().__init__()
284
+ self.num_heads = num_heads
285
+ head_dim = dim // num_heads
286
+ self.scale = head_dim**-0.5
287
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
288
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
289
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
290
+ self.attn_drop = attn_drop
291
+ self.proj = nn.Linear(dim, dim)
292
+ self.proj_drop = nn.Dropout(proj_drop)
293
+
294
+ self.resolution = resolution
295
+ self.full_attention = full_attention
296
+ self.attn_mask = (
297
+ get_triplane_attention_mask(resolution) if not full_attention else None
298
+ )
299
+
300
+ def forward(self, x):
301
+ B, N, C = x.shape
302
+ # [B, N, C] -> [B, N, H, C/H]
303
+ q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
304
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
305
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
306
+
307
+ # detokenize the planes
308
+ assert N == self.resolution**2 * 3
309
+ attn_bias = (
310
+ self.attn_mask.to(q)
311
+ .unsqueeze(0)
312
+ .unsqueeze(0)
313
+ .expand(B, self.num_heads, -1, -1)
314
+ if not self.full_attention
315
+ else None
316
+ )
317
+
318
+ # full attention
319
+ x = torch.nn.functional.scaled_dot_product_attention(
320
+ q.permute(0, 2, 1, 3),
321
+ k.permute(0, 2, 1, 3),
322
+ v.permute(0, 2, 1, 3),
323
+ attn_mask=attn_bias,
324
+ dropout_p=self.attn_drop,
325
+ scale=self.scale,
326
+ ).permute(0, 2, 1, 3)
327
+
328
+ # [B, N_q, H, C/H] -> [B, N_q, C]
329
+ x = x.reshape(B, N, C)
330
+ x = self.proj(x)
331
+ x = self.proj_drop(x)
332
+ return x
333
+
334
+
335
+ class TwoStreamBlock(nn.Module):
336
+ def __init__(
337
+ self,
338
+ dim_latent: int,
339
+ dim_input: int,
340
+ num_basic_blocks: int = 4,
341
+ num_heads: int = 16,
342
+ qkv_bias: bool = False,
343
+ attn_drop: float = 0.0,
344
+ proj_drop: float = 0.0,
345
+ ff_drop: float = 0.0,
346
+ norm_x_input: bool = True,
347
+ dim_cross: Optional[int] = None,
348
+ ):
349
+ super().__init__()
350
+
351
+ # Define the fuse block that fuse the input into the latent
352
+ self.fuse_block_in = FuseBlock(
353
+ dim_latent,
354
+ dim_input,
355
+ num_heads=num_heads,
356
+ qkv_bias=qkv_bias,
357
+ attn_drop=attn_drop,
358
+ proj_drop=proj_drop,
359
+ ff_drop=ff_drop,
360
+ norm_x_input=norm_x_input,
361
+ )
362
+
363
+ # Define the transformer block that process the latent
364
+ self.transformer_block = nn.ModuleList(
365
+ [
366
+ BasicBlock(
367
+ dim_latent,
368
+ kv_dim=dim_cross,
369
+ num_heads=num_heads,
370
+ qkv_bias=qkv_bias,
371
+ proj_drop=proj_drop,
372
+ ff_drop=ff_drop,
373
+ )
374
+ for _ in range(num_basic_blocks)
375
+ ]
376
+ )
377
+
378
+ # Define the fuse block that fuse the latent into the input
379
+ self.fuse_block_out = FuseBlock(
380
+ dim_input,
381
+ dim_latent,
382
+ num_heads=num_heads,
383
+ qkv_bias=qkv_bias,
384
+ attn_drop=attn_drop,
385
+ proj_drop=proj_drop,
386
+ ff_drop=ff_drop,
387
+ norm_x_input=norm_x_input,
388
+ )
389
+
390
+ def forward(self, latent, input, cross_input):
391
+ latent = self.fuse_block_in(latent, input)
392
+ for block in self.transformer_block:
393
+ latent = block(latent, cross_input)
394
+ input = self.fuse_block_out(input, latent)
395
+ return latent, input
396
+
397
+
398
+ class TwoStreamInterleaveTransformer(BaseModule):
399
+ @dataclass
400
+ class Config(BaseModule.Config):
401
+ num_attention_heads: int = 16
402
+ attention_head_dim: int = 64
403
+ raw_triplane_channels: int = 1024
404
+ triplane_channels: int = 1024
405
+ raw_image_channels: int = 1024
406
+ num_latents: int = 1792
407
+ num_blocks: int = 4
408
+ num_basic_blocks: int = 3
409
+ dropout: float = 0.0
410
+ latent_init_std: float = 0.02
411
+ norm_num_groups: int = 32
412
+ attention_bias: bool = False
413
+ norm_x_input: bool = False
414
+ cross_attention_dim: int = 1024
415
+ mix_latent: bool = True
416
+
417
+ cfg: Config
418
+
419
+ def configure(self) -> None:
420
+ self.mix_latent = self.cfg.mix_latent
421
+
422
+ # Define the dimensions
423
+ self.num_attention_heads = self.cfg.num_attention_heads
424
+ self.attention_head_dim = self.cfg.attention_head_dim
425
+ self.num_latents = self.cfg.num_latents
426
+ self.latent_dim = self.num_attention_heads * self.attention_head_dim
427
+
428
+ # Define input layers
429
+ if self.cfg.norm_num_groups > 0:
430
+ self.norm_triplane = torch.nn.GroupNorm(
431
+ num_groups=self.cfg.norm_num_groups,
432
+ num_channels=self.cfg.raw_triplane_channels,
433
+ eps=1e-6,
434
+ affine=True,
435
+ )
436
+ else:
437
+ self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
438
+ self.proj_triplane = nn.Linear(
439
+ self.cfg.raw_triplane_channels, self.cfg.triplane_channels
440
+ )
441
+ if self.mix_latent:
442
+ self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
443
+ self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
444
+ self.norm_latent = nn.LayerNorm(self.latent_dim)
445
+ self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
446
+
447
+ # Define the latents
448
+ self.latent_init = nn.Parameter(
449
+ torch.zeros(1, self.num_latents, self.latent_dim)
450
+ )
451
+ nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
452
+
453
+ # Define the transformer blocks
454
+ self.main_blocks = nn.ModuleList(
455
+ [
456
+ TwoStreamBlock(
457
+ self.latent_dim,
458
+ self.cfg.triplane_channels,
459
+ num_basic_blocks=self.cfg.num_basic_blocks,
460
+ num_heads=self.num_attention_heads,
461
+ qkv_bias=self.cfg.attention_bias,
462
+ proj_drop=self.cfg.dropout,
463
+ ff_drop=self.cfg.dropout,
464
+ norm_x_input=self.cfg.norm_x_input,
465
+ dim_cross=self.cfg.cross_attention_dim,
466
+ )
467
+ for _ in range(self.cfg.num_blocks)
468
+ ]
469
+ )
470
+
471
+ # 4. Define output layers
472
+ self.proj_out = nn.Linear(
473
+ self.cfg.triplane_channels, self.cfg.raw_triplane_channels
474
+ )
475
+
476
+ def forward(self, hidden_states, encoder_hidden_states, **kwargs):
477
+ # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
478
+ # encoder_hidden_states: [B, N_image, image_dim] is the image tokens
479
+ if isinstance(self.norm_triplane, nn.GroupNorm):
480
+ triplane_tokens = self.norm_triplane(hidden_states)
481
+ triplane_tokens = triplane_tokens.permute(
482
+ 0, 2, 1
483
+ ) # [B, N_triplane, triplane_dim]
484
+ elif isinstance(self.norm_triplane, nn.LayerNorm):
485
+ triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
486
+ else:
487
+ raise ValueError("Unknown normalization layer")
488
+ triplane_tokens = self.proj_triplane(triplane_tokens)
489
+ if self.mix_latent:
490
+ image_tokens = self.norm_image(
491
+ encoder_hidden_states
492
+ ) # [B, N_image, image_dim]
493
+ image_tokens = self.proj_image(image_tokens)
494
+ init_latents = self.latent_init.expand(
495
+ hidden_states.shape[0], -1, -1
496
+ ) # [B, N_latent_init, latent_dim]
497
+ init_latents = self.norm_latent(init_latents)
498
+ init_latents = self.proj_latent(init_latents)
499
+ if self.mix_latent:
500
+ latent_tokens = torch.cat(
501
+ [image_tokens, init_latents], dim=1
502
+ ) # [B, N_latent, latent_dim]
503
+ else:
504
+ latent_tokens = init_latents
505
+
506
+ # forward the main blocks
507
+ for block in self.main_blocks:
508
+ latent_tokens, triplane_tokens = block(
509
+ latent_tokens, triplane_tokens, encoder_hidden_states
510
+ )
511
+
512
+ # project the triplane tokens back to the original dimension
513
+ triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
514
+ triplane_tokens = triplane_tokens + hidden_states
515
+ return triplane_tokens
sf3d/models/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import importlib
3
+ from dataclasses import dataclass
4
+ from typing import Any, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from jaxtyping import Float, Int, Num
12
+ from omegaconf import DictConfig, OmegaConf
13
+ from torch import Tensor
14
+
15
+
16
+ class BaseModule(nn.Module):
17
+ @dataclass
18
+ class Config:
19
+ pass
20
+
21
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
22
+
23
+ def __init__(
24
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
25
+ ) -> None:
26
+ super().__init__()
27
+ self.cfg = parse_structured(self.Config, cfg)
28
+ self.configure(*args, **kwargs)
29
+
30
+ def configure(self, *args, **kwargs) -> None:
31
+ raise NotImplementedError
32
+
33
+
34
+ def find_class(cls_string):
35
+ module_string = ".".join(cls_string.split(".")[:-1])
36
+ cls_name = cls_string.split(".")[-1]
37
+ module = importlib.import_module(module_string, package=None)
38
+ cls = getattr(module, cls_name)
39
+ return cls
40
+
41
+
42
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
43
+ # Check if cfg.keys are in fields
44
+ cfg_ = cfg.copy()
45
+ keys = list(cfg_.keys())
46
+
47
+ field_names = {f.name for f in dataclasses.fields(fields)}
48
+ for key in keys:
49
+ # This is helpful when swapping out modules from CLI
50
+ if key not in field_names:
51
+ print(f"Ignoring {key} as it's not supported by {fields}")
52
+ cfg_.pop(key)
53
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
54
+ return scfg
55
+
56
+
57
+ EPS_DTYPE = {
58
+ torch.float16: 1e-4,
59
+ torch.bfloat16: 1e-4,
60
+ torch.float32: 1e-7,
61
+ torch.float64: 1e-8,
62
+ }
63
+
64
+
65
+ def dot(x, y, dim=-1):
66
+ return torch.sum(x * y, dim, keepdim=True)
67
+
68
+
69
+ def reflect(x, n):
70
+ return x - 2 * dot(x, n) * n
71
+
72
+
73
+ def normalize(x, dim=-1, eps=None):
74
+ if eps is None:
75
+ eps = EPS_DTYPE[x.dtype]
76
+ return F.normalize(x, dim=dim, p=2, eps=eps)
77
+
78
+
79
+ ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
80
+
81
+
82
+ def scale_tensor(
83
+ dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
84
+ ):
85
+ if inp_scale is None:
86
+ inp_scale = (0, 1)
87
+ if tgt_scale is None:
88
+ tgt_scale = (0, 1)
89
+ if isinstance(tgt_scale, Tensor):
90
+ assert dat.shape[-1] == tgt_scale.shape[-1]
91
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
92
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
93
+ return dat
94
+
95
+
96
+ def dilate_fill(img, mask, iterations=10):
97
+ oldMask = mask.float()
98
+ oldImg = img
99
+
100
+ mask_kernel = torch.ones(
101
+ (1, 1, 3, 3),
102
+ dtype=oldMask.dtype,
103
+ device=oldMask.device,
104
+ )
105
+
106
+ for i in range(iterations):
107
+ newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
108
+
109
+ # Fill the extension with mean color of old valid regions
110
+ img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
111
+ mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
112
+ new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
113
+
114
+ # Average color of the valid region
115
+ mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
116
+ 2
117
+ )
118
+ # Extend it to the new region
119
+ fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
120
+
121
+ mask_conv = F.conv2d(
122
+ newMask, mask_kernel, padding=1
123
+ ) # Get the sum for each kernel patch
124
+ newImg = F.fold(
125
+ fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
126
+ ) / mask_conv.clamp(1)
127
+
128
+ diffMask = newMask - oldMask
129
+
130
+ oldMask = newMask
131
+ oldImg = torch.lerp(oldImg, newImg, diffMask)
132
+
133
+ return oldImg
134
+
135
+
136
+ def float32_to_uint8_np(
137
+ x: Float[np.ndarray, "*B H W C"],
138
+ dither: bool = True,
139
+ dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
140
+ dither_strength: float = 1.0,
141
+ ) -> Int[np.ndarray, "*B H W C"]:
142
+ if dither:
143
+ dither = (
144
+ dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
145
+ )
146
+ if dither_mask is not None:
147
+ dither = dither * dither_mask
148
+ return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
149
+ return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
150
+
151
+
152
+ def convert_data(data):
153
+ if data is None:
154
+ return None
155
+ elif isinstance(data, np.ndarray):
156
+ return data
157
+ elif isinstance(data, torch.Tensor):
158
+ if data.dtype in [torch.float16, torch.bfloat16]:
159
+ data = data.float()
160
+ return data.detach().cpu().numpy()
161
+ elif isinstance(data, list):
162
+ return [convert_data(d) for d in data]
163
+ elif isinstance(data, dict):
164
+ return {k: convert_data(v) for k, v in data.items()}
165
+ else:
166
+ raise TypeError(
167
+ "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
168
+ type(data),
169
+ )
170
+
171
+
172
+ class ImageProcessor:
173
+ def convert_and_resize(
174
+ self,
175
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
176
+ size: int,
177
+ ):
178
+ if isinstance(image, PIL.Image.Image):
179
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
180
+ elif isinstance(image, np.ndarray):
181
+ if image.dtype == np.uint8:
182
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
183
+ else:
184
+ image = torch.from_numpy(image)
185
+ elif isinstance(image, torch.Tensor):
186
+ pass
187
+
188
+ batched = image.ndim == 4
189
+
190
+ if not batched:
191
+ image = image[None, ...]
192
+ image = F.interpolate(
193
+ image.permute(0, 3, 1, 2),
194
+ (size, size),
195
+ mode="bilinear",
196
+ align_corners=False,
197
+ antialias=True,
198
+ ).permute(0, 2, 3, 1)
199
+ if not batched:
200
+ image = image[0]
201
+ return image
202
+
203
+ def __call__(
204
+ self,
205
+ image: Union[
206
+ PIL.Image.Image,
207
+ np.ndarray,
208
+ torch.FloatTensor,
209
+ List[PIL.Image.Image],
210
+ List[np.ndarray],
211
+ List[torch.FloatTensor],
212
+ ],
213
+ size: int,
214
+ ) -> Any:
215
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
216
+ image = self.convert_and_resize(image, size)
217
+ else:
218
+ if not isinstance(image, list):
219
+ image = [image]
220
+ image = [self.convert_and_resize(im, size) for im in image]
221
+ image = torch.stack(image, dim=0)
222
+ return image
223
+
224
+
225
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
226
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
227
+ intrinsic = np.identity(3, dtype=np.float32)
228
+ intrinsic[0, 0] = focal_length
229
+ intrinsic[1, 1] = focal_length
230
+ intrinsic[0, 2] = W / 2.0
231
+ intrinsic[1, 2] = H / 2.0
232
+
233
+ if bs > 0:
234
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
235
+
236
+ return torch.from_numpy(intrinsic)
sf3d/system.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from contextlib import nullcontext
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, List, Literal, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import trimesh
10
+ from einops import rearrange
11
+ from huggingface_hub import hf_hub_download
12
+ from jaxtyping import Float
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+ from safetensors.torch import load_model
16
+ from torch import Tensor
17
+
18
+ from sf3d.models.isosurface import MarchingTetrahedraHelper
19
+ from sf3d.models.mesh import Mesh
20
+ from sf3d.models.utils import (
21
+ BaseModule,
22
+ ImageProcessor,
23
+ convert_data,
24
+ dilate_fill,
25
+ find_class,
26
+ float32_to_uint8_np,
27
+ normalize,
28
+ scale_tensor,
29
+ )
30
+ from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w, get_device
31
+
32
+ try:
33
+ from texture_baker import TextureBaker
34
+ except ImportError:
35
+ import logging
36
+
37
+ logging.warning(
38
+ "Could not import texture_baker. Please install it via `pip install texture-baker/`"
39
+ )
40
+ # Exit early to avoid further errors
41
+ raise ImportError("texture_baker not found")
42
+
43
+
44
+ class SF3D(BaseModule):
45
+ @dataclass
46
+ class Config(BaseModule.Config):
47
+ cond_image_size: int
48
+ isosurface_resolution: int
49
+ isosurface_threshold: float = 10.0
50
+ radius: float = 1.0
51
+ background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
52
+ default_fovy_deg: float = 40.0
53
+ default_distance: float = 1.6
54
+
55
+ camera_embedder_cls: str = ""
56
+ camera_embedder: dict = field(default_factory=dict)
57
+
58
+ image_tokenizer_cls: str = ""
59
+ image_tokenizer: dict = field(default_factory=dict)
60
+
61
+ tokenizer_cls: str = ""
62
+ tokenizer: dict = field(default_factory=dict)
63
+
64
+ backbone_cls: str = ""
65
+ backbone: dict = field(default_factory=dict)
66
+
67
+ post_processor_cls: str = ""
68
+ post_processor: dict = field(default_factory=dict)
69
+
70
+ decoder_cls: str = ""
71
+ decoder: dict = field(default_factory=dict)
72
+
73
+ image_estimator_cls: str = ""
74
+ image_estimator: dict = field(default_factory=dict)
75
+
76
+ global_estimator_cls: str = ""
77
+ global_estimator: dict = field(default_factory=dict)
78
+
79
+ cfg: Config
80
+
81
+ @classmethod
82
+ def from_pretrained(
83
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
84
+ ):
85
+ if os.path.isdir(pretrained_model_name_or_path):
86
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
87
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
88
+ else:
89
+ config_path = hf_hub_download(
90
+ repo_id=pretrained_model_name_or_path, filename=config_name
91
+ )
92
+ weight_path = hf_hub_download(
93
+ repo_id=pretrained_model_name_or_path, filename=weight_name
94
+ )
95
+
96
+ cfg = OmegaConf.load(config_path)
97
+ OmegaConf.resolve(cfg)
98
+ model = cls(cfg)
99
+ load_model(model, weight_path)
100
+ return model
101
+
102
+ @property
103
+ def device(self):
104
+ return next(self.parameters()).device
105
+
106
+ def configure(self):
107
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
108
+ self.cfg.image_tokenizer
109
+ )
110
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
111
+ self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
112
+ self.cfg.camera_embedder
113
+ )
114
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
115
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
116
+ self.cfg.post_processor
117
+ )
118
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
119
+ self.image_estimator = find_class(self.cfg.image_estimator_cls)(
120
+ self.cfg.image_estimator
121
+ )
122
+ self.global_estimator = find_class(self.cfg.global_estimator_cls)(
123
+ self.cfg.global_estimator
124
+ )
125
+
126
+ self.bbox: Float[Tensor, "2 3"]
127
+ self.register_buffer(
128
+ "bbox",
129
+ torch.as_tensor(
130
+ [
131
+ [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
132
+ [self.cfg.radius, self.cfg.radius, self.cfg.radius],
133
+ ],
134
+ dtype=torch.float32,
135
+ ),
136
+ )
137
+ self.isosurface_helper = MarchingTetrahedraHelper(
138
+ self.cfg.isosurface_resolution,
139
+ os.path.join(
140
+ os.path.dirname(__file__),
141
+ "..",
142
+ "load",
143
+ "tets",
144
+ f"{self.cfg.isosurface_resolution}_tets.npz",
145
+ ),
146
+ )
147
+
148
+ self.baker = TextureBaker()
149
+ self.image_processor = ImageProcessor()
150
+
151
+ def triplane_to_meshes(
152
+ self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
153
+ ) -> list[Mesh]:
154
+ meshes = []
155
+ for i in range(triplanes.shape[0]):
156
+ triplane = triplanes[i]
157
+ grid_vertices = scale_tensor(
158
+ self.isosurface_helper.grid_vertices.to(triplanes.device),
159
+ self.isosurface_helper.points_range,
160
+ self.bbox,
161
+ )
162
+
163
+ values = self.query_triplane(grid_vertices, triplane)
164
+ decoded = self.decoder(values, include=["vertex_offset", "density"])
165
+ sdf = decoded["density"] - self.cfg.isosurface_threshold
166
+
167
+ deform = decoded["vertex_offset"].squeeze(0)
168
+
169
+ mesh: Mesh = self.isosurface_helper(
170
+ sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
171
+ )
172
+ mesh.v_pos = scale_tensor(
173
+ mesh.v_pos, self.isosurface_helper.points_range, self.bbox
174
+ )
175
+
176
+ meshes.append(mesh)
177
+
178
+ return meshes
179
+
180
+ def query_triplane(
181
+ self,
182
+ positions: Float[Tensor, "*B N 3"],
183
+ triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
184
+ ) -> Float[Tensor, "*B N F"]:
185
+ batched = positions.ndim == 3
186
+ if not batched:
187
+ # no batch dimension
188
+ triplanes = triplanes[None, ...]
189
+ positions = positions[None, ...]
190
+ assert triplanes.ndim == 5 and positions.ndim == 3
191
+
192
+ positions = scale_tensor(
193
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
194
+ )
195
+
196
+ indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
197
+ (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
198
+ dim=-3,
199
+ ).to(triplanes.dtype)
200
+ out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
201
+ rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
202
+ rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
203
+ align_corners=True,
204
+ mode="bilinear",
205
+ )
206
+ out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
207
+
208
+ return out
209
+
210
+ def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
211
+ # if batch[rgb_cond] is only one view, add a view dimension
212
+ if len(batch["rgb_cond"].shape) == 4:
213
+ batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
214
+ batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
215
+ batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
216
+ batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
217
+ batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
218
+
219
+ batch_size, n_input_views = batch["rgb_cond"].shape[:2]
220
+
221
+ camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
222
+ camera_embeds = self.camera_embedder(**batch)
223
+
224
+ input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
225
+ rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
226
+ modulation_cond=camera_embeds,
227
+ )
228
+
229
+ input_image_tokens = rearrange(
230
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
231
+ )
232
+
233
+ tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
234
+
235
+ tokens = self.backbone(
236
+ tokens,
237
+ encoder_hidden_states=input_image_tokens,
238
+ modulation_cond=None,
239
+ )
240
+
241
+ direct_codes = self.tokenizer.detokenize(tokens)
242
+ scene_codes = self.post_processor(direct_codes)
243
+ return scene_codes, direct_codes
244
+
245
+ def run_image(
246
+ self,
247
+ image: Union[Image.Image, List[Image.Image]],
248
+ bake_resolution: int,
249
+ remesh: Literal["none", "triangle", "quad"] = "none",
250
+ vertex_count: int = -1,
251
+ estimate_illumination: bool = False,
252
+ ) -> Tuple[Union[trimesh.Trimesh, List[trimesh.Trimesh]], dict[str, Any]]:
253
+ if isinstance(image, list):
254
+ rgb_cond = []
255
+ mask_cond = []
256
+ for img in image:
257
+ mask, rgb = self.prepare_image(img)
258
+ mask_cond.append(mask)
259
+ rgb_cond.append(rgb)
260
+ rgb_cond = torch.stack(rgb_cond, 0)
261
+ mask_cond = torch.stack(mask_cond, 0)
262
+ batch_size = rgb_cond.shape[0]
263
+ else:
264
+ mask_cond, rgb_cond = self.prepare_image(image)
265
+ batch_size = 1
266
+
267
+ c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
268
+ intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
269
+ self.cfg.default_fovy_deg,
270
+ self.cfg.cond_image_size,
271
+ self.cfg.cond_image_size,
272
+ )
273
+
274
+ batch = {
275
+ "rgb_cond": rgb_cond,
276
+ "mask_cond": mask_cond,
277
+ "c2w_cond": c2w_cond.view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1),
278
+ "intrinsic_cond": intrinsic.to(self.device)
279
+ .view(1, 1, 3, 3)
280
+ .repeat(batch_size, 1, 1, 1),
281
+ "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device)
282
+ .view(1, 1, 3, 3)
283
+ .repeat(batch_size, 1, 1, 1),
284
+ }
285
+
286
+ meshes, global_dict = self.generate_mesh(
287
+ batch, bake_resolution, remesh, vertex_count, estimate_illumination
288
+ )
289
+ if batch_size == 1:
290
+ return meshes[0], global_dict
291
+ else:
292
+ return meshes, global_dict
293
+
294
+ def prepare_image(self, image):
295
+ if image.mode != "RGBA":
296
+ raise ValueError("Image must be in RGBA mode")
297
+ img_cond = (
298
+ torch.from_numpy(
299
+ np.asarray(
300
+ image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
301
+ ).astype(np.float32)
302
+ / 255.0
303
+ )
304
+ .float()
305
+ .clip(0, 1)
306
+ .to(self.device)
307
+ )
308
+ mask_cond = img_cond[:, :, -1:]
309
+ rgb_cond = torch.lerp(
310
+ torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
311
+ img_cond[:, :, :3],
312
+ mask_cond,
313
+ )
314
+
315
+ return mask_cond, rgb_cond
316
+
317
+ def generate_mesh(
318
+ self,
319
+ batch,
320
+ bake_resolution: int,
321
+ remesh: Literal["none", "triangle", "quad"] = "none",
322
+ vertex_count: int = -1,
323
+ estimate_illumination: bool = False,
324
+ ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
325
+ batch["rgb_cond"] = self.image_processor(
326
+ batch["rgb_cond"], self.cfg.cond_image_size
327
+ )
328
+ batch["mask_cond"] = self.image_processor(
329
+ batch["mask_cond"], self.cfg.cond_image_size
330
+ )
331
+ scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
332
+
333
+ global_dict = {}
334
+ if self.image_estimator is not None:
335
+ global_dict.update(
336
+ self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
337
+ )
338
+ if self.global_estimator is not None and estimate_illumination:
339
+ global_dict.update(self.global_estimator(non_postprocessed_codes))
340
+
341
+ device = get_device()
342
+ with torch.no_grad():
343
+ with torch.autocast(
344
+ device_type=device, enabled=False
345
+ ) if "cuda" in device else nullcontext():
346
+ meshes = self.triplane_to_meshes(scene_codes)
347
+
348
+ rets = []
349
+ for i, mesh in enumerate(meshes):
350
+ # Check for empty mesh
351
+ if mesh.v_pos.shape[0] == 0:
352
+ rets.append(trimesh.Trimesh())
353
+ continue
354
+
355
+ if remesh == "triangle":
356
+ mesh = mesh.triangle_remesh(triangle_vertex_count=vertex_count)
357
+ elif remesh == "quad":
358
+ mesh = mesh.quad_remesh(quad_vertex_count=vertex_count)
359
+ else:
360
+ if vertex_count > 0:
361
+ print(
362
+ "Warning: vertex_count is ignored when remesh is none"
363
+ )
364
+
365
+ print("After Remesh", mesh.v_pos.shape[0], mesh.t_pos_idx.shape[0])
366
+ mesh.unwrap_uv()
367
+
368
+ # Build textures
369
+ rast = self.baker.rasterize(
370
+ mesh.v_tex, mesh.t_pos_idx, bake_resolution
371
+ )
372
+ bake_mask = self.baker.get_mask(rast)
373
+
374
+ pos_bake = self.baker.interpolate(
375
+ mesh.v_pos,
376
+ rast,
377
+ mesh.t_pos_idx,
378
+ )
379
+ gb_pos = pos_bake[bake_mask]
380
+
381
+ tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
382
+ decoded = self.decoder(
383
+ tri_query, exclude=["density", "vertex_offset"]
384
+ )
385
+
386
+ nrm = self.baker.interpolate(
387
+ mesh.v_nrm,
388
+ rast,
389
+ mesh.t_pos_idx,
390
+ )
391
+ gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
392
+ decoded["normal"] = gb_nrm
393
+
394
+ # Check if any keys in global_dict start with decoded_
395
+ for k, v in global_dict.items():
396
+ if k.startswith("decoder_"):
397
+ decoded[k.replace("decoder_", "")] = v[i]
398
+
399
+ mat_out = {
400
+ "albedo": decoded["features"],
401
+ "roughness": decoded["roughness"],
402
+ "metallic": decoded["metallic"],
403
+ "normal": normalize(decoded["perturb_normal"]),
404
+ "bump": None,
405
+ }
406
+
407
+ for k, v in mat_out.items():
408
+ if v is None:
409
+ continue
410
+ if v.shape[0] == 1:
411
+ # Skip and directly add a single value
412
+ mat_out[k] = v[0]
413
+ else:
414
+ f = torch.zeros(
415
+ bake_resolution,
416
+ bake_resolution,
417
+ v.shape[-1],
418
+ dtype=v.dtype,
419
+ device=v.device,
420
+ )
421
+ if v.shape == f.shape:
422
+ continue
423
+ if k == "normal":
424
+ # Use un-normalized tangents here so that larger smaller tris
425
+ # Don't effect the tangents that much
426
+ tng = self.baker.interpolate(
427
+ mesh.v_tng,
428
+ rast,
429
+ mesh.t_pos_idx,
430
+ )
431
+ gb_tng = tng[bake_mask]
432
+ gb_tng = F.normalize(gb_tng, dim=-1)
433
+ gb_btng = F.normalize(
434
+ torch.cross(gb_nrm, gb_tng, dim=-1), dim=-1
435
+ )
436
+ normal = F.normalize(mat_out["normal"], dim=-1)
437
+
438
+ # Create tangent space matrix and transform normal
439
+ tangent_matrix = torch.stack(
440
+ [gb_tng, gb_btng, gb_nrm], dim=-1
441
+ )
442
+ normal_tangent = torch.bmm(
443
+ tangent_matrix.transpose(1, 2), normal.unsqueeze(-1)
444
+ ).squeeze(-1)
445
+
446
+ # Convert from [-1,1] to [0,1] range for storage
447
+ normal_tangent = (normal_tangent * 0.5 + 0.5).clamp(
448
+ 0, 1
449
+ )
450
+
451
+ f[bake_mask] = normal_tangent.view(-1, 3)
452
+ mat_out["bump"] = f
453
+ else:
454
+ f[bake_mask] = v.view(-1, v.shape[-1])
455
+ mat_out[k] = f
456
+
457
+ def uv_padding(arr):
458
+ if arr.ndim == 1:
459
+ return arr
460
+ return (
461
+ dilate_fill(
462
+ arr.permute(2, 0, 1)[None, ...].contiguous(),
463
+ bake_mask.unsqueeze(0).unsqueeze(0),
464
+ iterations=bake_resolution // 150,
465
+ )
466
+ .squeeze(0)
467
+ .permute(1, 2, 0)
468
+ .contiguous()
469
+ )
470
+
471
+ verts_np = convert_data(mesh.v_pos)
472
+ faces = convert_data(mesh.t_pos_idx)
473
+ uvs = convert_data(mesh.v_tex)
474
+
475
+ basecolor_tex = Image.fromarray(
476
+ float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
477
+ ).convert("RGB")
478
+ basecolor_tex.format = "JPEG"
479
+
480
+ metallic = mat_out["metallic"].squeeze().cpu().item()
481
+ roughness = mat_out["roughness"].squeeze().cpu().item()
482
+
483
+ if "bump" in mat_out and mat_out["bump"] is not None:
484
+ bump_np = convert_data(uv_padding(mat_out["bump"]))
485
+ bump_up = np.ones_like(bump_np)
486
+ bump_up[..., :2] = 0.5
487
+ bump_up[..., 2:] = 1
488
+ bump_tex = Image.fromarray(
489
+ float32_to_uint8_np(
490
+ bump_np,
491
+ dither=True,
492
+ # Do not dither if something is perfectly flat
493
+ dither_mask=np.all(
494
+ bump_np == bump_up, axis=-1, keepdims=True
495
+ ).astype(np.float32),
496
+ )
497
+ ).convert("RGB")
498
+ bump_tex.format = (
499
+ "JPEG" # PNG would be better but the assets are larger
500
+ )
501
+ else:
502
+ bump_tex = None
503
+
504
+ material = trimesh.visual.material.PBRMaterial(
505
+ baseColorTexture=basecolor_tex,
506
+ roughnessFactor=roughness,
507
+ metallicFactor=metallic,
508
+ normalTexture=bump_tex,
509
+ )
510
+
511
+ tmesh = trimesh.Trimesh(
512
+ vertices=verts_np,
513
+ faces=faces,
514
+ visual=trimesh.visual.texture.TextureVisuals(
515
+ uv=uvs, material=material
516
+ ),
517
+ )
518
+ rot = trimesh.transformations.rotation_matrix(
519
+ np.radians(-90), [1, 0, 0]
520
+ )
521
+ tmesh.apply_transform(rot)
522
+ tmesh.apply_transform(
523
+ trimesh.transformations.rotation_matrix(
524
+ np.radians(90), [0, 1, 0]
525
+ )
526
+ )
527
+
528
+ tmesh.invert()
529
+
530
+ rets.append(tmesh)
531
+
532
+ return rets, global_dict
sf3d/utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Union
3
+
4
+ import numpy as np
5
+ import rembg
6
+ import torch
7
+ import torchvision.transforms.functional as torchvision_F
8
+ from PIL import Image
9
+
10
+ import sf3d.models.utils as sf3d_utils
11
+
12
+
13
+ def get_device():
14
+ if os.environ.get("SF3D_USE_CPU", "0") == "1":
15
+ return "cpu"
16
+
17
+ device = "cpu"
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ device = "mps"
22
+ return device
23
+
24
+
25
+ def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
26
+ intrinsic = sf3d_utils.get_intrinsic_from_fov(
27
+ np.deg2rad(fov_deg),
28
+ H=cond_height,
29
+ W=cond_width,
30
+ )
31
+ intrinsic_normed_cond = intrinsic.clone()
32
+ intrinsic_normed_cond[..., 0, 2] /= cond_width
33
+ intrinsic_normed_cond[..., 1, 2] /= cond_height
34
+ intrinsic_normed_cond[..., 0, 0] /= cond_width
35
+ intrinsic_normed_cond[..., 1, 1] /= cond_height
36
+
37
+ return intrinsic, intrinsic_normed_cond
38
+
39
+
40
+ def default_cond_c2w(distance: float):
41
+ c2w_cond = torch.as_tensor(
42
+ [
43
+ [0, 0, 1, distance],
44
+ [1, 0, 0, 0],
45
+ [0, 1, 0, 0],
46
+ [0, 0, 0, 1],
47
+ ]
48
+ ).float()
49
+ return c2w_cond
50
+
51
+
52
+ def remove_background(
53
+ image: Image,
54
+ rembg_session: Any = None,
55
+ force: bool = False,
56
+ **rembg_kwargs,
57
+ ) -> Image:
58
+ do_remove = True
59
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
60
+ do_remove = False
61
+ do_remove = do_remove or force
62
+ if do_remove:
63
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
64
+ return image
65
+
66
+
67
+ def get_1d_bounds(arr):
68
+ nz = np.flatnonzero(arr)
69
+ return nz[0], nz[-1]
70
+
71
+
72
+ def get_bbox_from_mask(mask, thr=0.5):
73
+ masks_for_box = (mask > thr).astype(np.float32)
74
+ assert masks_for_box.sum() > 0, "Empty mask!"
75
+ x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2))
76
+ y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1))
77
+ return x0, y0, x1, y1
78
+
79
+
80
+ def resize_foreground(
81
+ image: Union[Image.Image, np.ndarray],
82
+ ratio: float,
83
+ out_size=None,
84
+ ) -> Image:
85
+ if isinstance(image, np.ndarray):
86
+ image = Image.fromarray(image, mode="RGBA")
87
+ assert image.mode == "RGBA"
88
+ # Get bounding box
89
+ mask_np = np.array(image)[:, :, -1]
90
+ x1, y1, x2, y2 = get_bbox_from_mask(mask_np, thr=0.5)
91
+ h, w = y2 - y1, x2 - x1
92
+ yc, xc = (y1 + y2) / 2, (x1 + x2) / 2
93
+ scale = max(h, w) / ratio
94
+
95
+ new_image = torchvision_F.crop(
96
+ image,
97
+ top=int(yc - scale / 2),
98
+ left=int(xc - scale / 2),
99
+ height=int(scale),
100
+ width=int(scale),
101
+ )
102
+ if out_size is not None:
103
+ new_image = new_image.resize(out_size)
104
+
105
+ return new_image
texture_baker/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Texture baker
2
+
3
+ Small texture baker which rasterizes barycentric coordinates to a tensor.
4
+ It also implements an interpolation module which can be used to bake attributes to textures then.
5
+
6
+ ## Usage
7
+
8
+ The baker can quickly bake vertex attributes to the a texture atlas based on the UV coordinates.
9
+ It supports baking on the CPU and GPU.
10
+
11
+ ```python
12
+ from texture_baker import TextureBaker
13
+
14
+ mesh = ...
15
+ uv = mesh.uv # num_vertex, 2
16
+ triangle_idx = mesh.faces # num_faces, 3
17
+ vertices = mesh.vertices # num_vertex, 3
18
+
19
+ tb = TextureBaker()
20
+ # First get the barycentric coordinates
21
+ rast = tb.rasterize(
22
+ uv=uv, face_indices=triangle_idx, bake_resolution=1024
23
+ )
24
+ # Then interpolate vertex attributes
25
+ position_bake = tb.interpolate(attr=vertices, rast=rast, face_indices=triangle_idx)
26
+ ```
texture_baker/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
texture_baker/setup.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import platform
4
+
5
+ import torch
6
+ from setuptools import find_packages, setup
7
+ from torch.utils.cpp_extension import (
8
+ CUDA_HOME,
9
+ BuildExtension,
10
+ CppExtension,
11
+ CUDAExtension,
12
+ )
13
+
14
+ library_name = "texture_baker"
15
+
16
+
17
+ def get_extensions():
18
+ debug_mode = os.getenv("DEBUG", "0") == "1"
19
+ use_cuda = os.getenv("USE_CUDA", "1" if torch.cuda.is_available() else "0") == "1"
20
+ use_metal = (
21
+ os.getenv("USE_METAL", "1" if torch.backends.mps.is_available() else "0") == "1"
22
+ )
23
+ use_native_arch = os.getenv("USE_NATIVE_ARCH", "1") == "1"
24
+ if debug_mode:
25
+ print("Compiling in debug mode")
26
+
27
+ use_cuda = use_cuda and CUDA_HOME is not None
28
+ extension = CUDAExtension if use_cuda else CppExtension
29
+
30
+ is_hip_extension = True if ((os.environ.get('ROCM_HOME') is not None) and (torch.version.hip is not None)) else False
31
+
32
+ extra_link_args = []
33
+ extra_compile_args = {
34
+ "cxx": [
35
+ "-O3" if not debug_mode else "-O0",
36
+ "-fdiagnostics-color=always",
37
+ "-fopenmp",
38
+ ]
39
+ + ["-march=native"]
40
+ if use_native_arch
41
+ else [],
42
+ "nvcc": [
43
+ "-O3" if not debug_mode else "-O0",
44
+ ],
45
+ }
46
+ if debug_mode:
47
+ extra_compile_args["cxx"].append("-g")
48
+ if platform.system() == "Windows":
49
+ extra_compile_args["cxx"].append("/Z7")
50
+ extra_compile_args["cxx"].append("/Od")
51
+ extra_link_args.extend(["/DEBUG"])
52
+ extra_compile_args["cxx"].append("-UNDEBUG")
53
+ extra_compile_args["nvcc"].append("-UNDEBUG")
54
+ extra_compile_args["nvcc"].append("-g")
55
+ extra_link_args.extend(["-O0", "-g"])
56
+
57
+ define_macros = []
58
+ extensions = []
59
+ libraries = []
60
+
61
+ this_dir = os.path.dirname(os.path.curdir)
62
+ sources = glob.glob(
63
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
64
+ )
65
+
66
+ if len(sources) == 0:
67
+ print("No source files found for extension, skipping extension compilation")
68
+ return None
69
+
70
+ if use_cuda:
71
+ define_macros += [
72
+ ("THRUST_IGNORE_CUB_VERSION_CHECK", None),
73
+ ]
74
+ sources += glob.glob(
75
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cu"), recursive=True
76
+ )
77
+
78
+ if not is_hip_extension:
79
+ libraries += ["cudart", "c10_cuda"]
80
+
81
+ if use_metal:
82
+ define_macros += [
83
+ ("WITH_MPS", None),
84
+ ]
85
+ sources += glob.glob(
86
+ os.path.join(this_dir, library_name, "csrc", "**", "*.mm"), recursive=True
87
+ )
88
+ extra_compile_args.update({"cxx": ["-O3", "-arch", "arm64", "-mmacosx-version-min=10.15"]})
89
+ extra_link_args += ["-arch", "arm64"]
90
+
91
+ extensions.append(
92
+ extension(
93
+ name=f"{library_name}._C",
94
+ sources=sources,
95
+ define_macros=define_macros,
96
+ extra_compile_args=extra_compile_args,
97
+ extra_link_args=extra_link_args,
98
+ libraries=libraries
99
+ + [
100
+ "c10",
101
+ "torch",
102
+ "torch_cpu",
103
+ "torch_python",
104
+ ],
105
+ )
106
+ )
107
+
108
+ for ext in extensions:
109
+ ext.libraries = ["cudart_static" if x == "cudart" else x for x in ext.libraries]
110
+
111
+ print(extensions)
112
+
113
+ return extensions
114
+
115
+
116
+ setup(
117
+ name=library_name,
118
+ version="0.0.1",
119
+ packages=find_packages(where="."),
120
+ package_dir={"": "."},
121
+ ext_modules=get_extensions(),
122
+ install_requires=[],
123
+ package_data={
124
+ library_name: [os.path.join("csrc", "*.h"), os.path.join("csrc", "*.metal")],
125
+ },
126
+ description="Small texture baker which rasterizes barycentric coordinates to a tensor.",
127
+ long_description=open("README.md").read(),
128
+ long_description_content_type="text/markdown",
129
+ url="https://github.com/Stability-AI/texture_baker",
130
+ cmdclass={"build_ext": BuildExtension},
131
+ )
texture_baker/texture_baker/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .baker import TextureBaker # noqa: F401
texture_baker/texture_baker/baker.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+
6
+ class TextureBaker(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def rasterize(
11
+ self,
12
+ uv: Tensor,
13
+ face_indices: Tensor,
14
+ bake_resolution: int,
15
+ ) -> Tensor:
16
+ """
17
+ Rasterize the UV coordinates to a barycentric coordinates
18
+ & Triangle idxs texture map
19
+
20
+ Args:
21
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
22
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
23
+ bake_resolution (int): Resolution of the bake
24
+
25
+ Returns:
26
+ Tensor, bake_resolution bake_resolution 4, float: Rasterized map
27
+ """
28
+ return torch.ops.texture_baker_cpp.rasterize(
29
+ uv, face_indices.to(torch.int32), bake_resolution
30
+ )
31
+
32
+ def get_mask(self, rast: Tensor) -> Tensor:
33
+ """
34
+ Get the occupancy mask from the rasterized map
35
+
36
+ Args:
37
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
38
+
39
+ Returns:
40
+ Tensor, bake_resolution bake_resolution, bool: Mask
41
+ """
42
+ return rast[..., -1] >= 0
43
+
44
+ def interpolate(
45
+ self,
46
+ attr: Tensor,
47
+ rast: Tensor,
48
+ face_indices: Tensor,
49
+ ) -> Tensor:
50
+ """
51
+ Interpolate the attributes using the rasterized map
52
+
53
+ Args:
54
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
55
+ rast (Tensor, bake_resolution bake_resolution 4, float): Rasterized map
56
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
57
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
58
+
59
+ Returns:
60
+ Tensor, bake_resolution bake_resolution 3, float: Interpolated attributes
61
+ """
62
+ return torch.ops.texture_baker_cpp.interpolate(
63
+ attr, face_indices.to(torch.int32), rast
64
+ )
65
+
66
+ def forward(
67
+ self,
68
+ attr: Tensor,
69
+ uv: Tensor,
70
+ face_indices: Tensor,
71
+ bake_resolution: int,
72
+ ) -> Tensor:
73
+ """
74
+ Bake the texture
75
+
76
+ Args:
77
+ attr (Tensor, num_vertices 3, float): Attributes of the mesh
78
+ uv (Tensor, num_vertices 2, float): UV coordinates of the mesh
79
+ face_indices (Tensor, num_faces 3, int): Face indices of the mesh
80
+ bake_resolution (int): Resolution of the bake
81
+
82
+ Returns:
83
+ Tensor, bake_resolution bake_resolution 3, float: Baked texture
84
+ """
85
+ rast = self.rasterize(uv, face_indices, bake_resolution)
86
+ return self.interpolate(attr, rast, face_indices, uv)
texture_baker/texture_baker/csrc/baker.cpp ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <chrono>
4
+ #include <cmath>
5
+ #include <omp.h>
6
+ #include <torch/extension.h>
7
+ #ifndef __ARM_ARCH_ISA_A64
8
+ #include <immintrin.h>
9
+ #endif
10
+
11
+ #include "baker.h"
12
+
13
+ // #define TIMING
14
+ #define BINS 8
15
+
16
+ namespace texture_baker_cpp {
17
+ // Calculate the centroid of a triangle
18
+ tb_float2 triangle_centroid(const tb_float2 &v0, const tb_float2 &v1,
19
+ const tb_float2 &v2) {
20
+ return {(v0.x + v1.x + v2.x) * 0.3333f, (v0.y + v1.y + v2.y) * 0.3333f};
21
+ }
22
+
23
+ float BVH::find_best_split_plane(const BVHNode &node, int &best_axis,
24
+ int &best_pos, AABB &centroidBounds) {
25
+ float best_cost = std::numeric_limits<float>::max();
26
+
27
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
28
+ {
29
+ float boundsMin = centroidBounds.min[axis];
30
+ float boundsMax = centroidBounds.max[axis];
31
+ if (boundsMin == boundsMax) {
32
+ continue;
33
+ }
34
+
35
+ // Populate the bins
36
+ float scale = BINS / (boundsMax - boundsMin);
37
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
38
+ int leftSum = 0, rightSum = 0;
39
+
40
+ #ifndef __ARM_ARCH_ISA_A64
41
+ #ifndef _MSC_VER
42
+ if (__builtin_cpu_supports("sse"))
43
+ #elif (defined(_M_AMD64) || defined(_M_X64))
44
+ // SSE supported on Windows
45
+ if constexpr (true)
46
+ #endif
47
+ {
48
+ __m128 min4[BINS], max4[BINS];
49
+ unsigned int count[BINS];
50
+ for (unsigned int i = 0; i < BINS; i++)
51
+ min4[i] = _mm_set_ps1(1e30f), max4[i] = _mm_set_ps1(-1e30f),
52
+ count[i] = 0;
53
+ for (int i = node.start; i < node.end; i++) {
54
+ int tri_idx = triangle_indices[i];
55
+ const Triangle &triangle = triangles[tri_idx];
56
+
57
+ int binIdx = std::min(
58
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
59
+ count[binIdx]++;
60
+ __m128 v0 = _mm_set_ps(triangle.v0.x, triangle.v0.y, 0.0f, 0.0f);
61
+ __m128 v1 = _mm_set_ps(triangle.v1.x, triangle.v1.y, 0.0f, 0.0f);
62
+ __m128 v2 = _mm_set_ps(triangle.v2.x, triangle.v2.y, 0.0f, 0.0f);
63
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
64
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
65
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
66
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
67
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
68
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
69
+ }
70
+ // gather data for the 7 planes between the 8 bins
71
+ __m128 leftMin4 = _mm_set_ps1(1e30f), rightMin4 = leftMin4;
72
+ __m128 leftMax4 = _mm_set_ps1(-1e30f), rightMax4 = leftMax4;
73
+ for (int i = 0; i < BINS - 1; i++) {
74
+ leftSum += count[i];
75
+ rightSum += count[BINS - 1 - i];
76
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
77
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
78
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
79
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
80
+ float le[4], re[4];
81
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
82
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
83
+ // SSE order goes from back to front
84
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
85
+ rightCountArea[BINS - 2 - i] =
86
+ rightSum * (re[2] * re[3]); // 2D area calculation
87
+ }
88
+ }
89
+ #else
90
+ if constexpr (false) {
91
+ }
92
+ #endif
93
+ else {
94
+ struct Bin {
95
+ AABB bounds;
96
+ int triCount = 0;
97
+ } bins[BINS];
98
+
99
+ for (int i = node.start; i < node.end; i++) {
100
+ int tri_idx = triangle_indices[i];
101
+ const Triangle &triangle = triangles[tri_idx];
102
+
103
+ int binIdx = std::min(
104
+ BINS - 1, (int)((triangle.centroid[axis] - boundsMin) * scale));
105
+ bins[binIdx].triCount++;
106
+ bins[binIdx].bounds.grow(triangle.v0);
107
+ bins[binIdx].bounds.grow(triangle.v1);
108
+ bins[binIdx].bounds.grow(triangle.v2);
109
+ }
110
+
111
+ // Gather data for the planes between the bins
112
+ AABB leftBox, rightBox;
113
+
114
+ for (int i = 0; i < BINS - 1; i++) {
115
+ leftSum += bins[i].triCount;
116
+ leftBox.grow(bins[i].bounds);
117
+ leftCountArea[i] = leftSum * leftBox.area();
118
+
119
+ rightSum += bins[BINS - 1 - i].triCount;
120
+ rightBox.grow(bins[BINS - 1 - i].bounds);
121
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
122
+ }
123
+ }
124
+
125
+ // Calculate SAH cost for the planes
126
+ scale = (boundsMax - boundsMin) / BINS;
127
+ for (int i = 0; i < BINS - 1; i++) {
128
+ float planeCost = leftCountArea[i] + rightCountArea[i];
129
+ if (planeCost < best_cost) {
130
+ best_axis = axis;
131
+ best_pos = i + 1;
132
+ best_cost = planeCost;
133
+ }
134
+ }
135
+ }
136
+
137
+ return best_cost;
138
+ }
139
+
140
+ void BVH::update_node_bounds(BVHNode &node, AABB &centroidBounds) {
141
+ #ifndef __ARM_ARCH_ISA_A64
142
+ #ifndef _MSC_VER
143
+ if (__builtin_cpu_supports("sse"))
144
+ #elif (defined(_M_AMD64) || defined(_M_X64))
145
+ // SSE supported on Windows
146
+ if constexpr (true)
147
+ #endif
148
+ {
149
+ __m128 min4 = _mm_set_ps1(1e30f), max4 = _mm_set_ps1(-1e30f);
150
+ __m128 cmin4 = _mm_set_ps1(1e30f), cmax4 = _mm_set_ps1(-1e30f);
151
+
152
+ for (int i = node.start; i < node.end; i += 2) {
153
+ int tri_idx1 = triangle_indices[i];
154
+ const Triangle &leafTri1 = triangles[tri_idx1];
155
+ // Check if the second actually exists in the node
156
+ __m128 v0, v1, v2, centroid;
157
+ if (i + 1 < node.end) {
158
+ int tri_idx2 = triangle_indices[i + 1];
159
+ const Triangle leafTri2 = triangles[tri_idx2];
160
+
161
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
162
+ leafTri2.v0.y);
163
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
164
+ leafTri2.v1.y);
165
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
166
+ leafTri2.v2.y);
167
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
168
+ leafTri2.centroid.x, leafTri2.centroid.y);
169
+ } else {
170
+ // Otherwise do some duplicated work
171
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
172
+ leafTri1.v0.y);
173
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
174
+ leafTri1.v1.y);
175
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
176
+ leafTri1.v2.y);
177
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
178
+ leafTri1.centroid.x, leafTri1.centroid.y);
179
+ }
180
+
181
+ min4 = _mm_min_ps(min4, v0);
182
+ max4 = _mm_max_ps(max4, v0);
183
+ min4 = _mm_min_ps(min4, v1);
184
+ max4 = _mm_max_ps(max4, v1);
185
+ min4 = _mm_min_ps(min4, v2);
186
+ max4 = _mm_max_ps(max4, v2);
187
+ cmin4 = _mm_min_ps(cmin4, centroid);
188
+ cmax4 = _mm_max_ps(cmax4, centroid);
189
+ }
190
+
191
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
192
+ _mm_store_ps(min_values, min4);
193
+ _mm_store_ps(max_values, max4);
194
+ _mm_store_ps(cmin_values, cmin4);
195
+ _mm_store_ps(cmax_values, cmax4);
196
+
197
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
198
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
199
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
200
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
201
+
202
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
203
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
204
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
205
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
206
+ }
207
+ #else
208
+ if constexpr (false) {
209
+ }
210
+ #endif
211
+ {
212
+ node.bbox.invalidate();
213
+ centroidBounds.invalidate();
214
+
215
+ // Calculate the bounding box for the node
216
+ for (int i = node.start; i < node.end; ++i) {
217
+ int tri_idx = triangle_indices[i];
218
+ const Triangle &tri = triangles[tri_idx];
219
+ node.bbox.grow(tri.v0);
220
+ node.bbox.grow(tri.v1);
221
+ node.bbox.grow(tri.v2);
222
+ centroidBounds.grow(tri.centroid);
223
+ }
224
+ }
225
+ }
226
+
227
+ void BVH::build(const tb_float2 *vertices, const tb_int3 *indices,
228
+ const int64_t &num_indices) {
229
+ #ifdef TIMING
230
+ auto start = std::chrono::high_resolution_clock::now();
231
+ #endif
232
+ // Create triangles
233
+ for (size_t i = 0; i < num_indices; ++i) {
234
+ tb_int3 idx = indices[i];
235
+ triangles.push_back(
236
+ {vertices[idx.x], vertices[idx.y], vertices[idx.z], static_cast<int>(i),
237
+ triangle_centroid(vertices[idx.x], vertices[idx.y], vertices[idx.z])});
238
+ }
239
+
240
+ // Initialize triangle_indices
241
+ triangle_indices.resize(triangles.size());
242
+ std::iota(triangle_indices.begin(), triangle_indices.end(), 0);
243
+
244
+ // Build BVH nodes
245
+ // Reserve extra capacity to fix windows specific crashes
246
+ nodes.reserve(triangles.size() * 2 + 1);
247
+ nodes.push_back({}); // Create the root node
248
+ root = 0;
249
+
250
+ // Define a struct for queue entries
251
+ struct QueueEntry {
252
+ int node_idx;
253
+ int start;
254
+ int end;
255
+ };
256
+
257
+ // Queue for breadth-first traversal
258
+ std::queue<QueueEntry> node_queue;
259
+ node_queue.push({root, 0, (int)triangles.size()});
260
+
261
+ // Process each node in the queue
262
+ while (!node_queue.empty()) {
263
+ QueueEntry current = node_queue.front();
264
+ node_queue.pop();
265
+
266
+ int node_idx = current.node_idx;
267
+ int start = current.start;
268
+ int end = current.end;
269
+
270
+ BVHNode &node = nodes[node_idx];
271
+ node.start = start;
272
+ node.end = end;
273
+
274
+ // Calculate the bounding box for the node
275
+ AABB centroidBounds;
276
+ update_node_bounds(node, centroidBounds);
277
+
278
+ // Determine the best split using SAH
279
+ int best_axis, best_pos;
280
+
281
+ float splitCost =
282
+ find_best_split_plane(node, best_axis, best_pos, centroidBounds);
283
+ float nosplitCost = node.calculate_node_cost();
284
+
285
+ // Stop condition: if the best cost is greater than or equal to the parent's
286
+ // cost
287
+ if (splitCost >= nosplitCost) {
288
+ // Leaf node
289
+ node.left = node.right = -1;
290
+ continue;
291
+ }
292
+
293
+ float scale =
294
+ BINS / (centroidBounds.max[best_axis] - centroidBounds.min[best_axis]);
295
+ int i = node.start;
296
+ int j = node.end - 1;
297
+
298
+ // Sort the triangle_indices in the range [start, end) based on the best
299
+ // axis
300
+ while (i <= j) {
301
+ // use the exact calculation we used for binning to prevent rare
302
+ // inaccuracies
303
+ int tri_idx = triangle_indices[i];
304
+ tb_float2 tcentr = triangles[tri_idx].centroid;
305
+ int binIdx = std::min(
306
+ BINS - 1,
307
+ (int)((tcentr[best_axis] - centroidBounds.min[best_axis]) * scale));
308
+ if (binIdx < best_pos)
309
+ i++;
310
+ else
311
+ std::swap(triangle_indices[i], triangle_indices[j--]);
312
+ }
313
+ int leftCount = i - node.start;
314
+ if (leftCount == 0 || leftCount == node.num_triangles()) {
315
+ // Leaf node
316
+ node.left = node.right = -1;
317
+ continue;
318
+ }
319
+
320
+ int mid = i;
321
+
322
+ // Create and set left child
323
+ node.left = nodes.size();
324
+ nodes.push_back({});
325
+ node_queue.push({node.left, start, mid});
326
+
327
+ // Create and set right child
328
+ node = nodes[node_idx]; // Update the node - Potentially stale reference
329
+ node.right = nodes.size();
330
+ nodes.push_back({});
331
+ node_queue.push({node.right, mid, end});
332
+ }
333
+ #ifdef TIMING
334
+ auto end = std::chrono::high_resolution_clock::now();
335
+ std::chrono::duration<double> elapsed = end - start;
336
+ std::cout << "BVH build time: " << elapsed.count() << "s" << std::endl;
337
+ #endif
338
+ }
339
+
340
+ // Utility function to clamp a value between a minimum and a maximum
341
+ float clamp(float val, float minVal, float maxVal) {
342
+ return std::min(std::max(val, minVal), maxVal);
343
+ }
344
+
345
+ // Function to check if a point (xy) is inside a triangle defined by vertices
346
+ // v1, v2, v3
347
+ bool barycentric_coordinates(tb_float2 xy, tb_float2 v1, tb_float2 v2,
348
+ tb_float2 v3, float &u, float &v, float &w) {
349
+ // Vectors from v1 to v2, v3 and xy
350
+ tb_float2 v1v2 = {v2.x - v1.x, v2.y - v1.y};
351
+ tb_float2 v1v3 = {v3.x - v1.x, v3.y - v1.y};
352
+ tb_float2 xyv1 = {xy.x - v1.x, xy.y - v1.y};
353
+
354
+ // Dot products of the vectors
355
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
356
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
357
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
358
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
359
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
360
+
361
+ // Calculate the barycentric coordinates
362
+ float denom = d00 * d11 - d01 * d01;
363
+ v = (d11 * d20 - d01 * d21) / denom;
364
+ w = (d00 * d21 - d01 * d20) / denom;
365
+ u = 1.0f - v - w;
366
+
367
+ // Check if the point is inside the triangle
368
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
369
+ }
370
+
371
+ bool BVH::intersect(const tb_float2 &point, float &u, float &v, float &w,
372
+ int &index) const {
373
+ const int max_stack_size = 64;
374
+ int node_stack[max_stack_size];
375
+ int stack_size = 0;
376
+
377
+ node_stack[stack_size++] = root;
378
+
379
+ while (stack_size > 0) {
380
+ int node_idx = node_stack[--stack_size];
381
+ const BVHNode &node = nodes[node_idx];
382
+
383
+ if (node.is_leaf()) {
384
+ for (int i = node.start; i < node.end; ++i) {
385
+ const Triangle &tri = triangles[triangle_indices[i]];
386
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w)) {
387
+ index = tri.index;
388
+ return true;
389
+ }
390
+ }
391
+ } else {
392
+ if (nodes[node.right].bbox.overlaps(point)) {
393
+ if (stack_size < max_stack_size) {
394
+ node_stack[stack_size++] = node.right;
395
+ } else {
396
+ // Handle stack overflow
397
+ throw std::runtime_error("Node stack overflow");
398
+ }
399
+ }
400
+ if (nodes[node.left].bbox.overlaps(point)) {
401
+ if (stack_size < max_stack_size) {
402
+ node_stack[stack_size++] = node.left;
403
+ } else {
404
+ // Handle stack overflow
405
+ throw std::runtime_error("Node stack overflow");
406
+ }
407
+ }
408
+ }
409
+ }
410
+
411
+ return false;
412
+ }
413
+
414
+ torch::Tensor rasterize_cpu(torch::Tensor uv, torch::Tensor indices,
415
+ int64_t bake_resolution) {
416
+ int width = bake_resolution;
417
+ int height = bake_resolution;
418
+ int num_pixels = width * height;
419
+ torch::Tensor rast_result = torch::empty(
420
+ {bake_resolution, bake_resolution, 4},
421
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
422
+
423
+ float *rast_result_ptr = rast_result.contiguous().data_ptr<float>();
424
+ const tb_float2 *vertices = (tb_float2 *)uv.data_ptr<float>();
425
+ const tb_int3 *tris = (tb_int3 *)indices.data_ptr<int>();
426
+
427
+ BVH bvh;
428
+ bvh.build(vertices, tris, indices.size(0));
429
+
430
+ #ifdef TIMING
431
+ auto start = std::chrono::high_resolution_clock::now();
432
+ #endif
433
+
434
+ #pragma omp parallel for
435
+ for (int idx = 0; idx < num_pixels; ++idx) {
436
+ int x = idx / height;
437
+ int y = idx % height;
438
+ int idx_ = idx * 4; // Note: *4 because we're storing float4 per pixel
439
+
440
+ tb_float2 pixel_coord = {float(y) / height, float(x) / width};
441
+ pixel_coord.x = clamp(pixel_coord.x, 0.0f, 1.0f);
442
+ pixel_coord.y = 1.0f - clamp(pixel_coord.y, 0.0f, 1.0f);
443
+
444
+ float u, v, w;
445
+ int triangle_idx;
446
+ if (bvh.intersect(pixel_coord, u, v, w, triangle_idx)) {
447
+ rast_result_ptr[idx_ + 0] = u;
448
+ rast_result_ptr[idx_ + 1] = v;
449
+ rast_result_ptr[idx_ + 2] = w;
450
+ rast_result_ptr[idx_ + 3] = static_cast<float>(triangle_idx);
451
+ } else {
452
+ rast_result_ptr[idx_ + 0] = 0.0f;
453
+ rast_result_ptr[idx_ + 1] = 0.0f;
454
+ rast_result_ptr[idx_ + 2] = 0.0f;
455
+ rast_result_ptr[idx_ + 3] = -1.0f;
456
+ }
457
+ }
458
+
459
+ #ifdef TIMING
460
+ auto end = std::chrono::high_resolution_clock::now();
461
+ std::chrono::duration<double> elapsed = end - start;
462
+ std::cout << "Rasterization time: " << elapsed.count() << "s" << std::endl;
463
+ #endif
464
+ return rast_result;
465
+ }
466
+
467
+ torch::Tensor interpolate_cpu(torch::Tensor attr, torch::Tensor indices,
468
+ torch::Tensor rast) {
469
+ #ifdef TIMING
470
+ auto start = std::chrono::high_resolution_clock::now();
471
+ #endif
472
+ int height = rast.size(0);
473
+ int width = rast.size(1);
474
+ torch::Tensor pos_bake = torch::empty(
475
+ {height, width, 3},
476
+ torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU));
477
+
478
+ const float *attr_ptr = attr.contiguous().data_ptr<float>();
479
+ const int *indices_ptr = indices.contiguous().data_ptr<int>();
480
+ const float *rast_ptr = rast.contiguous().data_ptr<float>();
481
+ float *output_ptr = pos_bake.contiguous().data_ptr<float>();
482
+
483
+ int num_pixels = width * height;
484
+
485
+ #pragma omp parallel for
486
+ for (int idx = 0; idx < num_pixels; ++idx) {
487
+ int idx_ = idx * 4; // Index into the float4 array (4 floats per pixel)
488
+ tb_float3 barycentric = {
489
+ rast_ptr[idx_ + 0],
490
+ rast_ptr[idx_ + 1],
491
+ rast_ptr[idx_ + 2],
492
+ };
493
+ int triangle_idx = static_cast<int>(rast_ptr[idx_ + 3]);
494
+
495
+ if (triangle_idx < 0) {
496
+ output_ptr[idx * 3 + 0] = 0.0f;
497
+ output_ptr[idx * 3 + 1] = 0.0f;
498
+ output_ptr[idx * 3 + 2] = 0.0f;
499
+ continue;
500
+ }
501
+
502
+ tb_int3 triangle = {indices_ptr[3 * triangle_idx + 0],
503
+ indices_ptr[3 * triangle_idx + 1],
504
+ indices_ptr[3 * triangle_idx + 2]};
505
+ tb_float3 v1 = {attr_ptr[3 * triangle.x + 0], attr_ptr[3 * triangle.x + 1],
506
+ attr_ptr[3 * triangle.x + 2]};
507
+ tb_float3 v2 = {attr_ptr[3 * triangle.y + 0], attr_ptr[3 * triangle.y + 1],
508
+ attr_ptr[3 * triangle.y + 2]};
509
+ tb_float3 v3 = {attr_ptr[3 * triangle.z + 0], attr_ptr[3 * triangle.z + 1],
510
+ attr_ptr[3 * triangle.z + 2]};
511
+
512
+ tb_float3 interpolated;
513
+ interpolated.x =
514
+ v1.x * barycentric.x + v2.x * barycentric.y + v3.x * barycentric.z;
515
+ interpolated.y =
516
+ v1.y * barycentric.x + v2.y * barycentric.y + v3.y * barycentric.z;
517
+ interpolated.z =
518
+ v1.z * barycentric.x + v2.z * barycentric.y + v3.z * barycentric.z;
519
+
520
+ output_ptr[idx * 3 + 0] = interpolated.x;
521
+ output_ptr[idx * 3 + 1] = interpolated.y;
522
+ output_ptr[idx * 3 + 2] = interpolated.z;
523
+ }
524
+
525
+ #ifdef TIMING
526
+ auto end = std::chrono::high_resolution_clock::now();
527
+ std::chrono::duration<double> elapsed = end - start;
528
+ std::cout << "Interpolation time: " << elapsed.count() << "s" << std::endl;
529
+ #endif
530
+ return pos_bake;
531
+ }
532
+
533
+ // Registers _C as a Python extension module.
534
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
535
+
536
+ // Defines the operators
537
+ TORCH_LIBRARY(texture_baker_cpp, m) {
538
+ m.def("rasterize(Tensor uv, Tensor indices, int bake_resolution) -> Tensor");
539
+ m.def("interpolate(Tensor attr, Tensor indices, Tensor rast) -> Tensor");
540
+ }
541
+
542
+ // Registers CPP implementations
543
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CPU, m) {
544
+ m.impl("rasterize", &rasterize_cpu);
545
+ m.impl("interpolate", &interpolate_cpu);
546
+ }
547
+
548
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker.h ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__NVCC__) || defined(__HIPCC__) || defined(__METAL__)
4
+ #define CUDA_ENABLED
5
+ #ifndef __METAL__
6
+ #define CUDA_HOST_DEVICE __host__ __device__
7
+ #define CUDA_DEVICE __device__
8
+ #define METAL_CONSTANT_MEM
9
+ #define METAL_THREAD_MEM
10
+ #else
11
+ #define tb_float2 float2
12
+ #define CUDA_HOST_DEVICE
13
+ #define CUDA_DEVICE
14
+ #define METAL_CONSTANT_MEM constant
15
+ #define METAL_THREAD_MEM thread
16
+ #endif
17
+ #else
18
+ #define CUDA_HOST_DEVICE
19
+ #define CUDA_DEVICE
20
+ #define METAL_CONSTANT_MEM
21
+ #define METAL_THREAD_MEM
22
+ #include <cfloat>
23
+ #include <limits>
24
+ #include <vector>
25
+ #endif
26
+
27
+ namespace texture_baker_cpp {
28
+ // Structure to represent a 2D point or vector
29
+ #ifndef __METAL__
30
+ union alignas(8) tb_float2 {
31
+ struct {
32
+ float x, y;
33
+ };
34
+
35
+ float data[2];
36
+
37
+ float &operator[](size_t idx) {
38
+ if (idx > 1)
39
+ throw std::runtime_error("bad index");
40
+ return data[idx];
41
+ }
42
+
43
+ const float &operator[](size_t idx) const {
44
+ if (idx > 1)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ bool operator==(const tb_float2 &rhs) const {
50
+ return x == rhs.x && y == rhs.y;
51
+ }
52
+ };
53
+
54
+ union alignas(4) tb_float3 {
55
+ struct {
56
+ float x, y, z;
57
+ };
58
+
59
+ float data[3];
60
+
61
+ float &operator[](size_t idx) {
62
+ if (idx > 2)
63
+ throw std::runtime_error("bad index");
64
+ return data[idx];
65
+ }
66
+
67
+ const float &operator[](size_t idx) const {
68
+ if (idx > 2)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+ };
73
+
74
+ union alignas(16) tb_float4 {
75
+ struct {
76
+ float x, y, z, w;
77
+ };
78
+
79
+ float data[4];
80
+
81
+ float &operator[](size_t idx) {
82
+ if (idx > 3)
83
+ throw std::runtime_error("bad index");
84
+ return data[idx];
85
+ }
86
+
87
+ const float &operator[](size_t idx) const {
88
+ if (idx > 3)
89
+ throw std::runtime_error("bad index");
90
+ return data[idx];
91
+ }
92
+ };
93
+ #endif
94
+
95
+ union alignas(4) tb_int3 {
96
+ struct {
97
+ int x, y, z;
98
+ };
99
+
100
+ int data[3];
101
+ #ifndef __METAL__
102
+ int &operator[](size_t idx) {
103
+ if (idx > 2)
104
+ throw std::runtime_error("bad index");
105
+ return data[idx];
106
+ }
107
+ #endif
108
+ };
109
+
110
+ // BVH structure to accelerate point-triangle intersection
111
+ struct alignas(16) AABB {
112
+ // Init bounding boxes with max/min
113
+ tb_float2 min = {FLT_MAX, FLT_MAX};
114
+ tb_float2 max = {FLT_MIN, FLT_MIN};
115
+
116
+ #ifndef CUDA_ENABLED
117
+ // grow the AABB to include a point
118
+ void grow(const tb_float2 &p) {
119
+ min.x = std::min(min.x, p.x);
120
+ min.y = std::min(min.y, p.y);
121
+ max.x = std::max(max.x, p.x);
122
+ max.y = std::max(max.y, p.y);
123
+ }
124
+
125
+ void grow(const AABB &b) {
126
+ if (b.min.x != FLT_MAX) {
127
+ grow(b.min);
128
+ grow(b.max);
129
+ }
130
+ }
131
+ #endif
132
+
133
+ // Check if two AABBs overlap
134
+ bool overlaps(const METAL_THREAD_MEM AABB &other) const {
135
+ return min.x <= other.max.x && max.x >= other.min.x &&
136
+ min.y <= other.max.y && max.y >= other.min.y;
137
+ }
138
+
139
+ bool overlaps(const METAL_THREAD_MEM tb_float2 &point) const {
140
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
141
+ point.y <= max.y;
142
+ }
143
+
144
+ #if defined(__NVCC__) || defined(__HIPCC__)
145
+ CUDA_DEVICE bool overlaps(const float2 &point) const {
146
+ return point.x >= min.x && point.x <= max.x && point.y >= min.y &&
147
+ point.y <= max.y;
148
+ }
149
+ #endif
150
+
151
+ // Initialize AABB to an invalid state
152
+ void invalidate() {
153
+ min = {FLT_MAX, FLT_MAX};
154
+ max = {FLT_MIN, FLT_MIN};
155
+ }
156
+
157
+ // Calculate the area of the AABB
158
+ float area() const {
159
+ tb_float2 extent = {max.x - min.x, max.y - min.y};
160
+ return extent.x * extent.y;
161
+ }
162
+ };
163
+
164
+ struct BVHNode {
165
+ AABB bbox;
166
+ int start, end;
167
+ int left, right;
168
+
169
+ int num_triangles() const { return end - start; }
170
+
171
+ CUDA_HOST_DEVICE bool is_leaf() const { return left == -1 && right == -1; }
172
+
173
+ float calculate_node_cost() {
174
+ float area = bbox.area();
175
+ return num_triangles() * area;
176
+ }
177
+ };
178
+
179
+ struct Triangle {
180
+ tb_float2 v0, v1, v2;
181
+ int index;
182
+ tb_float2 centroid;
183
+ };
184
+
185
+ #ifndef __METAL__
186
+ struct BVH {
187
+ std::vector<BVHNode> nodes;
188
+ std::vector<Triangle> triangles;
189
+ std::vector<int> triangle_indices;
190
+ int root;
191
+
192
+ void build(const tb_float2 *vertices, const tb_int3 *indices,
193
+ const int64_t &num_indices);
194
+ bool intersect(const tb_float2 &point, float &u, float &v, float &w,
195
+ int &index) const;
196
+
197
+ void update_node_bounds(BVHNode &node, AABB &centroidBounds);
198
+ float find_best_split_plane(const BVHNode &node, int &best_axis,
199
+ int &best_pos, AABB &centroidBounds);
200
+ };
201
+ #endif
202
+
203
+ } // namespace texture_baker_cpp
texture_baker/texture_baker/csrc/baker_kernel.cu ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <ATen/Context.h>
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <torch/extension.h>
5
+
6
+ #include "baker.h"
7
+
8
+ // #define TIMING
9
+
10
+ #define STRINGIFY(x) #x
11
+ #define STR(x) STRINGIFY(x)
12
+ #define FILE_LINE __FILE__ ":" STR(__LINE__)
13
+ #define CUDA_CHECK_THROW(x) \
14
+ do { \
15
+ cudaError_t _result = x; \
16
+ if (_result != cudaSuccess) \
17
+ throw std::runtime_error(std::string(FILE_LINE " check failed " #x " failed: ") + cudaGetErrorString(_result)); \
18
+ } while(0)
19
+
20
+ #if defined(__HIPCC__)
21
+ #define cudaMallocAsync hipMallocAsync
22
+ #define cudaFreeAsync hipFreeAsync
23
+ #endif
24
+
25
+ namespace texture_baker_cpp
26
+ {
27
+
28
+ __device__ float3 operator+(const float3 &a, const float3 &b)
29
+ {
30
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
31
+ }
32
+
33
+ // xy: 2D test position
34
+ // v1: vertex position 1
35
+ // v2: vertex position 2
36
+ // v3: vertex position 3
37
+ //
38
+ __forceinline__ __device__ bool barycentric_coordinates(const float2 &xy, const tb_float2 &v1, const tb_float2 &v2, const tb_float2 &v3, float &u, float &v, float &w)
39
+ {
40
+ // Return true if the point (xy) is inside the triangle defined by the vertices v1, v2, v3.
41
+ // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
42
+ float2 v1v2 = make_float2(v2.x - v1.x, v2.y - v1.y);
43
+ float2 v1v3 = make_float2(v3.x - v1.x, v3.y - v1.y);
44
+ float2 xyv1 = make_float2(xy.x - v1.x, xy.y - v1.y);
45
+
46
+ float d00 = v1v2.x * v1v2.x + v1v2.y * v1v2.y;
47
+ float d01 = v1v2.x * v1v3.x + v1v2.y * v1v3.y;
48
+ float d11 = v1v3.x * v1v3.x + v1v3.y * v1v3.y;
49
+ float d20 = xyv1.x * v1v2.x + xyv1.y * v1v2.y;
50
+ float d21 = xyv1.x * v1v3.x + xyv1.y * v1v3.y;
51
+
52
+ float denom = d00 * d11 - d01 * d01;
53
+ v = (d11 * d20 - d01 * d21) / denom;
54
+ w = (d00 * d21 - d01 * d20) / denom;
55
+ u = 1.0f - v - w;
56
+
57
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
58
+ }
59
+
60
+ __global__ void kernel_interpolate(const float3* __restrict__ attr, const int3* __restrict__ indices, const float4* __restrict__ rast, float3* __restrict__ output, int width, int height)
61
+ {
62
+ // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
63
+ //int idx = x * width + y;
64
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
65
+ int x = idx / width;
66
+ int y = idx % width;
67
+
68
+ if (x >= width || y >= height)
69
+ return;
70
+
71
+ float4 barycentric = rast[idx];
72
+ int triangle_idx = int(barycentric.w);
73
+
74
+ if (triangle_idx < 0)
75
+ {
76
+ output[idx] = make_float3(0.0f, 0.0f, 0.0f);
77
+ return;
78
+ }
79
+
80
+ float3 v1 = attr[indices[triangle_idx].x];
81
+ float3 v2 = attr[indices[triangle_idx].y];
82
+ float3 v3 = attr[indices[triangle_idx].z];
83
+
84
+ output[idx] = make_float3(v1.x * barycentric.x, v1.y * barycentric.x, v1.z * barycentric.x)
85
+ + make_float3(v2.x * barycentric.y, v2.y * barycentric.y, v2.z * barycentric.y)
86
+ + make_float3(v3.x * barycentric.z, v3.y * barycentric.z, v3.z * barycentric.z);
87
+ }
88
+
89
+ __device__ bool bvh_intersect(
90
+ const BVHNode* __restrict__ nodes,
91
+ const Triangle* __restrict__ triangles,
92
+ const int* __restrict__ triangle_indices,
93
+ const int root,
94
+ const float2 &point,
95
+ float &u, float &v, float &w,
96
+ int &index)
97
+ {
98
+ constexpr int max_stack_size = 64;
99
+ int node_stack[max_stack_size];
100
+ int stack_size = 0;
101
+
102
+ node_stack[stack_size++] = root;
103
+
104
+ while (stack_size > 0)
105
+ {
106
+ int node_idx = node_stack[--stack_size];
107
+ const BVHNode &node = nodes[node_idx];
108
+
109
+ if (node.is_leaf())
110
+ {
111
+ for (int i = node.start; i < node.end; ++i)
112
+ {
113
+ const Triangle &tri = triangles[triangle_indices[i]];
114
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
115
+ {
116
+ index = tri.index;
117
+ return true;
118
+ }
119
+ }
120
+ }
121
+ else
122
+ {
123
+ if (nodes[node.right].bbox.overlaps(point))
124
+ {
125
+ if (stack_size < max_stack_size)
126
+ {
127
+ node_stack[stack_size++] = node.right;
128
+ }
129
+ else
130
+ {
131
+ // Handle stack overflow
132
+ // Make sure NDEBUG is not defined (see setup.py)
133
+ assert(0 && "Node stack overflow");
134
+ }
135
+ }
136
+ if (nodes[node.left].bbox.overlaps(point))
137
+ {
138
+ if (stack_size < max_stack_size)
139
+ {
140
+ node_stack[stack_size++] = node.left;
141
+ }
142
+ else
143
+ {
144
+ // Handle stack overflow
145
+ // Make sure NDEBUG is not defined (see setup.py)
146
+ assert(0 && "Node stack overflow");
147
+ }
148
+ }
149
+ }
150
+ }
151
+
152
+ return false;
153
+ }
154
+
155
+ __global__ void kernel_bake_uv(
156
+ float2* __restrict__ uv,
157
+ int3* __restrict__ indices,
158
+ float4* __restrict__ output,
159
+ const BVHNode* __restrict__ nodes,
160
+ const Triangle* __restrict__ triangles,
161
+ const int* __restrict__ triangle_indices,
162
+ const int root,
163
+ const int width,
164
+ const int height,
165
+ const int num_indices)
166
+ {
167
+ //int idx = x * width + y;
168
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
169
+ int x = idx / width;
170
+ int y = idx % width;
171
+
172
+ if (y >= width || x >= height)
173
+ return;
174
+
175
+ // We index x,y but the original coords are HW. So swap them
176
+ float2 pixel_coord = make_float2(float(y) / height, float(x) / width);
177
+ pixel_coord.x = fminf(fmaxf(pixel_coord.x, 0.0f), 1.0f);
178
+ pixel_coord.y = 1.0f - fminf(fmaxf(pixel_coord.y, 0.0f), 1.0f);
179
+
180
+ float u, v, w;
181
+ int triangle_idx;
182
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
183
+
184
+ if (hit)
185
+ {
186
+ output[idx] = make_float4(u, v, w, float(triangle_idx));
187
+ return;
188
+ }
189
+
190
+ output[idx] = make_float4(0.0f, 0.0f, 0.0f, -1.0f);
191
+ }
192
+
193
+ torch::Tensor rasterize_gpu(
194
+ torch::Tensor uv,
195
+ torch::Tensor indices,
196
+ int64_t bake_resolution)
197
+ {
198
+ #ifdef TIMING
199
+ auto start = std::chrono::high_resolution_clock::now();
200
+ #endif
201
+ constexpr int block_size = 16 * 16;
202
+ int grid_size = bake_resolution * bake_resolution / block_size;
203
+ dim3 block_dims(block_size, 1, 1);
204
+ dim3 grid_dims(grid_size, 1, 1);
205
+
206
+ int num_indices = indices.size(0);
207
+
208
+ int width = bake_resolution;
209
+ int height = bake_resolution;
210
+
211
+ // Step 1: create an empty tensor to store the output.
212
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
213
+
214
+ auto vertices_cpu = uv.contiguous().cpu();
215
+ auto indices_cpu = indices.contiguous().cpu();
216
+
217
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
218
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
219
+
220
+ BVH bvh;
221
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
222
+
223
+ BVHNode *nodes_gpu = nullptr;
224
+ Triangle *triangles_gpu = nullptr;
225
+ int *triangle_indices_gpu = nullptr;
226
+ const int bvh_root = bvh.root;
227
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
228
+
229
+ CUDA_CHECK_THROW(cudaMallocAsync(&nodes_gpu, sizeof(BVHNode) * bvh.nodes.size(), cuda_stream));
230
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangles_gpu, sizeof(Triangle) * bvh.triangles.size(), cuda_stream));
231
+ CUDA_CHECK_THROW(cudaMallocAsync(&triangle_indices_gpu, sizeof(int) * bvh.triangle_indices.size(), cuda_stream));
232
+
233
+ CUDA_CHECK_THROW(cudaMemcpyAsync(nodes_gpu, bvh.nodes.data(), sizeof(BVHNode) * bvh.nodes.size(), cudaMemcpyHostToDevice, cuda_stream));
234
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangles_gpu, bvh.triangles.data(), sizeof(Triangle) * bvh.triangles.size(), cudaMemcpyHostToDevice, cuda_stream));
235
+ CUDA_CHECK_THROW(cudaMemcpyAsync(triangle_indices_gpu, bvh.triangle_indices.data(), sizeof(int) * bvh.triangle_indices.size(), cudaMemcpyHostToDevice, cuda_stream));
236
+
237
+ kernel_bake_uv<<<grid_dims, block_dims, 0, cuda_stream>>>(
238
+ (float2 *)uv.contiguous().data_ptr<float>(),
239
+ (int3 *)indices.contiguous().data_ptr<int>(),
240
+ (float4 *)rast_result.contiguous().data_ptr<float>(),
241
+ nodes_gpu,
242
+ triangles_gpu,
243
+ triangle_indices_gpu,
244
+ bvh_root,
245
+ width,
246
+ height,
247
+ num_indices);
248
+
249
+ CUDA_CHECK_THROW(cudaFreeAsync(nodes_gpu, cuda_stream));
250
+ CUDA_CHECK_THROW(cudaFreeAsync(triangles_gpu, cuda_stream));
251
+ CUDA_CHECK_THROW(cudaFreeAsync(triangle_indices_gpu, cuda_stream));
252
+
253
+ #ifdef TIMING
254
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
255
+ auto end = std::chrono::high_resolution_clock::now();
256
+ std::chrono::duration<double> elapsed = end - start;
257
+ std::cout << "Rasterization time (CUDA): " << elapsed.count() << "s" << std::endl;
258
+ #endif
259
+ return rast_result;
260
+ }
261
+
262
+ torch::Tensor interpolate_gpu(
263
+ torch::Tensor attr,
264
+ torch::Tensor indices,
265
+ torch::Tensor rast)
266
+ {
267
+ #ifdef TIMING
268
+ auto start = std::chrono::high_resolution_clock::now();
269
+ #endif
270
+ constexpr int block_size = 16 * 16;
271
+ int grid_size = rast.size(0) * rast.size(0) / block_size;
272
+ dim3 block_dims(block_size, 1, 1);
273
+ dim3 grid_dims(grid_size, 1, 1);
274
+
275
+ // Step 1: create an empty tensor to store the output.
276
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
277
+
278
+ int width = rast.size(0);
279
+ int height = rast.size(1);
280
+
281
+ cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
282
+
283
+ kernel_interpolate<<<grid_dims, block_dims, 0, cuda_stream>>>(
284
+ (float3 *)attr.contiguous().data_ptr<float>(),
285
+ (int3 *)indices.contiguous().data_ptr<int>(),
286
+ (float4 *)rast.contiguous().data_ptr<float>(),
287
+ (float3 *)pos_bake.contiguous().data_ptr<float>(),
288
+ width,
289
+ height);
290
+ #ifdef TIMING
291
+ CUDA_CHECK_THROW(cudaStreamSynchronize(cuda_stream));
292
+ auto end = std::chrono::high_resolution_clock::now();
293
+ std::chrono::duration<double> elapsed = end - start;
294
+ std::cout << "Interpolation time (CUDA): " << elapsed.count() << "s" << std::endl;
295
+ #endif
296
+ return pos_bake;
297
+ }
298
+
299
+ // Registers CUDA implementations
300
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, CUDA, m)
301
+ {
302
+ m.impl("rasterize", &rasterize_gpu);
303
+ m.impl("interpolate", &interpolate_gpu);
304
+ }
305
+
306
+ }
texture_baker/texture_baker/csrc/baker_kernel.metal ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ // This header is inlined manually
5
+ //#include "baker.h"
6
+
7
+ // Use the texture_baker_cpp so it can use the classes from baker.h
8
+ using namespace texture_baker_cpp;
9
+
10
+ // Utility function to compute barycentric coordinates
11
+ bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, thread float &u, thread float &v, thread float &w) {
12
+ float2 v1v2 = v2 - v1;
13
+ float2 v1v3 = v3 - v1;
14
+ float2 xyv1 = xy - v1;
15
+
16
+ float d00 = dot(v1v2, v1v2);
17
+ float d01 = dot(v1v2, v1v3);
18
+ float d11 = dot(v1v3, v1v3);
19
+ float d20 = dot(xyv1, v1v2);
20
+ float d21 = dot(xyv1, v1v3);
21
+
22
+ float denom = d00 * d11 - d01 * d01;
23
+ v = (d11 * d20 - d01 * d21) / denom;
24
+ w = (d00 * d21 - d01 * d20) / denom;
25
+ u = 1.0f - v - w;
26
+
27
+ return (v >= 0.0f) && (w >= 0.0f) && (v + w <= 1.0f);
28
+ }
29
+
30
+ // Kernel function for interpolation
31
+ kernel void kernel_interpolate(constant packed_float3 *attr [[buffer(0)]],
32
+ constant packed_int3 *indices [[buffer(1)]],
33
+ constant packed_float4 *rast [[buffer(2)]],
34
+ device packed_float3 *output [[buffer(3)]],
35
+ constant int &width [[buffer(4)]],
36
+ constant int &height [[buffer(5)]],
37
+ uint3 blockIdx [[threadgroup_position_in_grid]],
38
+ uint3 threadIdx [[thread_position_in_threadgroup]],
39
+ uint3 blockDim [[threads_per_threadgroup]])
40
+ {
41
+ // Calculate global position using threadgroup and thread positions
42
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
43
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
44
+
45
+ if (x >= width || y >= height) return;
46
+
47
+ int idx = y * width + x;
48
+ float4 barycentric = rast[idx];
49
+ int triangle_idx = int(barycentric.w);
50
+
51
+ if (triangle_idx < 0) {
52
+ output[idx] = float3(0.0f, 0.0f, 0.0f);
53
+ return;
54
+ }
55
+
56
+ float3 v1 = attr[indices[triangle_idx].x];
57
+ float3 v2 = attr[indices[triangle_idx].y];
58
+ float3 v3 = attr[indices[triangle_idx].z];
59
+
60
+ output[idx] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
61
+ }
62
+
63
+ bool bvh_intersect(
64
+ constant BVHNode* nodes,
65
+ constant Triangle* triangles,
66
+ constant int* triangle_indices,
67
+ const thread int root,
68
+ const thread float2 &point,
69
+ thread float &u, thread float &v, thread float &w,
70
+ thread int &index)
71
+ {
72
+ const int max_stack_size = 64;
73
+ thread int node_stack[max_stack_size];
74
+ int stack_size = 0;
75
+
76
+ node_stack[stack_size++] = root;
77
+
78
+ while (stack_size > 0)
79
+ {
80
+ int node_idx = node_stack[--stack_size];
81
+ BVHNode node = nodes[node_idx];
82
+
83
+ if (node.is_leaf())
84
+ {
85
+ for (int i = node.start; i < node.end; ++i)
86
+ {
87
+ constant Triangle &tri = triangles[triangle_indices[i]];
88
+ if (barycentric_coordinates(point, tri.v0, tri.v1, tri.v2, u, v, w))
89
+ {
90
+ index = tri.index;
91
+ return true;
92
+ }
93
+ }
94
+ }
95
+ else
96
+ {
97
+ BVHNode test_node = nodes[node.right];
98
+ if (test_node.bbox.overlaps(point))
99
+ {
100
+ if (stack_size < max_stack_size)
101
+ {
102
+ node_stack[stack_size++] = node.right;
103
+ }
104
+ else
105
+ {
106
+ // Handle stack overflow
107
+ // Sadly, metal doesn't support asserts (but you could try enabling metal validation layers)
108
+ return false;
109
+ }
110
+ }
111
+ test_node = nodes[node.left];
112
+ if (test_node.bbox.overlaps(point))
113
+ {
114
+ if (stack_size < max_stack_size)
115
+ {
116
+ node_stack[stack_size++] = node.left;
117
+ }
118
+ else
119
+ {
120
+ // Handle stack overflow
121
+ return false;
122
+ }
123
+ }
124
+ }
125
+ }
126
+
127
+ return false;
128
+ }
129
+
130
+
131
+ // Kernel function for baking UV
132
+ kernel void kernel_bake_uv(constant packed_float2 *uv [[buffer(0)]],
133
+ constant packed_int3 *indices [[buffer(1)]],
134
+ device packed_float4 *output [[buffer(2)]],
135
+ constant BVHNode *nodes [[buffer(3)]],
136
+ constant Triangle *triangles [[buffer(4)]],
137
+ constant int *triangle_indices [[buffer(5)]],
138
+ constant int &root [[buffer(6)]],
139
+ constant int &width [[buffer(7)]],
140
+ constant int &height [[buffer(8)]],
141
+ constant int &num_indices [[buffer(9)]],
142
+ uint3 blockIdx [[threadgroup_position_in_grid]],
143
+ uint3 threadIdx [[thread_position_in_threadgroup]],
144
+ uint3 blockDim [[threads_per_threadgroup]])
145
+ {
146
+ // Calculate global position using threadgroup and thread positions
147
+ int x = blockIdx.x * blockDim.x + threadIdx.x;
148
+ int y = blockIdx.y * blockDim.y + threadIdx.y;
149
+
150
+
151
+ if (x >= width || y >= height) return;
152
+
153
+ int idx = x * width + y;
154
+
155
+ // Swap original coordinates
156
+ float2 pixel_coord = float2(float(y) / float(height), float(x) / float(width));
157
+ pixel_coord = clamp(pixel_coord, 0.0f, 1.0f);
158
+ pixel_coord.y = 1.0f - pixel_coord.y;
159
+
160
+ float u, v, w;
161
+ int triangle_idx;
162
+ bool hit = bvh_intersect(nodes, triangles, triangle_indices, root, pixel_coord, u, v, w, triangle_idx);
163
+
164
+ if (hit) {
165
+ output[idx] = float4(u, v, w, float(triangle_idx));
166
+ return;
167
+ }
168
+
169
+ output[idx] = float4(0.0f, 0.0f, 0.0f, -1.0f);
170
+ }
texture_baker/texture_baker/csrc/baker_kernel.mm ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <ATen/ATen.h>
3
+ #include <ATen/Context.h>
4
+ #include "baker.h"
5
+
6
+ #import <Foundation/Foundation.h>
7
+ #import <Metal/Metal.h>
8
+ #include <filesystem>
9
+
10
+ // Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
11
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
12
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
13
+ }
14
+
15
+ // Helper function to create a compute pipeline state object (PSO).
16
+ static inline id<MTLComputePipelineState> createComputePipelineState(id<MTLDevice> device, NSString* fullSource, std::string kernel_name) {
17
+ NSError *error = nil;
18
+
19
+ // Load the custom kernel shader.
20
+ MTLCompileOptions *options = [[MTLCompileOptions alloc] init];
21
+ // Add the preprocessor macro "__METAL__"
22
+ options.preprocessorMacros = @{@"__METAL__": @""};
23
+ id<MTLLibrary> customKernelLibrary = [device newLibraryWithSource: fullSource options:options error:&error];
24
+ TORCH_CHECK(customKernelLibrary, "Failed to create custom kernel library, error: ", error.localizedDescription.UTF8String);
25
+
26
+ id<MTLFunction> customKernelFunction = [customKernelLibrary newFunctionWithName:[NSString stringWithUTF8String:kernel_name.c_str()]];
27
+ TORCH_CHECK(customKernelFunction, "Failed to create function state object for ", kernel_name.c_str());
28
+
29
+ id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:customKernelFunction error:&error];
30
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
31
+
32
+ return pso;
33
+ }
34
+
35
+ std::filesystem::path get_extension_path() {
36
+ // Ensure the GIL is held before calling any Python C API function
37
+ PyGILState_STATE gstate = PyGILState_Ensure();
38
+
39
+ const char* module_name = "texture_baker";
40
+
41
+ // Import the module by name
42
+ PyObject* module = PyImport_ImportModule(module_name);
43
+ if (!module) {
44
+ PyGILState_Release(gstate);
45
+ throw std::runtime_error("Could not import the module: " + std::string(module_name));
46
+ }
47
+
48
+ // Get the filename of the module
49
+ PyObject* filename_obj = PyModule_GetFilenameObject(module);
50
+ if (filename_obj) {
51
+ std::string path = PyUnicode_AsUTF8(filename_obj);
52
+ Py_DECREF(filename_obj);
53
+ PyGILState_Release(gstate);
54
+
55
+ // Get the directory part of the path (removing the __init__.py)
56
+ std::filesystem::path module_path = std::filesystem::path(path).parent_path();
57
+
58
+ // Append the 'csrc' directory to the path
59
+ module_path /= "csrc";
60
+
61
+ return module_path;
62
+ } else {
63
+ PyGILState_Release(gstate);
64
+ throw std::runtime_error("Could not retrieve the module filename.");
65
+ }
66
+ }
67
+
68
+ NSString *get_shader_sources_as_string()
69
+ {
70
+ const std::filesystem::path csrc_path = get_extension_path();
71
+ const std::string shader_path = (csrc_path / "baker_kernel.metal").string();
72
+ const std::string shader_header_path = (csrc_path / "baker.h").string();
73
+ // Load the Metal shader from the specified path
74
+ NSError *error = nil;
75
+
76
+ NSString* shaderHeaderSource = [
77
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_header_path.c_str()]
78
+ encoding:NSUTF8StringEncoding
79
+ error:&error];
80
+ if (error) {
81
+ throw std::runtime_error("Failed to load baker.h: " + std::string(error.localizedDescription.UTF8String));
82
+ }
83
+
84
+ NSString* shaderSource = [
85
+ NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shader_path.c_str()]
86
+ encoding:NSUTF8StringEncoding
87
+ error:&error];
88
+ if (error) {
89
+ throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String));
90
+ }
91
+
92
+ NSString *fullSource = [shaderHeaderSource stringByAppendingString:shaderSource];
93
+
94
+ return fullSource;
95
+ }
96
+
97
+ namespace texture_baker_cpp
98
+ {
99
+ torch::Tensor rasterize_gpu(
100
+ torch::Tensor uv,
101
+ torch::Tensor indices,
102
+ int64_t bake_resolution)
103
+ {
104
+ TORCH_CHECK(uv.device().is_mps(), "uv must be a MPS tensor");
105
+ TORCH_CHECK(uv.is_contiguous(), "uv must be contiguous");
106
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
107
+
108
+ TORCH_CHECK(uv.scalar_type() == torch::kFloat32, "Unsupported data type: ", indices.scalar_type());
109
+ TORCH_CHECK(indices.scalar_type() == torch::kInt32, "Unsupported data type: ", indices.scalar_type());
110
+
111
+ torch::Tensor rast_result = torch::empty({bake_resolution, bake_resolution, 4}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
112
+
113
+ @autoreleasepool {
114
+ auto vertices_cpu = uv.contiguous().cpu();
115
+ auto indices_cpu = indices.contiguous().cpu();
116
+
117
+ const tb_float2 *vertices_cpu_ptr = (tb_float2*)vertices_cpu.contiguous().data_ptr<float>();
118
+ const tb_int3 *tris_cpu_ptr = (tb_int3*)indices_cpu.contiguous().data_ptr<int>();
119
+
120
+ BVH bvh;
121
+ bvh.build(vertices_cpu_ptr, tris_cpu_ptr, indices.size(0));
122
+
123
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
124
+
125
+ NSString *fullSource = get_shader_sources_as_string();
126
+
127
+ // Create a compute pipeline state object using the helper function
128
+ id<MTLComputePipelineState> bake_uv_PSO = createComputePipelineState(device, fullSource, "kernel_bake_uv");
129
+
130
+ // Get a reference to the command buffer for the MPS stream.
131
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
132
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
133
+
134
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
135
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
136
+
137
+ dispatch_sync(serialQueue, ^(){
138
+ // Start a compute pass.
139
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
140
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
141
+
142
+ // Get Metal buffers directly from PyTorch tensors
143
+ auto uv_buf = getMTLBufferStorage(uv.contiguous());
144
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
145
+ auto rast_result_buf = getMTLBufferStorage(rast_result);
146
+
147
+ const int width = bake_resolution;
148
+ const int height = bake_resolution;
149
+ const int num_indices = indices.size(0);
150
+ const int bvh_root = bvh.root;
151
+
152
+ // Wrap the existing CPU memory in Metal buffers with shared memory
153
+ id<MTLBuffer> nodesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.nodes.data() length:sizeof(BVHNode) * bvh.nodes.size() options:MTLResourceStorageModeShared deallocator:nil];
154
+ id<MTLBuffer> trianglesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangles.data() length:sizeof(Triangle) * bvh.triangles.size() options:MTLResourceStorageModeShared deallocator:nil];
155
+ id<MTLBuffer> triangleIndicesBuffer = [device newBufferWithBytesNoCopy:(void*)bvh.triangle_indices.data() length:sizeof(int) * bvh.triangle_indices.size() options:MTLResourceStorageModeShared deallocator:nil];
156
+
157
+ [computeEncoder setComputePipelineState:bake_uv_PSO];
158
+ [computeEncoder setBuffer:uv_buf offset:uv.storage_offset() * uv.element_size() atIndex:0];
159
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
160
+ [computeEncoder setBuffer:rast_result_buf offset:rast_result.storage_offset() * rast_result.element_size() atIndex:2];
161
+ [computeEncoder setBuffer:nodesBuffer offset:0 atIndex:3];
162
+ [computeEncoder setBuffer:trianglesBuffer offset:0 atIndex:4];
163
+ [computeEncoder setBuffer:triangleIndicesBuffer offset:0 atIndex:5];
164
+ [computeEncoder setBytes:&bvh_root length:sizeof(int) atIndex:6];
165
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:7];
166
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:8];
167
+ [computeEncoder setBytes:&num_indices length:sizeof(int) atIndex:9];
168
+
169
+ // Calculate a thread group size.
170
+ int block_size = 16;
171
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
172
+ MTLSize numThreadgroups = MTLSizeMake(bake_resolution / block_size, bake_resolution / block_size, 1);
173
+
174
+ // Encode the compute command.
175
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
176
+ [computeEncoder endEncoding];
177
+
178
+ // Commit the work.
179
+ torch::mps::commit();
180
+ });
181
+ }
182
+
183
+ return rast_result;
184
+ }
185
+
186
+ torch::Tensor interpolate_gpu(
187
+ torch::Tensor attr,
188
+ torch::Tensor indices,
189
+ torch::Tensor rast)
190
+ {
191
+ TORCH_CHECK(attr.is_contiguous(), "attr must be contiguous");
192
+ TORCH_CHECK(indices.is_contiguous(), "indices must be contiguous");
193
+ TORCH_CHECK(rast.is_contiguous(), "rast must be contiguous");
194
+
195
+ torch::Tensor pos_bake = torch::empty({rast.size(0), rast.size(1), 3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kMPS)).contiguous();
196
+ std::filesystem::path csrc_path = get_extension_path();
197
+
198
+ @autoreleasepool {
199
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
200
+
201
+ NSString *fullSource = get_shader_sources_as_string();
202
+ // Create a compute pipeline state object using the helper function
203
+ id<MTLComputePipelineState> interpolate_PSO = createComputePipelineState(device, fullSource, "kernel_interpolate");
204
+
205
+ // Get a reference to the command buffer for the MPS stream.
206
+ id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
207
+ TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference");
208
+
209
+ // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.
210
+ dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
211
+
212
+ dispatch_sync(serialQueue, ^(){
213
+ // Start a compute pass.
214
+ id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
215
+ TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
216
+
217
+ // Get Metal buffers directly from PyTorch tensors
218
+ auto attr_buf = getMTLBufferStorage(attr.contiguous());
219
+ auto indices_buf = getMTLBufferStorage(indices.contiguous());
220
+ auto rast_buf = getMTLBufferStorage(rast.contiguous());
221
+ auto pos_bake_buf = getMTLBufferStorage(pos_bake);
222
+
223
+ int width = rast.size(0);
224
+ int height = rast.size(1);
225
+
226
+ [computeEncoder setComputePipelineState:interpolate_PSO];
227
+ [computeEncoder setBuffer:attr_buf offset:attr.storage_offset() * attr.element_size() atIndex:0];
228
+ [computeEncoder setBuffer:indices_buf offset:indices.storage_offset() * indices.element_size() atIndex:1];
229
+ [computeEncoder setBuffer:rast_buf offset:rast.storage_offset() * rast.element_size() atIndex:2];
230
+ [computeEncoder setBuffer:pos_bake_buf offset:pos_bake.storage_offset() * pos_bake.element_size() atIndex:3];
231
+ [computeEncoder setBytes:&width length:sizeof(int) atIndex:4];
232
+ [computeEncoder setBytes:&height length:sizeof(int) atIndex:5];
233
+
234
+ // Calculate a thread group size.
235
+
236
+ int block_size = 16;
237
+ MTLSize threadgroupSize = MTLSizeMake(block_size, block_size, 1); // Fixed threadgroup size
238
+ MTLSize numThreadgroups = MTLSizeMake(rast.size(0) / block_size, rast.size(0) / block_size, 1);
239
+
240
+ // Encode the compute command.
241
+ [computeEncoder dispatchThreadgroups:numThreadgroups threadsPerThreadgroup:threadgroupSize];
242
+
243
+ [computeEncoder endEncoding];
244
+
245
+ // Commit the work.
246
+ torch::mps::commit();
247
+ });
248
+ }
249
+
250
+ return pos_bake;
251
+ }
252
+
253
+ // Registers MPS implementations
254
+ TORCH_LIBRARY_IMPL(texture_baker_cpp, MPS, m)
255
+ {
256
+ m.impl("rasterize", &rasterize_gpu);
257
+ m.impl("interpolate", &interpolate_gpu);
258
+ }
259
+
260
+ }
uv_unwrapper/README.md ADDED
File without changes
uv_unwrapper/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ numpy
uv_unwrapper/setup.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import torch
5
+ from setuptools import find_packages, setup
6
+ from torch.utils.cpp_extension import (
7
+ BuildExtension,
8
+ CppExtension,
9
+ )
10
+
11
+ library_name = "uv_unwrapper"
12
+
13
+
14
+ def get_extensions():
15
+ debug_mode = os.getenv("DEBUG", "0") == "1"
16
+ if debug_mode:
17
+ print("Compiling in debug mode")
18
+
19
+ is_mac = True if torch.backends.mps.is_available() else False
20
+ use_native_arch = not is_mac and os.getenv("USE_NATIVE_ARCH", "1") == "1"
21
+ extension = CppExtension
22
+
23
+ extra_link_args = []
24
+ extra_compile_args = {
25
+ "cxx": [
26
+ "-O3" if not debug_mode else "-O0",
27
+ "-fdiagnostics-color=always",
28
+ ("-Xclang " if is_mac else "") + "-fopenmp",
29
+ ]
30
+ + ["-march=native"]
31
+ if use_native_arch
32
+ else []
33
+ + ["-mmacosx-version-min=10.15"] if is_mac else [],
34
+ }
35
+ if debug_mode:
36
+ extra_compile_args["cxx"].append("-g")
37
+ extra_compile_args["cxx"].append("-UNDEBUG")
38
+ extra_link_args.extend(["-O0", "-g"])
39
+
40
+ define_macros = []
41
+ extensions = []
42
+
43
+ this_dir = os.path.dirname(os.path.curdir)
44
+ sources = glob.glob(
45
+ os.path.join(this_dir, library_name, "csrc", "**", "*.cpp"), recursive=True
46
+ )
47
+
48
+ if len(sources) == 0:
49
+ print("No source files found for extension, skipping extension compilation")
50
+ return None
51
+
52
+ extensions.append(
53
+ extension(
54
+ name=f"{library_name}._C",
55
+ sources=sources,
56
+ define_macros=define_macros,
57
+ extra_compile_args=extra_compile_args,
58
+ extra_link_args=extra_link_args,
59
+ libraries=["c10", "torch", "torch_cpu", "torch_python"] + ["omp"]
60
+ if is_mac
61
+ else [],
62
+ )
63
+ )
64
+
65
+ print(extensions)
66
+
67
+ return extensions
68
+
69
+
70
+ setup(
71
+ name=library_name,
72
+ version="0.0.1",
73
+ packages=find_packages(),
74
+ ext_modules=get_extensions(),
75
+ install_requires=[],
76
+ description="Box projection based UV unwrapper",
77
+ long_description=open("README.md").read(),
78
+ long_description_content_type="text/markdown",
79
+ cmdclass={"build_ext": BuildExtension},
80
+ )
uv_unwrapper/uv_unwrapper/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch # noqa: F401
2
+
3
+ from . import _C # noqa: F401
4
+ from .unwrap import Unwrapper
5
+
6
+ __all__ = ["Unwrapper"]
uv_unwrapper/uv_unwrapper/csrc/bvh.cpp ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #include "bvh.h"
4
+ #include "common.h"
5
+ #include <cstring>
6
+ #include <iostream>
7
+ #include <queue>
8
+ #include <tuple>
9
+ #include <utility>
10
+
11
+ namespace UVUnwrapper {
12
+ BVH::BVH(Triangle *tri, int *actual_idx, const size_t &num_indices) {
13
+ // Copty tri to triangle
14
+ triangle = new Triangle[num_indices];
15
+ memcpy(triangle, tri, num_indices * sizeof(Triangle));
16
+
17
+ // Copy actual_idx to actualIdx
18
+ actualIdx = new int[num_indices];
19
+ memcpy(actualIdx, actual_idx, num_indices * sizeof(int));
20
+
21
+ triIdx = new int[num_indices];
22
+ triCount = num_indices;
23
+
24
+ bvhNode = new BVHNode[triCount * 2 + 64];
25
+ nodesUsed = 2;
26
+ memset(bvhNode, 0, triCount * 2 * sizeof(BVHNode));
27
+
28
+ // populate triangle index array
29
+ for (int i = 0; i < triCount; i++)
30
+ triIdx[i] = i;
31
+
32
+ BVHNode &root = bvhNode[0];
33
+
34
+ root.start = 0, root.end = triCount;
35
+ AABB centroidBounds;
36
+ UpdateNodeBounds(0, centroidBounds);
37
+
38
+ // subdivide recursively
39
+ Subdivide(0, nodesUsed, centroidBounds);
40
+ }
41
+
42
+ BVH::BVH(const BVH &other)
43
+ : BVH(other.triangle, other.triIdx, other.triCount) {}
44
+
45
+ BVH::BVH(BVH &&other) noexcept // move constructor
46
+ : triIdx(std::exchange(other.triIdx, nullptr)),
47
+ actualIdx(std::exchange(other.actualIdx, nullptr)),
48
+ triangle(std::exchange(other.triangle, nullptr)),
49
+ bvhNode(std::exchange(other.bvhNode, nullptr)) {}
50
+
51
+ BVH &BVH::operator=(const BVH &other) // copy assignment
52
+ {
53
+ return *this = BVH(other);
54
+ }
55
+
56
+ BVH &BVH::operator=(BVH &&other) noexcept // move assignment
57
+ {
58
+ std::swap(triIdx, other.triIdx);
59
+ std::swap(actualIdx, other.actualIdx);
60
+ std::swap(triangle, other.triangle);
61
+ std::swap(bvhNode, other.bvhNode);
62
+ std::swap(triCount, other.triCount);
63
+ std::swap(nodesUsed, other.nodesUsed);
64
+ return *this;
65
+ }
66
+
67
+ BVH::~BVH() {
68
+ if (triIdx)
69
+ delete[] triIdx;
70
+ if (triangle)
71
+ delete[] triangle;
72
+ if (actualIdx)
73
+ delete[] actualIdx;
74
+ if (bvhNode)
75
+ delete[] bvhNode;
76
+ }
77
+
78
+ void BVH::UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds) {
79
+ BVHNode &node = bvhNode[nodeIdx];
80
+ #ifndef __ARM_ARCH_ISA_A64
81
+ #ifndef _MSC_VER
82
+ if (__builtin_cpu_supports("sse"))
83
+ #elif (defined(_M_AMD64) || defined(_M_X64))
84
+ // SSE supported on Windows
85
+ if constexpr (true)
86
+ #endif
87
+ {
88
+ __m128 min4 = _mm_set_ps1(FLT_MAX), max4 = _mm_set_ps1(FLT_MIN);
89
+ __m128 cmin4 = _mm_set_ps1(FLT_MAX), cmax4 = _mm_set_ps1(FLT_MIN);
90
+ for (int i = node.start; i < node.end; i += 2) {
91
+ Triangle &leafTri1 = triangle[triIdx[i]];
92
+ __m128 v0, v1, v2, centroid;
93
+ if (i + 1 < node.end) {
94
+ const Triangle leafTri2 = triangle[triIdx[i + 1]];
95
+
96
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri2.v0.x,
97
+ leafTri2.v0.y);
98
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri2.v1.x,
99
+ leafTri2.v1.y);
100
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri2.v2.x,
101
+ leafTri2.v2.y);
102
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
103
+ leafTri2.centroid.x, leafTri2.centroid.y);
104
+ } else {
105
+ // Otherwise do some duplicated work
106
+ v0 = _mm_set_ps(leafTri1.v0.x, leafTri1.v0.y, leafTri1.v0.x,
107
+ leafTri1.v0.y);
108
+ v1 = _mm_set_ps(leafTri1.v1.x, leafTri1.v1.y, leafTri1.v1.x,
109
+ leafTri1.v1.y);
110
+ v2 = _mm_set_ps(leafTri1.v2.x, leafTri1.v2.y, leafTri1.v2.x,
111
+ leafTri1.v2.y);
112
+ centroid = _mm_set_ps(leafTri1.centroid.x, leafTri1.centroid.y,
113
+ leafTri1.centroid.x, leafTri1.centroid.y);
114
+ }
115
+
116
+ min4 = _mm_min_ps(min4, v0);
117
+ max4 = _mm_max_ps(max4, v0);
118
+ min4 = _mm_min_ps(min4, v1);
119
+ max4 = _mm_max_ps(max4, v1);
120
+ min4 = _mm_min_ps(min4, v2);
121
+ max4 = _mm_max_ps(max4, v2);
122
+ cmin4 = _mm_min_ps(cmin4, centroid);
123
+ cmax4 = _mm_max_ps(cmax4, centroid);
124
+ }
125
+ float min_values[4], max_values[4], cmin_values[4], cmax_values[4];
126
+ _mm_store_ps(min_values, min4);
127
+ _mm_store_ps(max_values, max4);
128
+ _mm_store_ps(cmin_values, cmin4);
129
+ _mm_store_ps(cmax_values, cmax4);
130
+
131
+ node.bbox.min.x = std::min(min_values[3], min_values[1]);
132
+ node.bbox.min.y = std::min(min_values[2], min_values[0]);
133
+ node.bbox.max.x = std::max(max_values[3], max_values[1]);
134
+ node.bbox.max.y = std::max(max_values[2], max_values[0]);
135
+
136
+ centroidBounds.min.x = std::min(cmin_values[3], cmin_values[1]);
137
+ centroidBounds.min.y = std::min(cmin_values[2], cmin_values[0]);
138
+ centroidBounds.max.x = std::max(cmax_values[3], cmax_values[1]);
139
+ centroidBounds.max.y = std::max(cmax_values[2], cmax_values[0]);
140
+ }
141
+ #else
142
+ if constexpr (false) {
143
+ }
144
+ #endif
145
+ else {
146
+ node.bbox.invalidate();
147
+ centroidBounds.invalidate();
148
+
149
+ // Calculate the bounding box for the node
150
+ for (int i = node.start; i < node.end; ++i) {
151
+ const Triangle &tri = triangle[triIdx[i]];
152
+ node.bbox.grow(tri.v0);
153
+ node.bbox.grow(tri.v1);
154
+ node.bbox.grow(tri.v2);
155
+ centroidBounds.grow(tri.centroid);
156
+ }
157
+ }
158
+ }
159
+
160
+ void BVH::Subdivide(unsigned int root_idx, unsigned int &nodePtr,
161
+ AABB &rootCentroidBounds) {
162
+ // Create a queue for the nodes to be subdivided
163
+ std::queue<std::tuple<unsigned int, AABB>> nodeQueue;
164
+ nodeQueue.push(std::make_tuple(root_idx, rootCentroidBounds));
165
+
166
+ while (!nodeQueue.empty()) {
167
+ // Get the next node to process from the queue
168
+ auto [node_idx, centroidBounds] = nodeQueue.front();
169
+ nodeQueue.pop();
170
+ BVHNode &node = bvhNode[node_idx];
171
+
172
+ // Check if left is -1 and right not or vice versa
173
+
174
+ int axis, splitPos;
175
+ float cost = FindBestSplitPlane(node, axis, splitPos, centroidBounds);
176
+
177
+ if (cost >= node.calculate_node_cost()) {
178
+ node.left = node.right = -1;
179
+ continue; // Move on to the next node in the queue
180
+ }
181
+
182
+ int i = node.start;
183
+ int j = node.end - 1;
184
+ float scale = BINS / (centroidBounds.max[axis] - centroidBounds.min[axis]);
185
+ while (i <= j) {
186
+ int binIdx =
187
+ std::min(BINS - 1, (int)((triangle[triIdx[i]].centroid[axis] -
188
+ centroidBounds.min[axis]) *
189
+ scale));
190
+ if (binIdx < splitPos)
191
+ i++;
192
+ else
193
+ std::swap(triIdx[i], triIdx[j--]);
194
+ }
195
+
196
+ int leftCount = i - node.start;
197
+ if (leftCount == 0 || leftCount == (int)node.num_triangles()) {
198
+ node.left = node.right = -1;
199
+ continue; // Move on to the next node in the queue
200
+ }
201
+
202
+ int mid = i;
203
+
204
+ // Create child nodes
205
+ int leftChildIdx = nodePtr++;
206
+ int rightChildIdx = nodePtr++;
207
+ bvhNode[leftChildIdx].start = node.start;
208
+ bvhNode[leftChildIdx].end = mid;
209
+ bvhNode[rightChildIdx].start = mid;
210
+ bvhNode[rightChildIdx].end = node.end;
211
+ node.left = leftChildIdx;
212
+ node.right = rightChildIdx;
213
+
214
+ // Update the bounds for the child nodes and push them onto the queue
215
+ UpdateNodeBounds(leftChildIdx, centroidBounds);
216
+ nodeQueue.push(std::make_tuple(leftChildIdx, centroidBounds));
217
+
218
+ UpdateNodeBounds(rightChildIdx, centroidBounds);
219
+ nodeQueue.push(std::make_tuple(rightChildIdx, centroidBounds));
220
+ }
221
+ }
222
+
223
+ float BVH::FindBestSplitPlane(BVHNode &node, int &best_axis, int &best_pos,
224
+ AABB &centroidBounds) {
225
+ float best_cost = FLT_MAX;
226
+
227
+ for (int axis = 0; axis < 2; ++axis) // We use 2 as we have only x and y
228
+ {
229
+ float boundsMin = centroidBounds.min[axis];
230
+ float boundsMax = centroidBounds.max[axis];
231
+ // Or floating point precision
232
+ if ((boundsMin == boundsMax) || (boundsMax - boundsMin < 1e-8f)) {
233
+ continue;
234
+ }
235
+
236
+ // populate the bins
237
+ float scale = BINS / (boundsMax - boundsMin);
238
+ float leftCountArea[BINS - 1], rightCountArea[BINS - 1];
239
+ int leftSum = 0, rightSum = 0;
240
+ #ifndef __ARM_ARCH_ISA_A64
241
+ #ifndef _MSC_VER
242
+ if (__builtin_cpu_supports("sse"))
243
+ #elif (defined(_M_AMD64) || defined(_M_X64))
244
+ // SSE supported on Windows
245
+ if constexpr (true)
246
+ #endif
247
+ {
248
+ __m128 min4[BINS], max4[BINS];
249
+ unsigned int count[BINS];
250
+ for (unsigned int i = 0; i < BINS; i++)
251
+ min4[i] = _mm_set_ps1(FLT_MAX), max4[i] = _mm_set_ps1(FLT_MIN),
252
+ count[i] = 0;
253
+ for (int i = node.start; i < node.end; i++) {
254
+ Triangle &tri = triangle[triIdx[i]];
255
+ int binIdx =
256
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
257
+ count[binIdx]++;
258
+
259
+ __m128 v0 = _mm_set_ps(tri.v0.x, tri.v0.y, 0.0f, 0.0f);
260
+ __m128 v1 = _mm_set_ps(tri.v1.x, tri.v1.y, 0.0f, 0.0f);
261
+ __m128 v2 = _mm_set_ps(tri.v2.x, tri.v2.y, 0.0f, 0.0f);
262
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v0);
263
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v0);
264
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v1);
265
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v1);
266
+ min4[binIdx] = _mm_min_ps(min4[binIdx], v2);
267
+ max4[binIdx] = _mm_max_ps(max4[binIdx], v2);
268
+ }
269
+ // gather data for the 7 planes between the 8 bins
270
+ __m128 leftMin4 = _mm_set_ps1(FLT_MAX), rightMin4 = leftMin4;
271
+ __m128 leftMax4 = _mm_set_ps1(FLT_MIN), rightMax4 = leftMax4;
272
+ for (int i = 0; i < BINS - 1; i++) {
273
+ leftSum += count[i];
274
+ rightSum += count[BINS - 1 - i];
275
+ leftMin4 = _mm_min_ps(leftMin4, min4[i]);
276
+ rightMin4 = _mm_min_ps(rightMin4, min4[BINS - 2 - i]);
277
+ leftMax4 = _mm_max_ps(leftMax4, max4[i]);
278
+ rightMax4 = _mm_max_ps(rightMax4, max4[BINS - 2 - i]);
279
+ float le[4], re[4];
280
+ _mm_store_ps(le, _mm_sub_ps(leftMax4, leftMin4));
281
+ _mm_store_ps(re, _mm_sub_ps(rightMax4, rightMin4));
282
+ // SSE order goes from back to front
283
+ leftCountArea[i] = leftSum * (le[2] * le[3]); // 2D area calculation
284
+ rightCountArea[BINS - 2 - i] =
285
+ rightSum * (re[2] * re[3]); // 2D area calculation
286
+ }
287
+ }
288
+ #else
289
+ if constexpr (false) {
290
+ }
291
+ #endif
292
+ else {
293
+ struct Bin {
294
+ AABB bounds;
295
+ int triCount = 0;
296
+ } bin[BINS];
297
+ for (int i = node.start; i < node.end; i++) {
298
+ Triangle &tri = triangle[triIdx[i]];
299
+ int binIdx =
300
+ std::min(BINS - 1, (int)((tri.centroid[axis] - boundsMin) * scale));
301
+ bin[binIdx].triCount++;
302
+ bin[binIdx].bounds.grow(tri.v0);
303
+ bin[binIdx].bounds.grow(tri.v1);
304
+ bin[binIdx].bounds.grow(tri.v2);
305
+ }
306
+ // gather data for the 7 planes between the 8 bins
307
+ AABB leftBox, rightBox;
308
+ for (int i = 0; i < BINS - 1; i++) {
309
+ leftSum += bin[i].triCount;
310
+ leftBox.grow(bin[i].bounds);
311
+ leftCountArea[i] = leftSum * leftBox.area();
312
+ rightSum += bin[BINS - 1 - i].triCount;
313
+ rightBox.grow(bin[BINS - 1 - i].bounds);
314
+ rightCountArea[BINS - 2 - i] = rightSum * rightBox.area();
315
+ }
316
+ }
317
+
318
+ // calculate SAH cost for the 7 planes
319
+ scale = (boundsMax - boundsMin) / BINS;
320
+ for (int i = 0; i < BINS - 1; i++) {
321
+ const float planeCost = leftCountArea[i] + rightCountArea[i];
322
+ if (planeCost < best_cost)
323
+ best_axis = axis, best_pos = i + 1, best_cost = planeCost;
324
+ }
325
+ }
326
+ return best_cost;
327
+ }
328
+
329
+ std::vector<int> BVH::Intersect(Triangle &tri_intersect) {
330
+ /**
331
+ * @brief Intersect a triangle with the BVH
332
+ *
333
+ * @param triangle the triangle to intersect
334
+ *
335
+ * @return -1 for no intersection, the index of the intersected triangle
336
+ * otherwise
337
+ */
338
+
339
+ const int max_stack_size = 64;
340
+ int node_stack[max_stack_size];
341
+ int stack_size = 0;
342
+ std::vector<int> intersected_triangles;
343
+
344
+ node_stack[stack_size++] = 0; // Start with the root node (index 0)
345
+ while (stack_size > 0) {
346
+ int node_idx = node_stack[--stack_size];
347
+ const BVHNode &node = bvhNode[node_idx];
348
+ if (node.is_leaf()) {
349
+ for (int i = node.start; i < node.end; ++i) {
350
+ const Triangle &tri = triangle[triIdx[i]];
351
+ // Check that the triangle is not the same as the intersected triangle
352
+ if (tri == tri_intersect)
353
+ continue;
354
+ if (tri_intersect.overlaps(tri)) {
355
+ intersected_triangles.push_back(actualIdx[triIdx[i]]);
356
+ }
357
+ }
358
+ } else {
359
+ // Check right child first
360
+ if (bvhNode[node.right].bbox.overlaps(tri_intersect)) {
361
+ if (stack_size < max_stack_size) {
362
+ node_stack[stack_size++] = node.right;
363
+ } else {
364
+ throw std::runtime_error("Node stack overflow");
365
+ }
366
+ }
367
+
368
+ // Check left child
369
+ if (bvhNode[node.left].bbox.overlaps(tri_intersect)) {
370
+ if (stack_size < max_stack_size) {
371
+ node_stack[stack_size++] = node.left;
372
+ } else {
373
+ throw std::runtime_error("Node stack overflow");
374
+ }
375
+ }
376
+ }
377
+ }
378
+ return intersected_triangles; // Return all intersected triangle indices
379
+ }
380
+
381
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/bvh.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cfloat>
4
+ #include <cmath>
5
+ #ifndef __ARM_ARCH_ISA_A64
6
+ #include <immintrin.h>
7
+ #endif
8
+ #include <limits>
9
+ #include <vector>
10
+
11
+ #include "common.h"
12
+ #include "intersect.h"
13
+ /**
14
+ * Based on https://github.com/jbikker/bvh_article released under the unlicense.
15
+ */
16
+
17
+ // bin count for binned BVH building
18
+ #define BINS 8
19
+
20
+ namespace UVUnwrapper {
21
+ // minimalist triangle struct
22
+ struct alignas(32) Triangle {
23
+ uv_float2 v0;
24
+ uv_float2 v1;
25
+ uv_float2 v2;
26
+ uv_float2 centroid;
27
+
28
+ bool overlaps(const Triangle &other) {
29
+ // return tri_tri_overlap_test_2d(v0, v1, v2, other.v0, other.v1, other.v2);
30
+ return triangle_triangle_intersection(v0, v1, v2, other.v0, other.v1,
31
+ other.v2);
32
+ }
33
+
34
+ bool operator==(const Triangle &rhs) const {
35
+ return v0 == rhs.v0 && v1 == rhs.v1 && v2 == rhs.v2;
36
+ }
37
+ };
38
+
39
+ // minimalist AABB struct with grow functionality
40
+ struct alignas(16) AABB {
41
+ // Init bounding boxes with max/min
42
+ uv_float2 min = {FLT_MAX, FLT_MAX};
43
+ uv_float2 max = {FLT_MIN, FLT_MIN};
44
+
45
+ void grow(const uv_float2 &p) {
46
+ min.x = std::min(min.x, p.x);
47
+ min.y = std::min(min.y, p.y);
48
+ max.x = std::max(max.x, p.x);
49
+ max.y = std::max(max.y, p.y);
50
+ }
51
+
52
+ void grow(const AABB &b) {
53
+ if (b.min.x != FLT_MAX) {
54
+ grow(b.min);
55
+ grow(b.max);
56
+ }
57
+ }
58
+
59
+ bool overlaps(const Triangle &tri) {
60
+ return triangle_aabb_intersection(min, max, tri.v0, tri.v1, tri.v2);
61
+ }
62
+
63
+ float area() const {
64
+ uv_float2 extent = {max.x - min.x, max.y - min.y};
65
+ return extent.x * extent.y;
66
+ }
67
+
68
+ void invalidate() {
69
+ min = {FLT_MAX, FLT_MAX};
70
+ max = {FLT_MIN, FLT_MIN};
71
+ }
72
+ };
73
+
74
+ // 32-byte BVH node struct
75
+ struct alignas(32) BVHNode {
76
+ AABB bbox; // 16
77
+ int start = 0, end = 0; // 8
78
+ int left, right;
79
+
80
+ int num_triangles() const { return end - start; }
81
+
82
+ bool is_leaf() const { return left == -1 && right == -1; }
83
+
84
+ float calculate_node_cost() {
85
+ float area = bbox.area();
86
+ return num_triangles() * area;
87
+ }
88
+ };
89
+
90
+ class BVH {
91
+ public:
92
+ BVH() = default;
93
+ BVH(BVH &&other) noexcept;
94
+ BVH(const BVH &other);
95
+ BVH &operator=(const BVH &other);
96
+ BVH &operator=(BVH &&other) noexcept;
97
+ BVH(Triangle *tri, int *actual_idx, const size_t &num_indices);
98
+ ~BVH();
99
+
100
+ std::vector<int> Intersect(Triangle &triangle);
101
+
102
+ private:
103
+ void Subdivide(unsigned int node_idx, unsigned int &nodePtr,
104
+ AABB &centroidBounds);
105
+ void UpdateNodeBounds(unsigned int nodeIdx, AABB &centroidBounds);
106
+ float FindBestSplitPlane(BVHNode &node, int &axis, int &splitPos,
107
+ AABB &centroidBounds);
108
+
109
+ public:
110
+ int *triIdx = nullptr;
111
+ int *actualIdx = nullptr;
112
+ unsigned int triCount;
113
+ unsigned int nodesUsed;
114
+ BVHNode *bvhNode = nullptr;
115
+ Triangle *triangle = nullptr;
116
+ };
117
+
118
+ } // namespace UVUnwrapper
uv_unwrapper/uv_unwrapper/csrc/common.h ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <array>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+
8
+ const float EPSILON = 1e-7f;
9
+
10
+ // Structure to represent a 2D point or vector
11
+ union alignas(8) uv_float2 {
12
+ struct {
13
+ float x, y;
14
+ };
15
+
16
+ float data[2];
17
+
18
+ float &operator[](size_t idx) {
19
+ if (idx > 1)
20
+ throw std::runtime_error("bad index");
21
+ return data[idx];
22
+ }
23
+
24
+ const float &operator[](size_t idx) const {
25
+ if (idx > 1)
26
+ throw std::runtime_error("bad index");
27
+ return data[idx];
28
+ }
29
+
30
+ bool operator==(const uv_float2 &rhs) const {
31
+ return x == rhs.x && y == rhs.y;
32
+ }
33
+ };
34
+
35
+ // Do not align as this is specifically tweaked for BVHNode
36
+ union uv_float3 {
37
+ struct {
38
+ float x, y, z;
39
+ };
40
+
41
+ float data[3];
42
+
43
+ float &operator[](size_t idx) {
44
+ if (idx > 3)
45
+ throw std::runtime_error("bad index");
46
+ return data[idx];
47
+ }
48
+
49
+ const float &operator[](size_t idx) const {
50
+ if (idx > 3)
51
+ throw std::runtime_error("bad index");
52
+ return data[idx];
53
+ }
54
+
55
+ bool operator==(const uv_float3 &rhs) const {
56
+ return x == rhs.x && y == rhs.y && z == rhs.z;
57
+ }
58
+ };
59
+
60
+ union alignas(16) uv_float4 {
61
+ struct {
62
+ float x, y, z, w;
63
+ };
64
+
65
+ float data[4];
66
+
67
+ float &operator[](size_t idx) {
68
+ if (idx > 3)
69
+ throw std::runtime_error("bad index");
70
+ return data[idx];
71
+ }
72
+
73
+ const float &operator[](size_t idx) const {
74
+ if (idx > 3)
75
+ throw std::runtime_error("bad index");
76
+ return data[idx];
77
+ }
78
+
79
+ bool operator==(const uv_float4 &rhs) const {
80
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
81
+ }
82
+ };
83
+
84
+ union alignas(8) uv_int2 {
85
+ struct {
86
+ int x, y;
87
+ };
88
+
89
+ int data[2];
90
+
91
+ int &operator[](size_t idx) {
92
+ if (idx > 1)
93
+ throw std::runtime_error("bad index");
94
+ return data[idx];
95
+ }
96
+
97
+ const int &operator[](size_t idx) const {
98
+ if (idx > 1)
99
+ throw std::runtime_error("bad index");
100
+ return data[idx];
101
+ }
102
+
103
+ bool operator==(const uv_int2 &rhs) const { return x == rhs.x && y == rhs.y; }
104
+ };
105
+
106
+ union alignas(4) uv_int3 {
107
+ struct {
108
+ int x, y, z;
109
+ };
110
+
111
+ int data[3];
112
+
113
+ int &operator[](size_t idx) {
114
+ if (idx > 2)
115
+ throw std::runtime_error("bad index");
116
+ return data[idx];
117
+ }
118
+
119
+ const int &operator[](size_t idx) const {
120
+ if (idx > 2)
121
+ throw std::runtime_error("bad index");
122
+ return data[idx];
123
+ }
124
+
125
+ bool operator==(const uv_int3 &rhs) const {
126
+ return x == rhs.x && y == rhs.y && z == rhs.z;
127
+ }
128
+ };
129
+
130
+ union alignas(16) uv_int4 {
131
+ struct {
132
+ int x, y, z, w;
133
+ };
134
+
135
+ int data[4];
136
+
137
+ int &operator[](size_t idx) {
138
+ if (idx > 3)
139
+ throw std::runtime_error("bad index");
140
+ return data[idx];
141
+ }
142
+
143
+ const int &operator[](size_t idx) const {
144
+ if (idx > 3)
145
+ throw std::runtime_error("bad index");
146
+ return data[idx];
147
+ }
148
+
149
+ bool operator==(const uv_int4 &rhs) const {
150
+ return x == rhs.x && y == rhs.y && z == rhs.z && w == rhs.w;
151
+ }
152
+ };
153
+
154
+ inline float calc_mean(float a, float b, float c) { return (a + b + c) / 3; }
155
+
156
+ // Create a triangle centroid
157
+ inline uv_float2 triangle_centroid(const uv_float2 &v0, const uv_float2 &v1,
158
+ const uv_float2 &v2) {
159
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y)};
160
+ }
161
+
162
+ inline uv_float3 triangle_centroid(const uv_float3 &v0, const uv_float3 &v1,
163
+ const uv_float3 &v2) {
164
+ return {calc_mean(v0.x, v1.x, v2.x), calc_mean(v0.y, v1.y, v2.y),
165
+ calc_mean(v0.z, v1.z, v2.z)};
166
+ }
167
+
168
+ // Helper functions for vector math
169
+ inline uv_float2 operator-(const uv_float2 &a, const uv_float2 &b) {
170
+ return {a.x - b.x, a.y - b.y};
171
+ }
172
+
173
+ inline uv_float3 operator-(const uv_float3 &a, const uv_float3 &b) {
174
+ return {a.x - b.x, a.y - b.y, a.z - b.z};
175
+ }
176
+
177
+ inline uv_float2 operator+(const uv_float2 &a, const uv_float2 &b) {
178
+ return {a.x + b.x, a.y + b.y};
179
+ }
180
+
181
+ inline uv_float3 operator+(const uv_float3 &a, const uv_float3 &b) {
182
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
183
+ }
184
+
185
+ inline uv_float2 operator*(const uv_float2 &a, float scalar) {
186
+ return {a.x * scalar, a.y * scalar};
187
+ }
188
+
189
+ inline uv_float3 operator*(const uv_float3 &a, float scalar) {
190
+ return {a.x * scalar, a.y * scalar, a.z * scalar};
191
+ }
192
+
193
+ inline float dot(const uv_float2 &a, const uv_float2 &b) {
194
+ return a.x * b.x + a.y * b.y;
195
+ }
196
+
197
+ inline float dot(const uv_float3 &a, const uv_float3 &b) {
198
+ return a.x * b.x + a.y * b.y + a.z * b.z;
199
+ }
200
+
201
+ inline float cross(const uv_float2 &a, const uv_float2 &b) {
202
+ return a.x * b.y - a.y * b.x;
203
+ }
204
+
205
+ inline uv_float3 cross(const uv_float3 &a, const uv_float3 &b) {
206
+ return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x};
207
+ }
208
+
209
+ inline uv_float2 abs_vec(const uv_float2 &v) {
210
+ return {std::abs(v.x), std::abs(v.y)};
211
+ }
212
+
213
+ inline uv_float2 min_vec(const uv_float2 &a, const uv_float2 &b) {
214
+ return {std::min(a.x, b.x), std::min(a.y, b.y)};
215
+ }
216
+
217
+ inline uv_float2 max_vec(const uv_float2 &a, const uv_float2 &b) {
218
+ return {std::max(a.x, b.x), std::max(a.y, b.y)};
219
+ }
220
+
221
+ inline float distance_to(const uv_float2 &a, const uv_float2 &b) {
222
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2));
223
+ }
224
+
225
+ inline float distance_to(const uv_float3 &a, const uv_float3 &b) {
226
+ return std::sqrt(std::pow(a.x - b.x, 2) + std::pow(a.y - b.y, 2) +
227
+ std::pow(a.z - b.z, 2));
228
+ }
229
+
230
+ inline uv_float2 normalize(const uv_float2 &v) {
231
+ float len = std::sqrt(v.x * v.x + v.y * v.y);
232
+ return {v.x / len, v.y / len};
233
+ }
234
+
235
+ inline uv_float3 normalize(const uv_float3 &v) {
236
+ float len = std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
237
+ return {v.x / len, v.y / len, v.z / len};
238
+ }
239
+
240
+ inline float magnitude(const uv_float3 &v) {
241
+ return std::sqrt(v.x * v.x + v.y * v.y + v.z * v.z);
242
+ }
243
+
244
+ struct Matrix4 {
245
+ std::array<std::array<float, 4>, 4> m;
246
+
247
+ Matrix4() {
248
+ for (auto &row : m) {
249
+ row.fill(0.0f);
250
+ }
251
+ m[3][3] = 1.0f; // Identity matrix for 4th row and column
252
+ }
253
+
254
+ void set(float m00, float m01, float m02, float m03, float m10, float m11,
255
+ float m12, float m13, float m20, float m21, float m22, float m23,
256
+ float m30, float m31, float m32, float m33) {
257
+ m[0][0] = m00;
258
+ m[0][1] = m01;
259
+ m[0][2] = m02;
260
+ m[0][3] = m03;
261
+ m[1][0] = m10;
262
+ m[1][1] = m11;
263
+ m[1][2] = m12;
264
+ m[1][3] = m13;
265
+ m[2][0] = m20;
266
+ m[2][1] = m21;
267
+ m[2][2] = m22;
268
+ m[2][3] = m23;
269
+ m[3][0] = m30;
270
+ m[3][1] = m31;
271
+ m[3][2] = m32;
272
+ m[3][3] = m33;
273
+ }
274
+
275
+ float determinant() const {
276
+ return m[0][3] * m[1][2] * m[2][1] * m[3][0] -
277
+ m[0][2] * m[1][3] * m[2][1] * m[3][0] -
278
+ m[0][3] * m[1][1] * m[2][2] * m[3][0] +
279
+ m[0][1] * m[1][3] * m[2][2] * m[3][0] +
280
+ m[0][2] * m[1][1] * m[2][3] * m[3][0] -
281
+ m[0][1] * m[1][2] * m[2][3] * m[3][0] -
282
+ m[0][3] * m[1][2] * m[2][0] * m[3][1] +
283
+ m[0][2] * m[1][3] * m[2][0] * m[3][1] +
284
+ m[0][3] * m[1][0] * m[2][2] * m[3][1] -
285
+ m[0][0] * m[1][3] * m[2][2] * m[3][1] -
286
+ m[0][2] * m[1][0] * m[2][3] * m[3][1] +
287
+ m[0][0] * m[1][2] * m[2][3] * m[3][1] +
288
+ m[0][3] * m[1][1] * m[2][0] * m[3][2] -
289
+ m[0][1] * m[1][3] * m[2][0] * m[3][2] -
290
+ m[0][3] * m[1][0] * m[2][1] * m[3][2] +
291
+ m[0][0] * m[1][3] * m[2][1] * m[3][2] +
292
+ m[0][1] * m[1][0] * m[2][3] * m[3][2] -
293
+ m[0][0] * m[1][1] * m[2][3] * m[3][2] -
294
+ m[0][2] * m[1][1] * m[2][0] * m[3][3] +
295
+ m[0][1] * m[1][2] * m[2][0] * m[3][3] +
296
+ m[0][2] * m[1][0] * m[2][1] * m[3][3] -
297
+ m[0][0] * m[1][2] * m[2][1] * m[3][3] -
298
+ m[0][1] * m[1][0] * m[2][2] * m[3][3] +
299
+ m[0][0] * m[1][1] * m[2][2] * m[3][3];
300
+ }
301
+
302
+ Matrix4 operator*(const Matrix4 &other) const {
303
+ Matrix4 result;
304
+ for (int row = 0; row < 4; ++row) {
305
+ for (int col = 0; col < 4; ++col) {
306
+ result.m[row][col] =
307
+ m[row][0] * other.m[0][col] + m[row][1] * other.m[1][col] +
308
+ m[row][2] * other.m[2][col] + m[row][3] * other.m[3][col];
309
+ }
310
+ }
311
+ return result;
312
+ }
313
+
314
+ Matrix4 operator*(float scalar) const {
315
+ Matrix4 result = *this;
316
+ for (auto &row : result.m) {
317
+ for (auto &element : row) {
318
+ element *= scalar;
319
+ }
320
+ }
321
+ return result;
322
+ }
323
+
324
+ Matrix4 operator+(const Matrix4 &other) const {
325
+ Matrix4 result;
326
+ for (int i = 0; i < 4; ++i) {
327
+ for (int j = 0; j < 4; ++j) {
328
+ result.m[i][j] = m[i][j] + other.m[i][j];
329
+ }
330
+ }
331
+ return result;
332
+ }
333
+
334
+ Matrix4 operator-(const Matrix4 &other) const {
335
+ Matrix4 result;
336
+ for (int i = 0; i < 4; ++i) {
337
+ for (int j = 0; j < 4; ++j) {
338
+ result.m[i][j] = m[i][j] - other.m[i][j];
339
+ }
340
+ }
341
+ return result;
342
+ }
343
+
344
+ float trace() const { return m[0][0] + m[1][1] + m[2][2] + m[3][3]; }
345
+
346
+ Matrix4 identity() const {
347
+ Matrix4 identity;
348
+ identity.set(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);
349
+ return identity;
350
+ }
351
+
352
+ Matrix4 power(int exp) const {
353
+ if (exp == 0)
354
+ return identity();
355
+ if (exp == 1)
356
+ return *this;
357
+
358
+ Matrix4 result = *this;
359
+ for (int i = 1; i < exp; ++i) {
360
+ result = result * (*this);
361
+ }
362
+ return result;
363
+ }
364
+
365
+ void print() {
366
+ // Print all entries in 4 rows with 4 columns
367
+ for (int i = 0; i < 4; ++i) {
368
+ for (int j = 0; j < 4; ++j) {
369
+ std::cout << m[i][j] << " ";
370
+ }
371
+ std::cout << std::endl;
372
+ }
373
+ }
374
+
375
+ bool invert() {
376
+ double inv[16], det;
377
+ double mArr[16];
378
+
379
+ // Convert the matrix to a 1D array for easier manipulation
380
+ for (int i = 0; i < 4; ++i) {
381
+ for (int j = 0; j < 4; ++j) {
382
+ mArr[i * 4 + j] = static_cast<double>(m[i][j]);
383
+ }
384
+ }
385
+
386
+ inv[0] = mArr[5] * mArr[10] * mArr[15] - mArr[5] * mArr[11] * mArr[14] -
387
+ mArr[9] * mArr[6] * mArr[15] + mArr[9] * mArr[7] * mArr[14] +
388
+ mArr[13] * mArr[6] * mArr[11] - mArr[13] * mArr[7] * mArr[10];
389
+
390
+ inv[4] = -mArr[4] * mArr[10] * mArr[15] + mArr[4] * mArr[11] * mArr[14] +
391
+ mArr[8] * mArr[6] * mArr[15] - mArr[8] * mArr[7] * mArr[14] -
392
+ mArr[12] * mArr[6] * mArr[11] + mArr[12] * mArr[7] * mArr[10];
393
+
394
+ inv[8] = mArr[4] * mArr[9] * mArr[15] - mArr[4] * mArr[11] * mArr[13] -
395
+ mArr[8] * mArr[5] * mArr[15] + mArr[8] * mArr[7] * mArr[13] +
396
+ mArr[12] * mArr[5] * mArr[11] - mArr[12] * mArr[7] * mArr[9];
397
+
398
+ inv[12] = -mArr[4] * mArr[9] * mArr[14] + mArr[4] * mArr[10] * mArr[13] +
399
+ mArr[8] * mArr[5] * mArr[14] - mArr[8] * mArr[6] * mArr[13] -
400
+ mArr[12] * mArr[5] * mArr[10] + mArr[12] * mArr[6] * mArr[9];
401
+
402
+ inv[1] = -mArr[1] * mArr[10] * mArr[15] + mArr[1] * mArr[11] * mArr[14] +
403
+ mArr[9] * mArr[2] * mArr[15] - mArr[9] * mArr[3] * mArr[14] -
404
+ mArr[13] * mArr[2] * mArr[11] + mArr[13] * mArr[3] * mArr[10];
405
+
406
+ inv[5] = mArr[0] * mArr[10] * mArr[15] - mArr[0] * mArr[11] * mArr[14] -
407
+ mArr[8] * mArr[2] * mArr[15] + mArr[8] * mArr[3] * mArr[14] +
408
+ mArr[12] * mArr[2] * mArr[11] - mArr[12] * mArr[3] * mArr[10];
409
+
410
+ inv[9] = -mArr[0] * mArr[9] * mArr[15] + mArr[0] * mArr[11] * mArr[13] +
411
+ mArr[8] * mArr[1] * mArr[15] - mArr[8] * mArr[3] * mArr[13] -
412
+ mArr[12] * mArr[1] * mArr[11] + mArr[12] * mArr[3] * mArr[9];
413
+
414
+ inv[13] = mArr[0] * mArr[9] * mArr[14] - mArr[0] * mArr[10] * mArr[13] -
415
+ mArr[8] * mArr[1] * mArr[14] + mArr[8] * mArr[2] * mArr[13] +
416
+ mArr[12] * mArr[1] * mArr[10] - mArr[12] * mArr[2] * mArr[9];
417
+
418
+ inv[2] = mArr[1] * mArr[6] * mArr[15] - mArr[1] * mArr[7] * mArr[14] -
419
+ mArr[5] * mArr[2] * mArr[15] + mArr[5] * mArr[3] * mArr[14] +
420
+ mArr[13] * mArr[2] * mArr[7] - mArr[13] * mArr[3] * mArr[6];
421
+
422
+ inv[6] = -mArr[0] * mArr[6] * mArr[15] + mArr[0] * mArr[7] * mArr[14] +
423
+ mArr[4] * mArr[2] * mArr[15] - mArr[4] * mArr[3] * mArr[14] -
424
+ mArr[12] * mArr[2] * mArr[7] + mArr[12] * mArr[3] * mArr[6];
425
+
426
+ inv[10] = mArr[0] * mArr[5] * mArr[15] - mArr[0] * mArr[7] * mArr[13] -
427
+ mArr[4] * mArr[1] * mArr[15] + mArr[4] * mArr[3] * mArr[13] +
428
+ mArr[12] * mArr[1] * mArr[7] - mArr[12] * mArr[3] * mArr[5];
429
+
430
+ inv[14] = -mArr[0] * mArr[5] * mArr[14] + mArr[0] * mArr[6] * mArr[13] +
431
+ mArr[4] * mArr[1] * mArr[14] - mArr[4] * mArr[2] * mArr[13] -
432
+ mArr[12] * mArr[1] * mArr[6] + mArr[12] * mArr[2] * mArr[5];
433
+
434
+ inv[3] = -mArr[1] * mArr[6] * mArr[11] + mArr[1] * mArr[7] * mArr[10] +
435
+ mArr[5] * mArr[2] * mArr[11] - mArr[5] * mArr[3] * mArr[10] -
436
+ mArr[9] * mArr[2] * mArr[7] + mArr[9] * mArr[3] * mArr[6];
437
+
438
+ inv[7] = mArr[0] * mArr[6] * mArr[11] - mArr[0] * mArr[7] * mArr[10] -
439
+ mArr[4] * mArr[2] * mArr[11] + mArr[4] * mArr[3] * mArr[10] +
440
+ mArr[8] * mArr[2] * mArr[7] - mArr[8] * mArr[3] * mArr[6];
441
+
442
+ inv[11] = -mArr[0] * mArr[5] * mArr[11] + mArr[0] * mArr[7] * mArr[9] +
443
+ mArr[4] * mArr[1] * mArr[11] - mArr[4] * mArr[3] * mArr[9] -
444
+ mArr[8] * mArr[1] * mArr[7] + mArr[8] * mArr[3] * mArr[5];
445
+
446
+ inv[15] = mArr[0] * mArr[5] * mArr[10] - mArr[0] * mArr[6] * mArr[9] -
447
+ mArr[4] * mArr[1] * mArr[10] + mArr[4] * mArr[2] * mArr[9] +
448
+ mArr[8] * mArr[1] * mArr[6] - mArr[8] * mArr[2] * mArr[5];
449
+
450
+ det = mArr[0] * inv[0] + mArr[1] * inv[4] + mArr[2] * inv[8] +
451
+ mArr[3] * inv[12];
452
+
453
+ if (fabs(det) < 1e-6) {
454
+ return false;
455
+ }
456
+
457
+ det = 1.0 / det;
458
+
459
+ for (int i = 0; i < 16; i++) {
460
+ inv[i] *= det;
461
+ }
462
+
463
+ // Convert the 1D array back to the 4x4 matrix
464
+ for (int i = 0; i < 4; ++i) {
465
+ for (int j = 0; j < 4; ++j) {
466
+ m[i][j] = static_cast<float>(inv[i * 4 + j]);
467
+ }
468
+ }
469
+
470
+ return true;
471
+ }
472
+ };
473
+
474
+ inline void apply_matrix4(uv_float3 &v, const Matrix4 matrix) {
475
+ float newX = v.x * matrix.m[0][0] + v.y * matrix.m[0][1] +
476
+ v.z * matrix.m[0][2] + matrix.m[0][3];
477
+ float newY = v.x * matrix.m[1][0] + v.y * matrix.m[1][1] +
478
+ v.z * matrix.m[1][2] + matrix.m[1][3];
479
+ float newZ = v.x * matrix.m[2][0] + v.y * matrix.m[2][1] +
480
+ v.z * matrix.m[2][2] + matrix.m[2][3];
481
+ float w = v.x * matrix.m[3][0] + v.y * matrix.m[3][1] + v.z * matrix.m[3][2] +
482
+ matrix.m[3][3];
483
+
484
+ if (std::fabs(w) > EPSILON) {
485
+ newX /= w;
486
+ newY /= w;
487
+ newZ /= w;
488
+ }
489
+
490
+ v.x = newX;
491
+ v.y = newY;
492
+ v.z = newZ;
493
+ }
uv_unwrapper/uv_unwrapper/csrc/intersect.cpp ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "intersect.h"
2
+ #include "bvh.h"
3
+ #include <algorithm>
4
+ #include <cmath>
5
+ #include <iostream>
6
+ #include <stdexcept>
7
+ #include <vector>
8
+
9
+ bool triangle_aabb_intersection(const uv_float2 &aabbMin,
10
+ const uv_float2 &aabbMax, const uv_float2 &v0,
11
+ const uv_float2 &v1, const uv_float2 &v2) {
12
+ // Convert the min and max aabb defintion to left, right, top, bottom
13
+ float l = aabbMin.x;
14
+ float r = aabbMax.x;
15
+ float t = aabbMin.y;
16
+ float b = aabbMax.y;
17
+
18
+ int b0 = ((v0.x > l) ? 1 : 0) | ((v0.y > t) ? 2 : 0) | ((v0.x > r) ? 4 : 0) |
19
+ ((v0.y > b) ? 8 : 0);
20
+ if (b0 == 3)
21
+ return true;
22
+
23
+ int b1 = ((v1.x > l) ? 1 : 0) | ((v1.y > t) ? 2 : 0) | ((v1.x > r) ? 4 : 0) |
24
+ ((v1.y > b) ? 8 : 0);
25
+ if (b1 == 3)
26
+ return true;
27
+
28
+ int b2 = ((v2.x > l) ? 1 : 0) | ((v2.y > t) ? 2 : 0) | ((v2.x > r) ? 4 : 0) |
29
+ ((v2.y > b) ? 8 : 0);
30
+ if (b2 == 3)
31
+ return true;
32
+
33
+ float m, c, s;
34
+
35
+ int i0 = b0 ^ b1;
36
+ if (i0 != 0) {
37
+ if (v1.x != v0.x) {
38
+ m = (v1.y - v0.y) / (v1.x - v0.x);
39
+ c = v0.y - (m * v0.x);
40
+ if (i0 & 1) {
41
+ s = m * l + c;
42
+ if (s >= t && s <= b)
43
+ return true;
44
+ }
45
+ if (i0 & 2) {
46
+ s = (t - c) / m;
47
+ if (s >= l && s <= r)
48
+ return true;
49
+ }
50
+ if (i0 & 4) {
51
+ s = m * r + c;
52
+ if (s >= t && s <= b)
53
+ return true;
54
+ }
55
+ if (i0 & 8) {
56
+ s = (b - c) / m;
57
+ if (s >= l && s <= r)
58
+ return true;
59
+ }
60
+ } else {
61
+ if (l == v0.x || r == v0.x)
62
+ return true;
63
+ if (v0.x > l && v0.x < r)
64
+ return true;
65
+ }
66
+ }
67
+
68
+ int i1 = b1 ^ b2;
69
+ if (i1 != 0) {
70
+ if (v2.x != v1.x) {
71
+ m = (v2.y - v1.y) / (v2.x - v1.x);
72
+ c = v1.y - (m * v1.x);
73
+ if (i1 & 1) {
74
+ s = m * l + c;
75
+ if (s >= t && s <= b)
76
+ return true;
77
+ }
78
+ if (i1 & 2) {
79
+ s = (t - c) / m;
80
+ if (s >= l && s <= r)
81
+ return true;
82
+ }
83
+ if (i1 & 4) {
84
+ s = m * r + c;
85
+ if (s >= t && s <= b)
86
+ return true;
87
+ }
88
+ if (i1 & 8) {
89
+ s = (b - c) / m;
90
+ if (s >= l && s <= r)
91
+ return true;
92
+ }
93
+ } else {
94
+ if (l == v1.x || r == v1.x)
95
+ return true;
96
+ if (v1.x > l && v1.x < r)
97
+ return true;
98
+ }
99
+ }
100
+
101
+ int i2 = b0 ^ b2;
102
+ if (i2 != 0) {
103
+ if (v2.x != v0.x) {
104
+ m = (v2.y - v0.y) / (v2.x - v0.x);
105
+ c = v0.y - (m * v0.x);
106
+ if (i2 & 1) {
107
+ s = m * l + c;
108
+ if (s >= t && s <= b)
109
+ return true;
110
+ }
111
+ if (i2 & 2) {
112
+ s = (t - c) / m;
113
+ if (s >= l && s <= r)
114
+ return true;
115
+ }
116
+ if (i2 & 4) {
117
+ s = m * r + c;
118
+ if (s >= t && s <= b)
119
+ return true;
120
+ }
121
+ if (i2 & 8) {
122
+ s = (b - c) / m;
123
+ if (s >= l && s <= r)
124
+ return true;
125
+ }
126
+ } else {
127
+ if (l == v0.x || r == v0.x)
128
+ return true;
129
+ if (v0.x > l && v0.x < r)
130
+ return true;
131
+ }
132
+ }
133
+
134
+ // Bounding box check
135
+ float tbb_l = std::min(v0.x, std::min(v1.x, v2.x));
136
+ float tbb_t = std::min(v0.y, std::min(v1.y, v2.y));
137
+ float tbb_r = std::max(v0.x, std::max(v1.x, v2.x));
138
+ float tbb_b = std::max(v0.y, std::max(v1.y, v2.y));
139
+
140
+ if (tbb_l <= l && tbb_r >= r && tbb_t <= t && tbb_b >= b) {
141
+ float v0x = v2.x - v0.x;
142
+ float v0y = v2.y - v0.y;
143
+ float v1x = v1.x - v0.x;
144
+ float v1y = v1.y - v0.y;
145
+ float v2x, v2y;
146
+
147
+ float dot00, dot01, dot02, dot11, dot12, invDenom, u, v;
148
+
149
+ // Top-left corner
150
+ v2x = l - v0.x;
151
+ v2y = t - v0.y;
152
+
153
+ dot00 = v0x * v0x + v0y * v0y;
154
+ dot01 = v0x * v1x + v0y * v1y;
155
+ dot02 = v0x * v2x + v0y * v2y;
156
+ dot11 = v1x * v1x + v1y * v1y;
157
+ dot12 = v1x * v2x + v1y * v2y;
158
+
159
+ invDenom = 1.0f / (dot00 * dot11 - dot01 * dot01);
160
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
161
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
162
+
163
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
164
+ return true;
165
+
166
+ // Bottom-left corner
167
+ v2x = l - v0.x;
168
+ v2y = b - v0.y;
169
+
170
+ dot02 = v0x * v2x + v0y * v2y;
171
+ dot12 = v1x * v2x + v1y * v2y;
172
+
173
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
174
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
175
+
176
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
177
+ return true;
178
+
179
+ // Bottom-right corner
180
+ v2x = r - v0.x;
181
+ v2y = b - v0.y;
182
+
183
+ dot02 = v0x * v2x + v0y * v2y;
184
+ dot12 = v1x * v2x + v1y * v2y;
185
+
186
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
187
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
188
+
189
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
190
+ return true;
191
+
192
+ // Top-right corner
193
+ v2x = r - v0.x;
194
+ v2y = t - v0.y;
195
+
196
+ dot02 = v0x * v2x + v0y * v2y;
197
+ dot12 = v1x * v2x + v1y * v2y;
198
+
199
+ u = (dot11 * dot02 - dot01 * dot12) * invDenom;
200
+ v = (dot00 * dot12 - dot01 * dot02) * invDenom;
201
+
202
+ if (u >= 0 && v >= 0 && (u + v) <= 1)
203
+ return true;
204
+ }
205
+
206
+ return false;
207
+ }
208
+
209
+ void tri_winding(uv_float2 &a, uv_float2 &b, uv_float2 &c) {
210
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
211
+
212
+ // If the determinant is negative, the triangle is oriented clockwise
213
+ if (det < 0) {
214
+ // Swap vertices b and c to ensure counter-clockwise winding
215
+ std::swap(b, c);
216
+ }
217
+ }
218
+
219
+ struct Triangle {
220
+ uv_float3 a, b, c;
221
+
222
+ Triangle(const uv_float2 &p1, const uv_float2 &q1, const uv_float2 &r1)
223
+ : a({p1.x, p1.y, 0}), b({q1.x, q1.y, 0}), c({r1.x, r1.y, 0}) {}
224
+
225
+ Triangle(const uv_float3 &p1, const uv_float3 &q1, const uv_float3 &r1)
226
+ : a(p1), b(q1), c(r1) {}
227
+
228
+ void getNormal(uv_float3 &normal) const {
229
+ uv_float3 u = b - a;
230
+ uv_float3 v = c - a;
231
+ normal = normalize(cross(u, v));
232
+ }
233
+ };
234
+
235
+ bool isTriDegenerated(const Triangle &tri) {
236
+ uv_float3 u = tri.a - tri.b;
237
+ uv_float3 v = tri.a - tri.c;
238
+ uv_float3 cr = cross(u, v);
239
+ return fabs(cr.x) < EPSILON && fabs(cr.y) < EPSILON && fabs(cr.z) < EPSILON;
240
+ }
241
+
242
+ int orient3D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c,
243
+ const uv_float3 &d) {
244
+ Matrix4 _matrix4;
245
+ _matrix4.set(a.x, a.y, a.z, 1, b.x, b.y, b.z, 1, c.x, c.y, c.z, 1, d.x, d.y,
246
+ d.z, 1);
247
+ float det = _matrix4.determinant();
248
+
249
+ if (det < -EPSILON)
250
+ return -1;
251
+ else if (det > EPSILON)
252
+ return 1;
253
+ else
254
+ return 0;
255
+ }
256
+
257
+ int orient2D(const uv_float2 &a, const uv_float2 &b, const uv_float2 &c) {
258
+ float det = (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y));
259
+
260
+ if (det < -EPSILON)
261
+ return -1;
262
+ else if (det > EPSILON)
263
+ return 1;
264
+ else
265
+ return 0;
266
+ }
267
+
268
+ int orient2D(const uv_float3 &a, const uv_float3 &b, const uv_float3 &c) {
269
+ uv_float2 a_2d = {a.x, a.y};
270
+ uv_float2 b_2d = {b.x, b.y};
271
+ uv_float2 c_2d = {c.x, c.y};
272
+ return orient2D(a_2d, b_2d, c_2d);
273
+ }
274
+
275
+ void permuteTriLeft(Triangle &tri) {
276
+ uv_float3 tmp = tri.a;
277
+ tri.a = tri.b;
278
+ tri.b = tri.c;
279
+ tri.c = tmp;
280
+ }
281
+
282
+ void permuteTriRight(Triangle &tri) {
283
+ uv_float3 tmp = tri.c;
284
+ tri.c = tri.b;
285
+ tri.b = tri.a;
286
+ tri.a = tmp;
287
+ }
288
+
289
+ void makeTriCounterClockwise(Triangle &tri) {
290
+ if (orient2D(tri.a, tri.b, tri.c) < 0) {
291
+ uv_float3 tmp = tri.c;
292
+ tri.c = tri.b;
293
+ tri.b = tmp;
294
+ }
295
+ }
296
+
297
+ void intersectPlane(const uv_float3 &a, const uv_float3 &b, const uv_float3 &p,
298
+ const uv_float3 &n, uv_float3 &target) {
299
+ uv_float3 u = b - a;
300
+ uv_float3 v = a - p;
301
+ float dot1 = dot(n, u);
302
+ float dot2 = dot(n, v);
303
+ u = u * (-dot2 / dot1);
304
+ target = a + u;
305
+ }
306
+
307
+ void computeLineIntersection(const Triangle &t1, const Triangle &t2,
308
+ std::vector<uv_float3> &target) {
309
+ uv_float3 n1, n2;
310
+ t1.getNormal(n1);
311
+ t2.getNormal(n2);
312
+
313
+ int o1 = orient3D(t1.a, t1.c, t2.b, t2.a);
314
+ int o2 = orient3D(t1.a, t1.b, t2.c, t2.a);
315
+
316
+ uv_float3 i1, i2;
317
+
318
+ if (o1 > 0) {
319
+ if (o2 > 0) {
320
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
321
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
322
+ } else {
323
+ intersectPlane(t1.a, t1.c, t2.a, n2, i1);
324
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
325
+ }
326
+ } else {
327
+ if (o2 > 0) {
328
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
329
+ intersectPlane(t2.a, t2.c, t1.a, n1, i2);
330
+ } else {
331
+ intersectPlane(t2.a, t2.b, t1.a, n1, i1);
332
+ intersectPlane(t1.a, t1.b, t2.a, n2, i2);
333
+ }
334
+ }
335
+
336
+ target.push_back(i1);
337
+ if (distance_to(i1, i2) >= EPSILON) {
338
+ target.push_back(i2);
339
+ }
340
+ }
341
+
342
+ void makeTriAVertexAlone(Triangle &tri, int oa, int ob, int oc) {
343
+ // Permute a, b, c so that a is alone on its side
344
+ if (oa == ob) {
345
+ // c is alone, permute right so c becomes a
346
+ permuteTriRight(tri);
347
+ } else if (oa == oc) {
348
+ // b is alone, permute so b becomes a
349
+ permuteTriLeft(tri);
350
+ } else if (ob != oc) {
351
+ // In case a, b, c have different orientation, put a on positive side
352
+ if (ob > 0) {
353
+ permuteTriLeft(tri);
354
+ } else if (oc > 0) {
355
+ permuteTriRight(tri);
356
+ }
357
+ }
358
+ }
359
+
360
+ void makeTriAVertexPositive(Triangle &tri, const Triangle &other) {
361
+ int o = orient3D(other.a, other.b, other.c, tri.a);
362
+ if (o < 0) {
363
+ std::swap(tri.b, tri.c);
364
+ }
365
+ }
366
+
367
+ bool crossIntersect(Triangle &t1, Triangle &t2, int o1a, int o1b, int o1c,
368
+ std::vector<uv_float3> *target = nullptr) {
369
+ int o2a = orient3D(t1.a, t1.b, t1.c, t2.a);
370
+ int o2b = orient3D(t1.a, t1.b, t1.c, t2.b);
371
+ int o2c = orient3D(t1.a, t1.b, t1.c, t2.c);
372
+
373
+ if (o2a == o2b && o2a == o2c) {
374
+ return false;
375
+ }
376
+
377
+ // Make a vertex alone on its side for both triangles
378
+ makeTriAVertexAlone(t1, o1a, o1b, o1c);
379
+ makeTriAVertexAlone(t2, o2a, o2b, o2c);
380
+
381
+ // Ensure the vertex on the positive side
382
+ makeTriAVertexPositive(t2, t1);
383
+ makeTriAVertexPositive(t1, t2);
384
+
385
+ int o1 = orient3D(t1.a, t1.b, t2.a, t2.b);
386
+ int o2 = orient3D(t1.a, t1.c, t2.c, t2.a);
387
+
388
+ if (o1 <= 0 && o2 <= 0) {
389
+ if (target) {
390
+ computeLineIntersection(t1, t2, *target);
391
+ }
392
+ return true;
393
+ }
394
+
395
+ return false;
396
+ }
397
+
398
+ void linesIntersect2d(const uv_float3 &a1, const uv_float3 &b1,
399
+ const uv_float3 &a2, const uv_float3 &b2,
400
+ uv_float3 &target) {
401
+ float dx1 = a1.x - b1.x;
402
+ float dx2 = a2.x - b2.x;
403
+ float dy1 = a1.y - b1.y;
404
+ float dy2 = a2.y - b2.y;
405
+
406
+ float D = dx1 * dy2 - dx2 * dy1;
407
+
408
+ float n1 = a1.x * b1.y - a1.y * b1.x;
409
+ float n2 = a2.x * b2.y - a2.y * b2.x;
410
+
411
+ target.x = (n1 * dx2 - n2 * dx1) / D;
412
+ target.y = (n1 * dy2 - n2 * dy1) / D;
413
+ target.z = 0;
414
+ }
415
+
416
+ void clipTriangle(const Triangle &t1, const Triangle &t2,
417
+ std::vector<uv_float3> &target) {
418
+ std::vector<uv_float3> clip = {t1.a, t1.b, t1.c};
419
+ std::vector<uv_float3> output = {t2.a, t2.b, t2.c};
420
+ std::vector<int> orients(output.size() * 3, 0);
421
+ uv_float3 inter;
422
+
423
+ for (int i = 0; i < 3; ++i) {
424
+ const int i_prev = (i + 2) % 3;
425
+ std::vector<uv_float3> input;
426
+ std::copy(output.begin(), output.end(), std::back_inserter(input));
427
+ output.clear();
428
+
429
+ for (size_t j = 0; j < input.size(); ++j) {
430
+ orients[j] = orient2D(clip[i_prev], clip[i], input[j]);
431
+ }
432
+
433
+ for (size_t j = 0; j < input.size(); ++j) {
434
+ const int j_prev = (j - 1 + input.size()) % input.size();
435
+
436
+ if (orients[j] >= 0) {
437
+ if (orients[j_prev] < 0) {
438
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j],
439
+ inter);
440
+ output.push_back({inter.x, inter.y, inter.z});
441
+ }
442
+ output.push_back({input[j].x, input[j].y, input[j].z});
443
+ } else if (orients[j_prev] >= 0) {
444
+ linesIntersect2d(clip[i_prev], clip[i], input[j_prev], input[j], inter);
445
+ output.push_back({inter.x, inter.y, inter.z});
446
+ }
447
+ }
448
+ }
449
+
450
+ // Clear duplicated points
451
+ for (const auto &point : output) {
452
+ int j = 0;
453
+ bool sameFound = false;
454
+ while (!sameFound && j < target.size()) {
455
+ sameFound = distance_to(point, target[j]) <= 1e-6;
456
+ j++;
457
+ }
458
+
459
+ if (!sameFound) {
460
+ target.push_back(point);
461
+ }
462
+ }
463
+ }
464
+
465
+ bool intersectionTypeR1(const Triangle &t1, const Triangle &t2) {
466
+ const uv_float3 &p1 = t1.a;
467
+ const uv_float3 &q1 = t1.b;
468
+ const uv_float3 &r1 = t1.c;
469
+ const uv_float3 &p2 = t2.a;
470
+ const uv_float3 &r2 = t2.c;
471
+
472
+ if (orient2D(r2, p2, q1) >= 0) { // I
473
+ if (orient2D(r2, p1, q1) >= 0) { // II.a
474
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
475
+ return true;
476
+ } else {
477
+ if (orient2D(p1, p2, r1) >= 0) { // IV.a
478
+ if (orient2D(q1, r1, p2) >= 0) { // V
479
+ return true;
480
+ }
481
+ }
482
+ }
483
+ }
484
+ } else {
485
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
486
+ if (orient2D(q1, r1, r2) >= 0) { // III.b
487
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b (diverges from paper)
488
+ return true;
489
+ }
490
+ }
491
+ }
492
+ }
493
+
494
+ return false;
495
+ }
496
+
497
+ bool intersectionTypeR2(const Triangle &t1, const Triangle &t2) {
498
+ const uv_float3 &p1 = t1.a;
499
+ const uv_float3 &q1 = t1.b;
500
+ const uv_float3 &r1 = t1.c;
501
+ const uv_float3 &p2 = t2.a;
502
+ const uv_float3 &q2 = t2.b;
503
+ const uv_float3 &r2 = t2.c;
504
+
505
+ if (orient2D(r2, p2, q1) >= 0) { // I
506
+ if (orient2D(q2, r2, q1) >= 0) { // II.a
507
+ if (orient2D(p1, p2, q1) >= 0) { // III.a
508
+ if (orient2D(p1, q2, q1) <= 0) { // IV.a
509
+ return true;
510
+ }
511
+ } else {
512
+ if (orient2D(p1, p2, r1) >= 0) { // IV.b
513
+ if (orient2D(r2, p2, r1) <= 0) { // V.a
514
+ return true;
515
+ }
516
+ }
517
+ }
518
+ } else {
519
+ if (orient2D(p1, q2, q1) <= 0) { // III.b
520
+ if (orient2D(q2, r2, r1) >= 0) { // IV.c
521
+ if (orient2D(q1, r1, q2) >= 0) { // V.b
522
+ return true;
523
+ }
524
+ }
525
+ }
526
+ }
527
+ } else {
528
+ if (orient2D(r2, p2, r1) >= 0) { // II.b
529
+ if (orient2D(q1, r1, r2) >= 0) { // III.c
530
+ if (orient2D(r1, p1, p2) >= 0) { // IV.d
531
+ return true;
532
+ }
533
+ } else {
534
+ if (orient2D(q1, r1, q2) >= 0) { // IV.e
535
+ if (orient2D(q2, r2, r1) >= 0) { // V.c
536
+ return true;
537
+ }
538
+ }
539
+ }
540
+ }
541
+ }
542
+
543
+ return false;
544
+ }
545
+
546
+ bool coplanarIntersect(Triangle &t1, Triangle &t2,
547
+ std::vector<uv_float3> *target = nullptr) {
548
+ uv_float3 normal, u, v;
549
+ t1.getNormal(normal);
550
+ normal = normalize(normal);
551
+ u = normalize(t1.a - t1.b);
552
+ v = cross(normal, u);
553
+
554
+ // Move basis to t1.a
555
+ u = u + t1.a;
556
+ v = v + t1.a;
557
+ normal = normal + t1.a;
558
+
559
+ Matrix4 _matrix;
560
+ _matrix.set(t1.a.x, u.x, v.x, normal.x, t1.a.y, u.y, v.y, normal.y, t1.a.z,
561
+ u.z, v.z, normal.z, 1, 1, 1, 1);
562
+
563
+ Matrix4 _affineMatrix;
564
+ _affineMatrix.set(0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1);
565
+
566
+ _matrix.invert(); // Invert the _matrix
567
+ _matrix = _affineMatrix * _matrix;
568
+
569
+ // Apply transformation
570
+ apply_matrix4(t1.a, _matrix);
571
+ apply_matrix4(t1.b, _matrix);
572
+ apply_matrix4(t1.c, _matrix);
573
+ apply_matrix4(t2.a, _matrix);
574
+ apply_matrix4(t2.b, _matrix);
575
+ apply_matrix4(t2.c, _matrix);
576
+
577
+ makeTriCounterClockwise(t1);
578
+ makeTriCounterClockwise(t2);
579
+
580
+ const uv_float3 &p1 = t1.a;
581
+ const uv_float3 &p2 = t2.a;
582
+ const uv_float3 &q2 = t2.b;
583
+ const uv_float3 &r2 = t2.c;
584
+
585
+ int o_p2q2 = orient2D(p2, q2, p1);
586
+ int o_q2r2 = orient2D(q2, r2, p1);
587
+ int o_r2p2 = orient2D(r2, p2, p1);
588
+
589
+ bool intersecting = false;
590
+ if (o_p2q2 >= 0) {
591
+ if (o_q2r2 >= 0) {
592
+ if (o_r2p2 >= 0) {
593
+ // + + +
594
+ intersecting = true;
595
+ } else {
596
+ // + + -
597
+ intersecting = intersectionTypeR1(t1, t2);
598
+ }
599
+ } else {
600
+ if (o_r2p2 >= 0) {
601
+ // + - +
602
+ permuteTriRight(t2);
603
+ intersecting = intersectionTypeR1(t1, t2);
604
+ } else {
605
+ // + - -
606
+ intersecting = intersectionTypeR2(t1, t2);
607
+ }
608
+ }
609
+ } else {
610
+ if (o_q2r2 >= 0) {
611
+ if (o_r2p2 >= 0) {
612
+ // - + +
613
+ permuteTriLeft(t2);
614
+ intersecting = intersectionTypeR1(t1, t2);
615
+ } else {
616
+ // - + -
617
+ permuteTriLeft(t2);
618
+ intersecting = intersectionTypeR2(t1, t2);
619
+ }
620
+ } else {
621
+ if (o_r2p2 >= 0) {
622
+ // - - +
623
+ permuteTriRight(t2);
624
+ intersecting = intersectionTypeR2(t1, t2);
625
+ } else {
626
+ // - - -
627
+ std::cerr << "Triangles should not be flat." << std::endl;
628
+ return false;
629
+ }
630
+ }
631
+ }
632
+
633
+ if (intersecting && target) {
634
+ clipTriangle(t1, t2, *target);
635
+
636
+ _matrix.invert();
637
+ // Apply the transform to each target point
638
+ for (int i = 0; i < target->size(); ++i) {
639
+ apply_matrix4(target->at(i), _matrix);
640
+ }
641
+ }
642
+
643
+ return intersecting;
644
+ }
645
+
646
+ // Helper function to calculate the area of a polygon
647
+ float polygon_area(const std::vector<uv_float3> &polygon) {
648
+ if (polygon.size() < 3)
649
+ return 0.0f; // Not a polygon
650
+
651
+ uv_float3 normal = {0.0f, 0.0f, 0.0f}; // Initialize normal vector
652
+
653
+ // Calculate the cross product of edges around the polygon
654
+ for (size_t i = 0; i < polygon.size(); ++i) {
655
+ uv_float3 p1 = polygon[i];
656
+ uv_float3 p2 = polygon[(i + 1) % polygon.size()];
657
+
658
+ normal = normal + cross(p1, p2); // Accumulate the normal vector
659
+ }
660
+
661
+ float area =
662
+ magnitude(normal) / 2.0f; // Area is half the magnitude of the normal
663
+ return area;
664
+ }
665
+
666
+ bool triangle_triangle_intersection(uv_float2 p1, uv_float2 q1, uv_float2 r1,
667
+ uv_float2 p2, uv_float2 q2, uv_float2 r2) {
668
+ Triangle t1(p1, q1, r1);
669
+ Triangle t2(p2, q2, r2);
670
+
671
+ if (isTriDegenerated(t1) || isTriDegenerated(t2)) {
672
+ // std::cerr << "Degenerated triangles provided, skipping." << std::endl;
673
+ return false;
674
+ }
675
+
676
+ int o1a = orient3D(t2.a, t2.b, t2.c, t1.a);
677
+ int o1b = orient3D(t2.a, t2.b, t2.c, t1.b);
678
+ int o1c = orient3D(t2.a, t2.b, t2.c, t1.c);
679
+
680
+ std::vector<uv_float3> intersections;
681
+ bool intersects;
682
+
683
+ if (o1a == o1b && o1a == o1c) // [[likely]]
684
+ {
685
+ intersects = o1a == 0 && coplanarIntersect(t1, t2, &intersections);
686
+ } else // [[unlikely]]
687
+ {
688
+ intersects = crossIntersect(t1, t2, o1a, o1b, o1c, &intersections);
689
+ }
690
+
691
+ if (intersects) {
692
+ float area = polygon_area(intersections);
693
+
694
+ // std::cout << "Intersection area: " << area << std::endl;
695
+ if (area < 1e-10f || std::isfinite(area) == false) {
696
+ // std::cout<<"Invalid area: " << area << std::endl;
697
+ return false; // Ignore intersection if the area is too small
698
+ }
699
+ }
700
+
701
+ return intersects;
702
+ }