Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
5638c1f
1
Parent(s):
22afe09
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -1
- embodied_gen/models/sam3d.py +3 -2
- embodied_gen/utils/monkey_patches.py +4 -8
- thirdparty/sam3d/sam3d/.gitignore +1 -0
- thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md +80 -0
- thirdparty/sam3d/sam3d/CONTRIBUTING.md +39 -0
- thirdparty/sam3d/sam3d/LICENSE +52 -0
- thirdparty/sam3d/sam3d/README.md +152 -0
- thirdparty/sam3d/sam3d/checkpoints/.gitignore +2 -0
- thirdparty/sam3d/sam3d/demo.py +21 -0
- thirdparty/sam3d/sam3d/doc/setup.md +58 -0
- thirdparty/sam3d/sam3d/environments/default.yml +216 -0
- thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb +149 -0
- thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb +162 -0
- thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb +164 -0
- thirdparty/sam3d/sam3d/notebook/inference.py +414 -0
- thirdparty/sam3d/sam3d/notebook/mesh_alignment.py +469 -0
- thirdparty/sam3d/sam3d/patching/hydra +16 -0
- thirdparty/sam3d/sam3d/pyproject.toml +30 -0
- thirdparty/sam3d/sam3d/requirements.dev.txt +4 -0
- thirdparty/sam3d/sam3d/requirements.inference.txt +4 -0
- thirdparty/sam3d/sam3d/requirements.p3d.txt +2 -0
- thirdparty/sam3d/sam3d/requirements.txt +88 -0
- thirdparty/sam3d/sam3d/sam3d_objects/__init__.py +6 -0
- thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py +174 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py +986 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py +189 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py +784 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py +203 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py +50 -0
- thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py +243 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py +142 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py +238 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py +78 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py +238 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py +65 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py +259 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py +1 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py +363 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py +126 -0
- thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py +1 -0
README.md
CHANGED
|
@@ -10,7 +10,7 @@ pinned: false
|
|
| 10 |
license: apache-2.0
|
| 11 |
short_description: Generate physically plausible 3D model from single image.
|
| 12 |
paper: https://huggingface.co/papers/2506.10600
|
| 13 |
-
startup_duration_timeout:
|
| 14 |
---
|
| 15 |
|
| 16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 10 |
license: apache-2.0
|
| 11 |
short_description: Generate physically plausible 3D model from single image.
|
| 12 |
paper: https://huggingface.co/papers/2506.10600
|
| 13 |
+
startup_duration_timeout: 4h
|
| 14 |
---
|
| 15 |
|
| 16 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
embodied_gen/models/sam3d.py
CHANGED
|
@@ -94,9 +94,10 @@ class Sam3dInference:
|
|
| 94 |
) -> dict:
|
| 95 |
if isinstance(image, Image.Image):
|
| 96 |
image = np.array(image)
|
|
|
|
| 97 |
return self.pipeline.run(
|
| 98 |
image,
|
| 99 |
-
|
| 100 |
seed,
|
| 101 |
stage1_only=False,
|
| 102 |
with_mesh_postprocess=False,
|
|
@@ -132,7 +133,7 @@ if __name__ == "__main__":
|
|
| 132 |
|
| 133 |
start = time()
|
| 134 |
|
| 135 |
-
output = pipeline(image, mask, seed=42)
|
| 136 |
print(f"Running cost: {round(time()-start, 1)}")
|
| 137 |
|
| 138 |
if torch.cuda.is_available():
|
|
|
|
| 94 |
) -> dict:
|
| 95 |
if isinstance(image, Image.Image):
|
| 96 |
image = np.array(image)
|
| 97 |
+
image = self.merge_mask_to_rgba(image, mask)
|
| 98 |
return self.pipeline.run(
|
| 99 |
image,
|
| 100 |
+
None,
|
| 101 |
seed,
|
| 102 |
stage1_only=False,
|
| 103 |
with_mesh_postprocess=False,
|
|
|
|
| 133 |
|
| 134 |
start = time()
|
| 135 |
|
| 136 |
+
output = pipeline.run(image, mask, seed=42)
|
| 137 |
print(f"Running cost: {round(time()-start, 1)}")
|
| 138 |
|
| 139 |
if torch.cuda.is_available():
|
embodied_gen/utils/monkey_patches.py
CHANGED
|
@@ -397,17 +397,13 @@ def monkey_patch_sam3d():
|
|
| 397 |
exc_info=True,
|
| 398 |
)
|
| 399 |
|
| 400 |
-
|
| 401 |
-
logger.info("Finished!")
|
| 402 |
-
|
| 403 |
-
return {
|
| 404 |
**ss_return_dict,
|
| 405 |
**outputs,
|
| 406 |
-
"pointmap": pts.cpu().permute((1, 2, 0)),
|
| 407 |
-
"pointmap_colors": pts_colors.cpu().permute(
|
| 408 |
-
(1, 2, 0)
|
| 409 |
-
), # HxWx3
|
| 410 |
}
|
|
|
|
| 411 |
|
| 412 |
InferencePipelinePointMap.run = patch_run
|
| 413 |
|
|
|
|
| 397 |
exc_info=True,
|
| 398 |
)
|
| 399 |
|
| 400 |
+
result = {
|
|
|
|
|
|
|
|
|
|
| 401 |
**ss_return_dict,
|
| 402 |
**outputs,
|
| 403 |
+
"pointmap": pts.cpu().permute((1, 2, 0)),
|
| 404 |
+
"pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
|
|
|
|
|
|
|
| 405 |
}
|
| 406 |
+
return result
|
| 407 |
|
| 408 |
InferencePipelinePointMap.run = patch_run
|
| 409 |
|
thirdparty/sam3d/sam3d/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
thirdparty/sam3d/sam3d/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to sam-3d-objects
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Our Development Process
|
| 6 |
+
... (in particular how this is synced with internal changes to the project)
|
| 7 |
+
|
| 8 |
+
## Pull Requests
|
| 9 |
+
We actively welcome your pull requests.
|
| 10 |
+
|
| 11 |
+
1. Fork the repo and create your branch from `main`.
|
| 12 |
+
2. If you've added code that should be tested, add tests.
|
| 13 |
+
3. If you've changed APIs, update the documentation.
|
| 14 |
+
4. Ensure the test suite passes.
|
| 15 |
+
5. Make sure your code lints.
|
| 16 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 17 |
+
|
| 18 |
+
## Contributor License Agreement ("CLA")
|
| 19 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 20 |
+
to do this once to work on any of Meta's open source projects.
|
| 21 |
+
|
| 22 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 23 |
+
|
| 24 |
+
## Issues
|
| 25 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 26 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 27 |
+
|
| 28 |
+
Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
|
| 29 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 30 |
+
outlined on that page and do not file a public issue.
|
| 31 |
+
|
| 32 |
+
## Coding Style
|
| 33 |
+
* 2 spaces for indentation rather than tabs
|
| 34 |
+
* 80 character line length
|
| 35 |
+
* ...
|
| 36 |
+
|
| 37 |
+
## License
|
| 38 |
+
By contributing to sam-3d-objects, you agree that your contributions will be licensed
|
| 39 |
+
under the LICENSE file in the root directory of this source tree.
|
thirdparty/sam3d/sam3d/LICENSE
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SAM License
|
| 2 |
+
Last Updated: November 19, 2025
|
| 3 |
+
|
| 4 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
|
| 5 |
+
|
| 6 |
+
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 7 |
+
|
| 8 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 9 |
+
SAM Materials distributed by Meta.
|
| 10 |
+
|
| 11 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 12 |
+
|
| 13 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 14 |
+
|
| 15 |
+
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
| 16 |
+
|
| 17 |
+
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
|
| 18 |
+
|
| 19 |
+
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
|
| 20 |
+
|
| 21 |
+
1. License Rights and Redistribution.
|
| 22 |
+
|
| 23 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
|
| 24 |
+
|
| 25 |
+
i. Grant of Patent License. Subject to the terms and conditions of this License, you are granted a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by Meta that are necessarily infringed alone or by combination of their contribution(s) with the SAM 3 Materials. If you institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM 3 Materials incorporated within the work constitutes direct or contributory patent infringement, then any patent licenses granted to you under this License for that work shall terminate as of the date such litigation is filed.
|
| 26 |
+
|
| 27 |
+
b. Redistribution and Use.
|
| 28 |
+
|
| 29 |
+
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
|
| 30 |
+
|
| 31 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
|
| 32 |
+
|
| 33 |
+
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
| 34 |
+
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
|
| 35 |
+
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
| 36 |
+
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 37 |
+
|
| 38 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 39 |
+
|
| 40 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 41 |
+
|
| 42 |
+
5. Intellectual Property.
|
| 43 |
+
|
| 44 |
+
a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 45 |
+
|
| 46 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
|
| 47 |
+
|
| 48 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 49 |
+
|
| 50 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 51 |
+
|
| 52 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
thirdparty/sam3d/sam3d/README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM 3D
|
| 2 |
+
|
| 3 |
+
SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you’re looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body).
|
| 4 |
+
|
| 5 |
+
# SAM 3D Objects
|
| 6 |
+
|
| 7 |
+
**SAM 3D Team**, [Xingyu Chen](https://scholar.google.com/citations?user=gjSHr6YAAAAJ&hl=en&oi=sra)\*, [Fu-Jen Chu](https://fujenchu.github.io/)\*, [Pierre Gleize](https://scholar.google.com/citations?user=4imOcw4AAAAJ&hl=en&oi=ao)\*, [Kevin J Liang](https://kevinjliang.github.io/)\*, [Alexander Sax](https://alexsax.github.io/)\*, [Hao Tang](https://scholar.google.com/citations?user=XY6Nh9YAAAAJ&hl=en&oi=sra)\*, [Weiyao Wang](https://sites.google.com/view/weiyaowang/home)\*, [Michelle Guo](https://scholar.google.com/citations?user=lyjjpNMAAAAJ&hl=en&oi=ao), [Thibaut Hardin](https://github.com/Thibaut-H), [Xiang Li](https://ryanxli.github.io/)⚬, [Aohan Lin](https://github.com/linaohan), [Jia-Wei Liu](https://jia-wei-liu.github.io/), [Ziqi Ma](https://ziqi-ma.github.io/)⚬, [Anushka Sagar](https://www.linkedin.com/in/anushkasagar/), [Bowen Song](https://scholar.google.com/citations?user=QQKVkfcAAAAJ&hl=en&oi=sra)⚬, [Xiaodong Wang](https://scholar.google.com/citations?authuser=2&user=rMpcFYgAAAAJ), [Jianing Yang](https://jedyang.com/)⚬, [Bowen Zhang](http://home.ustc.edu.cn/~zhangbowen/)⚬, [Piotr Dollár](https://pdollar.github.io/)†, [Georgia Gkioxari](https://georgiagkioxari.com/)†, [Matt Feiszli](https://scholar.google.com/citations?user=A-wA73gAAAAJ&hl=en&oi=ao)†§, [Jitendra Malik](https://people.eecs.berkeley.edu/~malik/)†§
|
| 8 |
+
|
| 9 |
+
***Meta Superintelligence Labs***
|
| 10 |
+
|
| 11 |
+
*Core contributor (Alphabetical, Equal Contribution), ⚬Intern, †Project leads, §Equal Contribution
|
| 12 |
+
|
| 13 |
+
[[`Paper`](https://ai.meta.com/research/publications/sam-3d-3dfy-anything-in-images/)] [[`Code`](https://github.com/facebookresearch/sam-3d-objects)] [[`Website`](https://ai.meta.com/sam3d/)] [[`Demo`](https://www.aidemos.meta.com/segment-anything/editor/convert-image-to-3d)] [[`Blog`](https://ai.meta.com/blog/sam-3d/)] [[`BibTeX`](#citing-sam-3d-objects)] [[`Roboflow`](https://blog.roboflow.com/sam-3d/)]
|
| 14 |
+
|
| 15 |
+
**SAM 3D Objects** is a foundation model that reconstructs full 3D shape geometry, texture, and layout from a single image, excelling in real-world scenarios with occlusion and clutter by using progressive training and a data engine with human feedback. It outperforms prior 3D generation models in human preference tests on real-world objects and scenes. We released code, weights, online demo, and a new challenging benchmark.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
<p align="center"><img src="doc/intro.png"/></p>
|
| 19 |
+
|
| 20 |
+
-----
|
| 21 |
+
|
| 22 |
+
<p align="center"><img src="doc/arch.png"/></p>
|
| 23 |
+
|
| 24 |
+
## Latest updates
|
| 25 |
+
|
| 26 |
+
**11/19/2025** - Checkpoints Launched, Web Demo and Paper are out.
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
|
| 30 |
+
Follow the [setup](doc/setup.md) steps before running the following.
|
| 31 |
+
|
| 32 |
+
## Single or Multi-Object 3D Generation
|
| 33 |
+
|
| 34 |
+
SAM 3D Objects can convert masked objects in an image, into 3D models with pose, shape, texture, and layout. SAM 3D is designed to be robust in challenging natural images, handling small objects and occlusions, unusual poses, and difficult situations encountered in uncurated natural scenes like this kidsroom:
|
| 35 |
+
|
| 36 |
+
<p align="center">
|
| 37 |
+
<img src="notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png" width="55%"/>
|
| 38 |
+
<img src="doc/kidsroom_transparent.gif" width="40%"/>
|
| 39 |
+
</p>
|
| 40 |
+
|
| 41 |
+
For a quick start, run `python demo.py` or use the the following lines of code:
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
import sys
|
| 45 |
+
|
| 46 |
+
# import inference code
|
| 47 |
+
sys.path.append("notebook")
|
| 48 |
+
from inference import Inference, load_image, load_single_mask
|
| 49 |
+
|
| 50 |
+
# load model
|
| 51 |
+
tag = "hf"
|
| 52 |
+
config_path = f"checkpoints/{tag}/pipeline.yaml"
|
| 53 |
+
inference = Inference(config_path, compile=False)
|
| 54 |
+
|
| 55 |
+
# load image and mask
|
| 56 |
+
image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png")
|
| 57 |
+
mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14)
|
| 58 |
+
|
| 59 |
+
# run model
|
| 60 |
+
output = inference(image, mask, seed=42)
|
| 61 |
+
|
| 62 |
+
# export gaussian splat
|
| 63 |
+
output["gs"].save_ply(f"splat.ply")
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
For more details and multi-object reconstruction, please take a look at out two jupyter notebooks:
|
| 67 |
+
* [single object](notebook/demo_single_object.ipynb)
|
| 68 |
+
* [multi object](notebook/demo_multi_object.ipynb)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
## SAM 3D Body
|
| 72 |
+
|
| 73 |
+
[SAM 3D Body (3DB)](https://github.com/facebookresearch/sam-3d-body) is a robust promptable foundation model for single-image 3D human mesh recovery (HMR).
|
| 74 |
+
|
| 75 |
+
As a way to combine the strengths of both **SAM 3D Objects** and **SAM 3D Body**, we provide an example notebook that demonstrates how to combine the results of both models such that they are aligned in the same frame of reference. Check it out [here](notebook/demo_3db_mesh_alignment.ipynb).
|
| 76 |
+
|
| 77 |
+
## License
|
| 78 |
+
|
| 79 |
+
The SAM 3D Objects model checkpoints and code are licensed under [SAM License](./LICENSE).
|
| 80 |
+
|
| 81 |
+
## Contributing
|
| 82 |
+
|
| 83 |
+
See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
|
| 84 |
+
|
| 85 |
+
## Contributors
|
| 86 |
+
|
| 87 |
+
The SAM 3D Objects project was made possible with the help of many contributors.
|
| 88 |
+
|
| 89 |
+
Robbie Adkins,
|
| 90 |
+
Paris Baptiste,
|
| 91 |
+
Karen Bergan,
|
| 92 |
+
Kai Brown,
|
| 93 |
+
Michelle Chan,
|
| 94 |
+
Ida Cheng,
|
| 95 |
+
Khadijat Durojaiye,
|
| 96 |
+
Patrick Edwards,
|
| 97 |
+
Daniella Factor,
|
| 98 |
+
Facundo Figueroa,
|
| 99 |
+
Rene de la Fuente,
|
| 100 |
+
Eva Galper,
|
| 101 |
+
Cem Gokmen,
|
| 102 |
+
Alex He,
|
| 103 |
+
Enmanuel Hernandez,
|
| 104 |
+
Dex Honsa,
|
| 105 |
+
Leonna Jones,
|
| 106 |
+
Arpit Kalla,
|
| 107 |
+
Kris Kitani,
|
| 108 |
+
Helen Klein,
|
| 109 |
+
Kei Koyama,
|
| 110 |
+
Robert Kuo,
|
| 111 |
+
Vivian Lee,
|
| 112 |
+
Alex Lende,
|
| 113 |
+
Jonny Li,
|
| 114 |
+
Kehan Lyu,
|
| 115 |
+
Faye Ma,
|
| 116 |
+
Mallika Malhotra,
|
| 117 |
+
Sasha Mitts,
|
| 118 |
+
William Ngan,
|
| 119 |
+
George Orlin,
|
| 120 |
+
Peter Park,
|
| 121 |
+
Don Pinkus,
|
| 122 |
+
Roman Radle,
|
| 123 |
+
Nikhila Ravi,
|
| 124 |
+
Azita Shokrpour,
|
| 125 |
+
Jasmine Shone,
|
| 126 |
+
Zayida Suber,
|
| 127 |
+
Phillip Thomas,
|
| 128 |
+
Tatum Turner,
|
| 129 |
+
Joseph Walker,
|
| 130 |
+
Meng Wang,
|
| 131 |
+
Claudette Ward,
|
| 132 |
+
Andrew Westbury,
|
| 133 |
+
Lea Wilken,
|
| 134 |
+
Nan Yang,
|
| 135 |
+
Yael Yungster
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Citing SAM 3D Objects
|
| 139 |
+
|
| 140 |
+
If you use SAM 3D Objects in your research, please use the following BibTeX entry.
|
| 141 |
+
|
| 142 |
+
```
|
| 143 |
+
@article{sam3dteam2025sam3d3dfyimages,
|
| 144 |
+
title={SAM 3D: 3Dfy Anything in Images},
|
| 145 |
+
author={SAM 3D Team and Xingyu Chen and Fu-Jen Chu and Pierre Gleize and Kevin J Liang and Alexander Sax and Hao Tang and Weiyao Wang and Michelle Guo and Thibaut Hardin and Xiang Li and Aohan Lin and Jiawei Liu and Ziqi Ma and Anushka Sagar and Bowen Song and Xiaodong Wang and Jianing Yang and Bowen Zhang and Piotr Dollár and Georgia Gkioxari and Matt Feiszli and Jitendra Malik},
|
| 146 |
+
year={2025},
|
| 147 |
+
eprint={2511.16624},
|
| 148 |
+
archivePrefix={arXiv},
|
| 149 |
+
primaryClass={cs.CV},
|
| 150 |
+
url={https://arxiv.org/abs/2511.16624},
|
| 151 |
+
}
|
| 152 |
+
```
|
thirdparty/sam3d/sam3d/checkpoints/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*
|
| 2 |
+
!.gitignore
|
thirdparty/sam3d/sam3d/demo.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
# import inference code
|
| 4 |
+
sys.path.append("notebook")
|
| 5 |
+
from inference import Inference, load_image, load_single_mask
|
| 6 |
+
|
| 7 |
+
# load model
|
| 8 |
+
tag = "hf"
|
| 9 |
+
config_path = f"checkpoints/{tag}/pipeline.yaml"
|
| 10 |
+
inference = Inference(config_path, compile=False)
|
| 11 |
+
|
| 12 |
+
# load image (RGBA only, mask is embedded in the alpha channel)
|
| 13 |
+
image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png")
|
| 14 |
+
mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14)
|
| 15 |
+
|
| 16 |
+
# run model
|
| 17 |
+
output = inference(image, mask, seed=42)
|
| 18 |
+
|
| 19 |
+
# export gaussian splat
|
| 20 |
+
output["gs"].save_ply(f"splat.ply")
|
| 21 |
+
print("Your reconstruction has been saved to splat.ply")
|
thirdparty/sam3d/sam3d/doc/setup.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Setup
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
|
| 5 |
+
* A linux 64-bits architecture (i.e. `linux-64` platform in `mamba info`).
|
| 6 |
+
* A NVIDIA GPU with at least 32 Gb of VRAM.
|
| 7 |
+
|
| 8 |
+
## 1. Setup Python Environment
|
| 9 |
+
|
| 10 |
+
The following will install the default environment. If you use `conda` instead of `mamba`, replace its name in the first two lines. Note that you may have to build the environment on a compute node with GPU (e.g., you may get a `RuntimeError: Not compiled with GPU support` error when running certain parts of the code that use Pytorch3D).
|
| 11 |
+
|
| 12 |
+
```bash
|
| 13 |
+
# create sam3d-objects environment
|
| 14 |
+
mamba env create -f environments/default.yml
|
| 15 |
+
mamba activate sam3d-objects
|
| 16 |
+
|
| 17 |
+
# for pytorch/cuda dependencies
|
| 18 |
+
export PIP_EXTRA_INDEX_URL="https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121"
|
| 19 |
+
|
| 20 |
+
# install sam3d-objects and core dependencies
|
| 21 |
+
pip install -e '.[dev]'
|
| 22 |
+
pip install -e '.[p3d]' # pytorch3d dependency on pytorch is broken, this 2-step approach solves it
|
| 23 |
+
|
| 24 |
+
# for inference
|
| 25 |
+
export PIP_FIND_LINKS="https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html"
|
| 26 |
+
pip install -e '.[inference]'
|
| 27 |
+
|
| 28 |
+
# patch things that aren't yet in official pip packages
|
| 29 |
+
./patching/hydra # https://github.com/facebookresearch/hydra/pull/2863
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## 2. Getting Checkpoints
|
| 33 |
+
|
| 34 |
+
### From HuggingFace
|
| 35 |
+
|
| 36 |
+
⚠️ Before using SAM 3D Objects, please request access to the checkpoints on the SAM 3D Objects
|
| 37 |
+
Hugging Face [repo](https://huggingface.co/facebook/sam-3d-objects). Once accepted, you
|
| 38 |
+
need to be authenticated to download the checkpoints. You can do this by running
|
| 39 |
+
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
|
| 40 |
+
(e.g. `hf auth login` after generating an access token).
|
| 41 |
+
|
| 42 |
+
⚠️ SAM 3D Objects is available via HuggingFace globally, **except** in comprehensively sanctioned jurisdictions.
|
| 43 |
+
Sanctioned jurisdiction will result in requests being **rejected**.
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install 'huggingface-hub[cli]<1.0'
|
| 47 |
+
|
| 48 |
+
TAG=hf
|
| 49 |
+
hf download \
|
| 50 |
+
--repo-type model \
|
| 51 |
+
--local-dir checkpoints/${TAG}-download \
|
| 52 |
+
--max-workers 1 \
|
| 53 |
+
facebook/sam-3d-objects
|
| 54 |
+
mv checkpoints/${TAG}-download/checkpoints checkpoints/${TAG}
|
| 55 |
+
rm -rf checkpoints/${TAG}-download
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
|
thirdparty/sam3d/sam3d/environments/default.yml
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sam3d-objects
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 6 |
+
- _openmp_mutex=4.5=2_gnu
|
| 7 |
+
- alsa-lib=1.2.13=hb9d3cd8_0
|
| 8 |
+
- attr=2.5.1=h166bdaf_1
|
| 9 |
+
- binutils=2.43=h4852527_4
|
| 10 |
+
- binutils_impl_linux-64=2.43=h4bf12b8_4
|
| 11 |
+
- binutils_linux-64=2.43=h4852527_4
|
| 12 |
+
- bzip2=1.0.8=h4bc722e_7
|
| 13 |
+
- c-compiler=1.7.0=hd590300_1
|
| 14 |
+
- ca-certificates=2025.1.31=hbcca054_0
|
| 15 |
+
- cairo=1.18.0=h3faef2a_0
|
| 16 |
+
- cuda-cccl=12.1.109=ha770c72_0
|
| 17 |
+
- cuda-cccl-impl=2.0.1=ha770c72_1
|
| 18 |
+
- cuda-cccl_linux-64=12.1.109=ha770c72_0
|
| 19 |
+
- cuda-command-line-tools=12.1.1=ha770c72_0
|
| 20 |
+
- cuda-compiler=12.1.1=hbad6d8a_0
|
| 21 |
+
- cuda-cudart=12.1.105=hd3aeb46_0
|
| 22 |
+
- cuda-cudart-dev=12.1.105=hd3aeb46_0
|
| 23 |
+
- cuda-cudart-dev_linux-64=12.1.105=h59595ed_0
|
| 24 |
+
- cuda-cudart-static=12.1.105=hd3aeb46_0
|
| 25 |
+
- cuda-cudart-static_linux-64=12.1.105=h59595ed_0
|
| 26 |
+
- cuda-cudart_linux-64=12.1.105=h59595ed_0
|
| 27 |
+
- cuda-cuobjdump=12.1.111=h59595ed_0
|
| 28 |
+
- cuda-cupti=12.1.105=h59595ed_0
|
| 29 |
+
- cuda-cupti-dev=12.1.105=h59595ed_0
|
| 30 |
+
- cuda-cuxxfilt=12.1.105=h59595ed_0
|
| 31 |
+
- cuda-driver-dev=12.1.105=hd3aeb46_0
|
| 32 |
+
- cuda-driver-dev_linux-64=12.1.105=h59595ed_0
|
| 33 |
+
- cuda-gdb=12.1.105=hd47b8d6_0
|
| 34 |
+
- cuda-libraries=12.1.1=ha770c72_0
|
| 35 |
+
- cuda-libraries-dev=12.1.1=ha770c72_0
|
| 36 |
+
- cuda-nsight=12.1.105=ha770c72_0
|
| 37 |
+
- cuda-nvcc=12.1.105=hcdd1206_1
|
| 38 |
+
- cuda-nvcc-dev_linux-64=12.1.105=ha770c72_0
|
| 39 |
+
- cuda-nvcc-impl=12.1.105=hd3aeb46_0
|
| 40 |
+
- cuda-nvcc-tools=12.1.105=hd3aeb46_0
|
| 41 |
+
- cuda-nvcc_linux-64=12.1.105=h8a487aa_1
|
| 42 |
+
- cuda-nvdisasm=12.1.105=h59595ed_0
|
| 43 |
+
- cuda-nvml-dev=12.1.105=h59595ed_0
|
| 44 |
+
- cuda-nvprof=12.1.105=h59595ed_0
|
| 45 |
+
- cuda-nvprune=12.1.105=h59595ed_0
|
| 46 |
+
- cuda-nvrtc=12.1.105=hd3aeb46_0
|
| 47 |
+
- cuda-nvrtc-dev=12.1.105=hd3aeb46_0
|
| 48 |
+
- cuda-nvtx=12.1.105=h59595ed_0
|
| 49 |
+
- cuda-nvvp=12.1.105=h59595ed_0
|
| 50 |
+
- cuda-opencl=12.1.105=h59595ed_0
|
| 51 |
+
- cuda-opencl-dev=12.1.105=h59595ed_0
|
| 52 |
+
- cuda-profiler-api=12.1.105=ha770c72_0
|
| 53 |
+
- cuda-sanitizer-api=12.1.105=h59595ed_0
|
| 54 |
+
- cuda-toolkit=12.1.1=ha804496_0
|
| 55 |
+
- cuda-tools=12.1.1=ha770c72_0
|
| 56 |
+
- cuda-version=12.1=h1d6eff3_3
|
| 57 |
+
- cuda-visual-tools=12.1.1=ha770c72_0
|
| 58 |
+
- cxx-compiler=1.7.0=h00ab1b0_1
|
| 59 |
+
- dbus=1.13.6=h5008d03_3
|
| 60 |
+
- expat=2.6.4=h5888daf_0
|
| 61 |
+
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
| 62 |
+
- font-ttf-inconsolata=3.000=h77eed37_0
|
| 63 |
+
- font-ttf-source-code-pro=2.038=h77eed37_0
|
| 64 |
+
- font-ttf-ubuntu=0.83=h77eed37_3
|
| 65 |
+
- fontconfig=2.15.0=h7e30c49_1
|
| 66 |
+
- fonts-conda-ecosystem=1=0
|
| 67 |
+
- fonts-conda-forge=1=0
|
| 68 |
+
- freetype=2.13.3=h48d6fc4_0
|
| 69 |
+
- gcc=12.4.0=h236703b_2
|
| 70 |
+
- gcc_impl_linux-64=12.4.0=h26ba24d_2
|
| 71 |
+
- gcc_linux-64=12.4.0=h6b7512a_8
|
| 72 |
+
- gds-tools=1.6.1.9=hd3aeb46_0
|
| 73 |
+
- gettext=0.23.1=h5888daf_0
|
| 74 |
+
- gettext-tools=0.23.1=h5888daf_0
|
| 75 |
+
- glib=2.82.2=h07242d1_1
|
| 76 |
+
- glib-tools=2.82.2=h4833e2c_1
|
| 77 |
+
- gmp=6.3.0=hac33072_2
|
| 78 |
+
- graphite2=1.3.13=h59595ed_1003
|
| 79 |
+
- gst-plugins-base=1.24.4=h9ad1361_0
|
| 80 |
+
- gstreamer=1.24.4=haf2f30d_0
|
| 81 |
+
- gxx=12.4.0=h236703b_2
|
| 82 |
+
- gxx_impl_linux-64=12.4.0=h3ff227c_2
|
| 83 |
+
- gxx_linux-64=12.4.0=h8489865_8
|
| 84 |
+
- harfbuzz=8.5.0=hfac3d4d_0
|
| 85 |
+
- icu=73.2=h59595ed_0
|
| 86 |
+
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
| 87 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 88 |
+
- krb5=1.21.3=h659f571_0
|
| 89 |
+
- lame=3.100=h166bdaf_1003
|
| 90 |
+
- ld_impl_linux-64=2.43=h712a8e2_4
|
| 91 |
+
- libasprintf=0.23.1=h8e693c7_0
|
| 92 |
+
- libasprintf-devel=0.23.1=h8e693c7_0
|
| 93 |
+
- libcap=2.75=h39aace5_0
|
| 94 |
+
- libclang-cpp15=15.0.7=default_h127d8a8_5
|
| 95 |
+
- libclang13=19.1.2=default_h9c6a7e4_1
|
| 96 |
+
- libcublas=12.1.3.1=hd3aeb46_0
|
| 97 |
+
- libcublas-dev=12.1.3.1=hd3aeb46_0
|
| 98 |
+
- libcufft=11.0.2.54=hd3aeb46_0
|
| 99 |
+
- libcufft-dev=11.0.2.54=hd3aeb46_0
|
| 100 |
+
- libcufile=1.6.1.9=hd3aeb46_0
|
| 101 |
+
- libcufile-dev=1.6.1.9=hd3aeb46_0
|
| 102 |
+
- libcups=2.3.3=h4637d8d_4
|
| 103 |
+
- libcurand=10.3.2.106=hd3aeb46_0
|
| 104 |
+
- libcurand-dev=10.3.2.106=hd3aeb46_0
|
| 105 |
+
- libcusolver=11.4.5.107=hd3aeb46_0
|
| 106 |
+
- libcusolver-dev=11.4.5.107=hd3aeb46_0
|
| 107 |
+
- libcusparse=12.1.0.106=hd3aeb46_0
|
| 108 |
+
- libcusparse-dev=12.1.0.106=hd3aeb46_0
|
| 109 |
+
- libedit=3.1.20250104=pl5321h7949ede_0
|
| 110 |
+
- libevent=2.1.12=hf998b51_1
|
| 111 |
+
- libexpat=2.6.4=h5888daf_0
|
| 112 |
+
- libffi=3.4.6=h2dba641_0
|
| 113 |
+
- libflac=1.4.3=h59595ed_0
|
| 114 |
+
- libgcc=14.2.0=h767d61c_2
|
| 115 |
+
- libgcc-devel_linux-64=12.4.0=h1762d19_102
|
| 116 |
+
- libgcc-ng=14.2.0=h69a702a_2
|
| 117 |
+
- libgcrypt-lib=1.11.0=hb9d3cd8_2
|
| 118 |
+
- libgettextpo=0.23.1=h5888daf_0
|
| 119 |
+
- libgettextpo-devel=0.23.1=h5888daf_0
|
| 120 |
+
- libglib=2.82.2=h2ff4ddf_1
|
| 121 |
+
- libgomp=14.2.0=h767d61c_2
|
| 122 |
+
- libgpg-error=1.51=hbd13f7d_1
|
| 123 |
+
- libiconv=1.18=h4ce23a2_1
|
| 124 |
+
- libjpeg-turbo=3.0.0=hd590300_1
|
| 125 |
+
- libllvm15=15.0.7=hb3ce162_4
|
| 126 |
+
- libllvm19=19.1.2=ha7bfdaf_0
|
| 127 |
+
- liblzma=5.6.4=hb9d3cd8_0
|
| 128 |
+
- liblzma-devel=5.6.4=hb9d3cd8_0
|
| 129 |
+
- libnpp=12.1.0.40=hd3aeb46_0
|
| 130 |
+
- libnpp-dev=12.1.0.40=hd3aeb46_0
|
| 131 |
+
- libnsl=2.0.1=hd590300_0
|
| 132 |
+
- libnuma=2.0.18=h4ab18f5_2
|
| 133 |
+
- libnvjitlink=12.1.105=hd3aeb46_0
|
| 134 |
+
- libnvjitlink-dev=12.1.105=hd3aeb46_0
|
| 135 |
+
- libnvjpeg=12.2.0.2=h59595ed_0
|
| 136 |
+
- libnvjpeg-dev=12.2.0.2=ha770c72_0
|
| 137 |
+
- libogg=1.3.5=h4ab18f5_0
|
| 138 |
+
- libopus=1.3.1=h7f98852_1
|
| 139 |
+
- libpng=1.6.47=h943b412_0
|
| 140 |
+
- libpq=16.8=h87c4ccc_0
|
| 141 |
+
- libsanitizer=12.4.0=ha732cd4_2
|
| 142 |
+
- libsndfile=1.2.2=hc60ed4a_1
|
| 143 |
+
- libsqlite=3.49.1=hee588c1_2
|
| 144 |
+
- libstdcxx=14.2.0=h8f9b012_2
|
| 145 |
+
- libstdcxx-devel_linux-64=12.4.0=h1762d19_102
|
| 146 |
+
- libstdcxx-ng=14.2.0=h4852527_2
|
| 147 |
+
- libsystemd0=257.4=h4e0b6ca_1
|
| 148 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 149 |
+
- libvorbis=1.3.7=h9c3ff4c_0
|
| 150 |
+
- libxcb=1.15=h0b41bf4_0
|
| 151 |
+
- libxkbcommon=1.7.0=h662e7e4_0
|
| 152 |
+
- libxkbfile=1.1.0=h166bdaf_1
|
| 153 |
+
- libxml2=2.12.7=h4c95cb1_3
|
| 154 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
| 155 |
+
- lz4-c=1.10.0=h5888daf_1
|
| 156 |
+
- mpg123=1.32.9=hc50e24c_0
|
| 157 |
+
- mysql-common=8.3.0=h70512c7_5
|
| 158 |
+
- mysql-libs=8.3.0=ha479ceb_5
|
| 159 |
+
- ncurses=6.5=h2d0b736_3
|
| 160 |
+
- nsight-compute=2023.1.1.4=h3718151_0
|
| 161 |
+
- nspr=4.36=h5888daf_0
|
| 162 |
+
- nss=3.108=h159eef7_0
|
| 163 |
+
- ocl-icd=2.3.2=hb9d3cd8_2
|
| 164 |
+
- opencl-headers=2024.10.24=h5888daf_0
|
| 165 |
+
- openssl=3.4.1=h7b32b05_0
|
| 166 |
+
- packaging=24.2=pyhd8ed1ab_2
|
| 167 |
+
- pcre2=10.44=hba22ea6_2
|
| 168 |
+
- pip=25.0.1=pyh8b19718_0
|
| 169 |
+
- pixman=0.44.2=h29eaf8c_0
|
| 170 |
+
- pthread-stubs=0.4=hb9d3cd8_1002
|
| 171 |
+
- pulseaudio-client=17.0=hb77b528_0
|
| 172 |
+
- python=3.11.0=he550d4f_1_cpython
|
| 173 |
+
- qt-main=5.15.8=hc9dc06e_21
|
| 174 |
+
- readline=8.2=h8c095d6_2
|
| 175 |
+
- setuptools=75.8.2=pyhff2d567_0
|
| 176 |
+
- sysroot_linux-64=2.17=h0157908_18
|
| 177 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 178 |
+
- tzdata=2025b=h78e105d_0
|
| 179 |
+
- wayland=1.23.1=h3e06ad9_0
|
| 180 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
| 181 |
+
- xcb-util=0.4.0=hd590300_1
|
| 182 |
+
- xcb-util-image=0.4.0=h8ee46fc_1
|
| 183 |
+
- xcb-util-keysyms=0.4.0=h8ee46fc_1
|
| 184 |
+
- xcb-util-renderutil=0.3.9=hd590300_1
|
| 185 |
+
- xcb-util-wm=0.4.1=h8ee46fc_1
|
| 186 |
+
- xkeyboard-config=2.42=h4ab18f5_0
|
| 187 |
+
- xorg-compositeproto=0.4.2=hb9d3cd8_1002
|
| 188 |
+
- xorg-damageproto=1.2.1=hb9d3cd8_1003
|
| 189 |
+
- xorg-fixesproto=5.0=hb9d3cd8_1003
|
| 190 |
+
- xorg-inputproto=2.3.2=hb9d3cd8_1003
|
| 191 |
+
- xorg-kbproto=1.0.7=hb9d3cd8_1003
|
| 192 |
+
- xorg-libice=1.1.2=hb9d3cd8_0
|
| 193 |
+
- xorg-libsm=1.2.6=he73a12e_0
|
| 194 |
+
- xorg-libx11=1.8.9=h8ee46fc_0
|
| 195 |
+
- xorg-libxau=1.0.12=hb9d3cd8_0
|
| 196 |
+
- xorg-libxcomposite=0.4.6=h0b41bf4_1
|
| 197 |
+
- xorg-libxdamage=1.1.5=h7f98852_1
|
| 198 |
+
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
|
| 199 |
+
- xorg-libxext=1.3.4=h0b41bf4_2
|
| 200 |
+
- xorg-libxfixes=5.0.3=h7f98852_1004
|
| 201 |
+
- xorg-libxi=1.7.10=h4bc722e_1
|
| 202 |
+
- xorg-libxrandr=1.5.2=h7f98852_1
|
| 203 |
+
- xorg-libxrender=0.9.11=hd590300_0
|
| 204 |
+
- xorg-libxtst=1.2.5=h4bc722e_0
|
| 205 |
+
- xorg-randrproto=1.5.0=hb9d3cd8_1002
|
| 206 |
+
- xorg-recordproto=1.14.2=hb9d3cd8_1003
|
| 207 |
+
- xorg-renderproto=0.11.1=hb9d3cd8_1003
|
| 208 |
+
- xorg-util-macros=1.20.2=hb9d3cd8_0
|
| 209 |
+
- xorg-xextproto=7.3.0=hb9d3cd8_1004
|
| 210 |
+
- xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1005
|
| 211 |
+
- xorg-xproto=7.0.31=hb9d3cd8_1008
|
| 212 |
+
- xz=5.6.4=hbcc6ac9_0
|
| 213 |
+
- xz-gpl-tools=5.6.4=hbcc6ac9_0
|
| 214 |
+
- xz-tools=5.6.4=hb9d3cd8_0
|
| 215 |
+
- zlib=1.3.1=hb9d3cd8_2
|
| 216 |
+
- zstd=1.5.7=hb8e6e7a_2
|
thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# SAM 3D Body (3DB) Mesh Alignment to SAM 3D Object Scale\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook processes a single 3DB mesh and aligns it to the SAM 3D Objects scale.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Input Data:**\n",
|
| 12 |
+
"- `images/human_object/image.jpg` - Input image for MoGe\n",
|
| 13 |
+
"- `meshes/human_object/3DB_results/mask_human.png` - Human mask\n",
|
| 14 |
+
"- `meshes/human_object/3DB_results/human.ply` - Single 3DB mesh in OpenGL coordinates\n",
|
| 15 |
+
"- `meshes/human_object/3DB_results/focal_length.json` - 3DB focal length\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"**Output:**\n",
|
| 18 |
+
"- `meshes/human_object/aligned_meshes/human_aligned.ply` - Aligned 3DB mesh in OpenGL coordinates"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"import os\n",
|
| 28 |
+
"import torch\n",
|
| 29 |
+
"import matplotlib.pyplot as plt\n",
|
| 30 |
+
"from PIL import Image\n",
|
| 31 |
+
"from mesh_alignment import process_and_save_alignment\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 34 |
+
"print(f\"Using device: {device}\")\n",
|
| 35 |
+
"PATH = os.getcwd()\n",
|
| 36 |
+
"print(f\"Current working directory: {PATH}\")\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Please inference the SAM 3D Body (3DB) Repo (https://github.com/facebookresearch/sam-3d-body) to get the 3DB Results\n",
|
| 39 |
+
"image_path = f\"{PATH}/images/human_object/image.png\"\n",
|
| 40 |
+
"mask_path = f\"{PATH}/meshes/human_object/3DB_results/mask_human.png\"\n",
|
| 41 |
+
"mesh_path = f\"{PATH}/meshes/human_object/3DB_results/human.ply\"\n",
|
| 42 |
+
"focal_length_json_path = f\"{PATH}/meshes/human_object/3DB_results/focal_length.json\"\n",
|
| 43 |
+
"output_dir = f\"{PATH}/meshes/human_object/aligned_meshes\"\n",
|
| 44 |
+
"os.makedirs(output_dir, exist_ok=True)\n"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"source": [
|
| 51 |
+
"## 1. Load and Display Input Data"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"cell_type": "code",
|
| 56 |
+
"execution_count": null,
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"outputs": [],
|
| 59 |
+
"source": [
|
| 60 |
+
"input_image = Image.open(image_path)\n",
|
| 61 |
+
"mask = Image.open(mask_path).convert('L')\n",
|
| 62 |
+
"fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n",
|
| 63 |
+
"axes[0].imshow(input_image)\n",
|
| 64 |
+
"axes[0].set_title('Input Image')\n",
|
| 65 |
+
"axes[0].axis('off')\n",
|
| 66 |
+
"axes[1].imshow(mask, cmap='gray')\n",
|
| 67 |
+
"axes[1].set_title('Mask')\n",
|
| 68 |
+
"axes[1].axis('off')\n",
|
| 69 |
+
"plt.tight_layout()\n",
|
| 70 |
+
"plt.show()"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"source": [
|
| 77 |
+
"## 2. Process and Save Aligned Mesh"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"\n",
|
| 87 |
+
"success, output_mesh_path, result = process_and_save_alignment(\n",
|
| 88 |
+
" mesh_path=mesh_path,\n",
|
| 89 |
+
" mask_path=mask_path,\n",
|
| 90 |
+
" image_path=image_path,\n",
|
| 91 |
+
" output_dir=output_dir,\n",
|
| 92 |
+
" device=device,\n",
|
| 93 |
+
" focal_length_json_path=focal_length_json_path\n",
|
| 94 |
+
")\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"if success:\n",
|
| 97 |
+
" print(f\"Alignment completed successfully! Output: {output_mesh_path}\")\n",
|
| 98 |
+
"else:\n",
|
| 99 |
+
" print(\"Alignment failed!\")"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "markdown",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"source": [
|
| 106 |
+
"## 3. Interactive 3D Visualization\n"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": [
|
| 115 |
+
"from mesh_alignment import visualize_meshes_interactive\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"aligned_mesh_path = f\"{PATH}/meshes/human_object/aligned_meshes/human_aligned.ply\"\n",
|
| 118 |
+
"dfy_mesh_path = f\"{PATH}/meshes/human_object/3Dfy_results/0.glb\"\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"demo, combined_glb_path = visualize_meshes_interactive(\n",
|
| 121 |
+
" aligned_mesh_path=aligned_mesh_path,\n",
|
| 122 |
+
" dfy_mesh_path=dfy_mesh_path,\n",
|
| 123 |
+
" share=True\n",
|
| 124 |
+
")"
|
| 125 |
+
]
|
| 126 |
+
}
|
| 127 |
+
],
|
| 128 |
+
"metadata": {
|
| 129 |
+
"kernelspec": {
|
| 130 |
+
"display_name": "sam3d_objects-3dfy",
|
| 131 |
+
"language": "python",
|
| 132 |
+
"name": "python3"
|
| 133 |
+
},
|
| 134 |
+
"language_info": {
|
| 135 |
+
"codemirror_mode": {
|
| 136 |
+
"name": "ipython",
|
| 137 |
+
"version": 3
|
| 138 |
+
},
|
| 139 |
+
"file_extension": ".py",
|
| 140 |
+
"mimetype": "text/x-python",
|
| 141 |
+
"name": "python",
|
| 142 |
+
"nbconvert_exporter": "python",
|
| 143 |
+
"pygments_lexer": "ipython3",
|
| 144 |
+
"version": "3.11.0"
|
| 145 |
+
}
|
| 146 |
+
},
|
| 147 |
+
"nbformat": 4,
|
| 148 |
+
"nbformat_minor": 4
|
| 149 |
+
}
|
thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "markdown",
|
| 14 |
+
"metadata": {},
|
| 15 |
+
"source": [
|
| 16 |
+
"## 1. Imports and Model Loading"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": null,
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"import os\n",
|
| 26 |
+
"import uuid\n",
|
| 27 |
+
"import imageio\n",
|
| 28 |
+
"import numpy as np\n",
|
| 29 |
+
"from IPython.display import Image as ImageDisplay\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"from inference import Inference, ready_gaussian_for_video_rendering, load_image, load_masks, display_image, make_scene, render_video, interactive_visualizer"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"PATH = os.getcwd()\n",
|
| 41 |
+
"TAG = \"hf\"\n",
|
| 42 |
+
"config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n",
|
| 43 |
+
"inference = Inference(config_path, compile=False)"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "markdown",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"source": [
|
| 50 |
+
"## 2. Load input image to lift to 3D (multiple objects)"
|
| 51 |
+
]
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"execution_count": null,
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n",
|
| 60 |
+
"IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"image = load_image(IMAGE_PATH)\n",
|
| 63 |
+
"masks = load_masks(os.path.dirname(IMAGE_PATH), extension=\".png\")\n",
|
| 64 |
+
"display_image(image, masks)"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## 3. Generate Gaussian Splats"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"outputs = [inference(image, mask, seed=42) for mask in masks]"
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "markdown",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"source": [
|
| 87 |
+
"## 4. Visualize Gaussian Splat of the Scene\n",
|
| 88 |
+
"### a. Animated Gif"
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
{
|
| 92 |
+
"cell_type": "code",
|
| 93 |
+
"execution_count": null,
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"source": [
|
| 97 |
+
"scene_gs = make_scene(*outputs)\n",
|
| 98 |
+
"scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# export gaussian splatting (as point cloud)\n",
|
| 101 |
+
"scene_gs.save_ply(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"video = render_video(\n",
|
| 104 |
+
" scene_gs,\n",
|
| 105 |
+
" r=1,\n",
|
| 106 |
+
" fov=60,\n",
|
| 107 |
+
" resolution=512,\n",
|
| 108 |
+
")[\"color\"]\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# save video as gif\n",
|
| 111 |
+
"imageio.mimsave(\n",
|
| 112 |
+
" os.path.join(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.gif\"),\n",
|
| 113 |
+
" video,\n",
|
| 114 |
+
" format=\"GIF\",\n",
|
| 115 |
+
" duration=1000 / 30, # default assuming 30fps from the input MP4\n",
|
| 116 |
+
" loop=0, # 0 means loop indefinitely\n",
|
| 117 |
+
")\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"# notebook display\n",
|
| 120 |
+
"ImageDisplay(url=f\"gaussians/multi/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\",)"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "markdown",
|
| 125 |
+
"metadata": {},
|
| 126 |
+
"source": [
|
| 127 |
+
"### b. Interactive Visualizer"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"outputs": [],
|
| 135 |
+
"source": [
|
| 136 |
+
"# might take a while to load (black screen)\n",
|
| 137 |
+
"interactive_visualizer(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")"
|
| 138 |
+
]
|
| 139 |
+
}
|
| 140 |
+
],
|
| 141 |
+
"metadata": {
|
| 142 |
+
"kernelspec": {
|
| 143 |
+
"display_name": "sam3d-objects",
|
| 144 |
+
"language": "python",
|
| 145 |
+
"name": "python3"
|
| 146 |
+
},
|
| 147 |
+
"language_info": {
|
| 148 |
+
"codemirror_mode": {
|
| 149 |
+
"name": "ipython",
|
| 150 |
+
"version": 3
|
| 151 |
+
},
|
| 152 |
+
"file_extension": ".py",
|
| 153 |
+
"mimetype": "text/x-python",
|
| 154 |
+
"name": "python",
|
| 155 |
+
"nbconvert_exporter": "python",
|
| 156 |
+
"pygments_lexer": "ipython3",
|
| 157 |
+
"version": "3.11.0"
|
| 158 |
+
}
|
| 159 |
+
},
|
| 160 |
+
"nbformat": 4,
|
| 161 |
+
"nbformat_minor": 2
|
| 162 |
+
}
|
thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"# Copyright (c) Meta Platforms, Inc. and affiliates."
|
| 10 |
+
]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"cell_type": "markdown",
|
| 14 |
+
"metadata": {},
|
| 15 |
+
"source": [
|
| 16 |
+
"## 1. Imports and Model Loading"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": null,
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"import os\n",
|
| 26 |
+
"import imageio\n",
|
| 27 |
+
"import uuid\n",
|
| 28 |
+
"from IPython.display import Image as ImageDisplay\n",
|
| 29 |
+
"from inference import Inference, ready_gaussian_for_video_rendering, render_video, load_image, load_single_mask, display_image, make_scene, interactive_visualizer"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"PATH = os.getcwd()\n",
|
| 39 |
+
"TAG = \"hf\"\n",
|
| 40 |
+
"config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n",
|
| 41 |
+
"inference = Inference(config_path, compile=False)"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": [
|
| 48 |
+
"## 2. Load input image to lift to 3D (single object)"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"outputs": [],
|
| 56 |
+
"source": [
|
| 57 |
+
"IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n",
|
| 58 |
+
"IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"image = load_image(IMAGE_PATH)\n",
|
| 61 |
+
"mask = load_single_mask(os.path.dirname(IMAGE_PATH), index=14)\n",
|
| 62 |
+
"display_image(image, masks=[mask])"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"source": [
|
| 69 |
+
"## 3. Generate Gaussian Splat"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": [
|
| 78 |
+
"# run model\n",
|
| 79 |
+
"output = inference(image, mask, seed=42)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# export gaussian splat (as point cloud)\n",
|
| 82 |
+
"output[\"gs\"].save_ply(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"source": [
|
| 89 |
+
"## 4. Visualize Gaussian Splat\n",
|
| 90 |
+
"### a. Animated Gif"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"metadata": {},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"# render gaussian splat\n",
|
| 100 |
+
"scene_gs = make_scene(output)\n",
|
| 101 |
+
"scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"video = render_video(\n",
|
| 104 |
+
" scene_gs,\n",
|
| 105 |
+
" r=1,\n",
|
| 106 |
+
" fov=60,\n",
|
| 107 |
+
" pitch_deg=15,\n",
|
| 108 |
+
" yaw_start_deg=-45,\n",
|
| 109 |
+
" resolution=512,\n",
|
| 110 |
+
")[\"color\"]\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"# save video as gif\n",
|
| 113 |
+
"imageio.mimsave(\n",
|
| 114 |
+
" os.path.join(f\"{PATH}/gaussians/single/{IMAGE_NAME}.gif\"),\n",
|
| 115 |
+
" video,\n",
|
| 116 |
+
" format=\"GIF\",\n",
|
| 117 |
+
" duration=1000 / 30, # default assuming 30fps from the input MP4\n",
|
| 118 |
+
" loop=0, # 0 means loop indefinitely\n",
|
| 119 |
+
")\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"# notebook display\n",
|
| 122 |
+
"ImageDisplay(url=f\"gaussians/single/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\")"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "markdown",
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"source": [
|
| 129 |
+
"### b. Interactive Visualizer"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"cell_type": "code",
|
| 134 |
+
"execution_count": null,
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"# might take a while to load (black screen)\n",
|
| 139 |
+
"interactive_visualizer(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")"
|
| 140 |
+
]
|
| 141 |
+
}
|
| 142 |
+
],
|
| 143 |
+
"metadata": {
|
| 144 |
+
"kernelspec": {
|
| 145 |
+
"display_name": "sam3d_objects-3dfy",
|
| 146 |
+
"language": "python",
|
| 147 |
+
"name": "python3"
|
| 148 |
+
},
|
| 149 |
+
"language_info": {
|
| 150 |
+
"codemirror_mode": {
|
| 151 |
+
"name": "ipython",
|
| 152 |
+
"version": 3
|
| 153 |
+
},
|
| 154 |
+
"file_extension": ".py",
|
| 155 |
+
"mimetype": "text/x-python",
|
| 156 |
+
"name": "python",
|
| 157 |
+
"nbconvert_exporter": "python",
|
| 158 |
+
"pygments_lexer": "ipython3",
|
| 159 |
+
"version": "3.11.0"
|
| 160 |
+
}
|
| 161 |
+
},
|
| 162 |
+
"nbformat": 4,
|
| 163 |
+
"nbformat_minor": 2
|
| 164 |
+
}
|
thirdparty/sam3d/sam3d/notebook/inference.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# not ideal to put that here
|
| 5 |
+
os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"]
|
| 6 |
+
os.environ["LIDRA_SKIP_INIT"] = "true"
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Union, Optional, List, Callable
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from omegaconf import OmegaConf, DictConfig, ListConfig
|
| 13 |
+
from hydra.utils import instantiate, get_method
|
| 14 |
+
import torch
|
| 15 |
+
import math
|
| 16 |
+
import utils3d
|
| 17 |
+
import shutil
|
| 18 |
+
import subprocess
|
| 19 |
+
import seaborn as sns
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import gradio as gr
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
from copy import deepcopy
|
| 25 |
+
from kaolin.visualize import IpyTurntableVisualizer
|
| 26 |
+
from kaolin.render.camera import Camera, CameraExtrinsics, PinholeIntrinsics
|
| 27 |
+
import builtins
|
| 28 |
+
from pytorch3d.transforms import quaternion_multiply, quaternion_invert
|
| 29 |
+
|
| 30 |
+
import sam3d_objects # REMARK(Pierre) : do not remove this import
|
| 31 |
+
from sam3d_objects.pipeline.inference_pipeline_pointmap import InferencePipelinePointMap
|
| 32 |
+
from sam3d_objects.model.backbone.tdfy_dit.utils import render_utils
|
| 33 |
+
|
| 34 |
+
from sam3d_objects.utils.visualization import SceneVisualizer
|
| 35 |
+
|
| 36 |
+
__all__ = ["Inference"]
|
| 37 |
+
|
| 38 |
+
WHITELIST_FILTERS = [
|
| 39 |
+
lambda target: target.split(".", 1)[0] in {"sam3d_objects", "torch", "torchvision", "moge"},
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
BLACKLIST_FILTERS = [
|
| 43 |
+
lambda target: get_method(target)
|
| 44 |
+
in {
|
| 45 |
+
builtins.exec,
|
| 46 |
+
builtins.eval,
|
| 47 |
+
builtins.__import__,
|
| 48 |
+
os.kill,
|
| 49 |
+
os.system,
|
| 50 |
+
os.putenv,
|
| 51 |
+
os.remove,
|
| 52 |
+
os.removedirs,
|
| 53 |
+
os.rmdir,
|
| 54 |
+
os.fchdir,
|
| 55 |
+
os.setuid,
|
| 56 |
+
os.fork,
|
| 57 |
+
os.forkpty,
|
| 58 |
+
os.killpg,
|
| 59 |
+
os.rename,
|
| 60 |
+
os.renames,
|
| 61 |
+
os.truncate,
|
| 62 |
+
os.replace,
|
| 63 |
+
os.unlink,
|
| 64 |
+
os.fchmod,
|
| 65 |
+
os.fchown,
|
| 66 |
+
os.chmod,
|
| 67 |
+
os.chown,
|
| 68 |
+
os.chroot,
|
| 69 |
+
os.fchdir,
|
| 70 |
+
os.lchown,
|
| 71 |
+
os.getcwd,
|
| 72 |
+
os.chdir,
|
| 73 |
+
shutil.rmtree,
|
| 74 |
+
shutil.move,
|
| 75 |
+
shutil.chown,
|
| 76 |
+
subprocess.Popen,
|
| 77 |
+
builtins.help,
|
| 78 |
+
},
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Inference:
|
| 83 |
+
# public facing inference API
|
| 84 |
+
# only put publicly exposed arguments here
|
| 85 |
+
def __init__(self, config_file: str, compile: bool = False):
|
| 86 |
+
# load inference pipeline
|
| 87 |
+
config = OmegaConf.load(config_file)
|
| 88 |
+
config.rendering_engine = "pytorch3d" # overwrite to disable nvdiffrast
|
| 89 |
+
config.compile_model = compile
|
| 90 |
+
config.workspace_dir = os.path.dirname(config_file)
|
| 91 |
+
check_hydra_safety(config, WHITELIST_FILTERS, BLACKLIST_FILTERS)
|
| 92 |
+
self._pipeline: InferencePipelinePointMap = instantiate(config)
|
| 93 |
+
|
| 94 |
+
def merge_mask_to_rgba(self, image, mask):
|
| 95 |
+
mask = mask.astype(np.uint8) * 255
|
| 96 |
+
mask = mask[..., None]
|
| 97 |
+
# embed mask in alpha channel
|
| 98 |
+
rgba_image = np.concatenate([image[..., :3], mask], axis=-1)
|
| 99 |
+
return rgba_image
|
| 100 |
+
|
| 101 |
+
def __call__(
|
| 102 |
+
self,
|
| 103 |
+
image: Union[Image.Image, np.ndarray],
|
| 104 |
+
mask: Optional[Union[None, Image.Image, np.ndarray]],
|
| 105 |
+
seed: Optional[int] = None,
|
| 106 |
+
pointmap=None,
|
| 107 |
+
) -> dict:
|
| 108 |
+
image = self.merge_mask_to_rgba(image, mask)
|
| 109 |
+
return self._pipeline.run(
|
| 110 |
+
image,
|
| 111 |
+
None,
|
| 112 |
+
seed,
|
| 113 |
+
stage1_only=False,
|
| 114 |
+
with_mesh_postprocess=False,
|
| 115 |
+
with_texture_baking=False,
|
| 116 |
+
with_layout_postprocess=True,
|
| 117 |
+
use_vertex_color=True,
|
| 118 |
+
stage1_inference_steps=None,
|
| 119 |
+
pointmap=pointmap,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
|
| 124 |
+
is_list = isinstance(yaws, list)
|
| 125 |
+
if not is_list:
|
| 126 |
+
yaws = [yaws]
|
| 127 |
+
pitchs = [pitchs]
|
| 128 |
+
if not isinstance(rs, list):
|
| 129 |
+
rs = [rs] * len(yaws)
|
| 130 |
+
if not isinstance(fovs, list):
|
| 131 |
+
fovs = [fovs] * len(yaws)
|
| 132 |
+
extrinsics = []
|
| 133 |
+
intrinsics = []
|
| 134 |
+
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
|
| 135 |
+
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
|
| 136 |
+
yaw = torch.tensor(float(yaw)).cuda()
|
| 137 |
+
pitch = torch.tensor(float(pitch)).cuda()
|
| 138 |
+
orig = (
|
| 139 |
+
torch.tensor(
|
| 140 |
+
[
|
| 141 |
+
torch.sin(yaw) * torch.cos(pitch),
|
| 142 |
+
torch.sin(pitch),
|
| 143 |
+
torch.cos(yaw) * torch.cos(pitch),
|
| 144 |
+
]
|
| 145 |
+
).cuda()
|
| 146 |
+
* r
|
| 147 |
+
)
|
| 148 |
+
extr = utils3d.torch.extrinsics_look_at(
|
| 149 |
+
orig,
|
| 150 |
+
torch.tensor([0, 0, 0]).float().cuda(),
|
| 151 |
+
torch.tensor([0, 1, 0]).float().cuda(),
|
| 152 |
+
)
|
| 153 |
+
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 154 |
+
extrinsics.append(extr)
|
| 155 |
+
intrinsics.append(intr)
|
| 156 |
+
if not is_list:
|
| 157 |
+
extrinsics = extrinsics[0]
|
| 158 |
+
intrinsics = intrinsics[0]
|
| 159 |
+
return extrinsics, intrinsics
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def render_video(
|
| 163 |
+
sample,
|
| 164 |
+
resolution=512,
|
| 165 |
+
bg_color=(0, 0, 0),
|
| 166 |
+
num_frames=300,
|
| 167 |
+
r=2.0,
|
| 168 |
+
fov=40,
|
| 169 |
+
pitch_deg=0,
|
| 170 |
+
yaw_start_deg=-90,
|
| 171 |
+
**kwargs,
|
| 172 |
+
):
|
| 173 |
+
|
| 174 |
+
yaws = (
|
| 175 |
+
torch.linspace(0, 2 * torch.pi, num_frames) + math.radians(yaw_start_deg)
|
| 176 |
+
).tolist()
|
| 177 |
+
pitch = [math.radians(pitch_deg)] * num_frames
|
| 178 |
+
|
| 179 |
+
extr, intr = _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
|
| 180 |
+
|
| 181 |
+
return render_utils.render_frames(
|
| 182 |
+
sample,
|
| 183 |
+
extr,
|
| 184 |
+
intr,
|
| 185 |
+
{"resolution": resolution, "bg_color": bg_color, "backend": "gsplat"},
|
| 186 |
+
**kwargs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def ready_gaussian_for_video_rendering(scene_gs, in_place=False, fix_alignment=False):
|
| 191 |
+
if fix_alignment:
|
| 192 |
+
scene_gs = _fix_gaussian_alignment(scene_gs, in_place=in_place)
|
| 193 |
+
scene_gs = normalized_gaussian(scene_gs, in_place=fix_alignment)
|
| 194 |
+
return scene_gs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _fix_gaussian_alignment(scene_gs, in_place=False):
|
| 198 |
+
if not in_place:
|
| 199 |
+
scene_gs = deepcopy(scene_gs)
|
| 200 |
+
|
| 201 |
+
device = scene_gs._xyz.device
|
| 202 |
+
dtype = scene_gs._xyz.dtype
|
| 203 |
+
scene_gs._xyz = (
|
| 204 |
+
scene_gs._xyz
|
| 205 |
+
@ torch.tensor(
|
| 206 |
+
[
|
| 207 |
+
[-1, 0, 0],
|
| 208 |
+
[0, 0, 1],
|
| 209 |
+
[0, 1, 0],
|
| 210 |
+
],
|
| 211 |
+
device=device,
|
| 212 |
+
dtype=dtype,
|
| 213 |
+
).T
|
| 214 |
+
)
|
| 215 |
+
return scene_gs
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def normalized_gaussian(scene_gs, in_place=False, outlier_percentile=None):
|
| 219 |
+
if not in_place:
|
| 220 |
+
scene_gs = deepcopy(scene_gs)
|
| 221 |
+
|
| 222 |
+
orig_xyz = scene_gs.get_xyz
|
| 223 |
+
orig_scale = scene_gs.get_scaling
|
| 224 |
+
|
| 225 |
+
active_mask = (scene_gs.get_opacity > 0.9).squeeze()
|
| 226 |
+
inv_scale = (
|
| 227 |
+
orig_xyz[active_mask].max(dim=0)[0] - orig_xyz[active_mask].min(dim=0)[0]
|
| 228 |
+
).max()
|
| 229 |
+
norm_scale = orig_scale / inv_scale
|
| 230 |
+
norm_xyz = orig_xyz / inv_scale
|
| 231 |
+
|
| 232 |
+
if outlier_percentile is None:
|
| 233 |
+
lower_bound_xyz = torch.min(norm_xyz[active_mask], dim=0)[0]
|
| 234 |
+
upper_bound_xyz = torch.max(norm_xyz[active_mask], dim=0)[0]
|
| 235 |
+
else:
|
| 236 |
+
lower_bound_xyz = torch.quantile(
|
| 237 |
+
norm_xyz[active_mask],
|
| 238 |
+
outlier_percentile,
|
| 239 |
+
dim=0,
|
| 240 |
+
)
|
| 241 |
+
upper_bound_xyz = torch.quantile(
|
| 242 |
+
norm_xyz[active_mask],
|
| 243 |
+
1.0 - outlier_percentile,
|
| 244 |
+
dim=0,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
center = (lower_bound_xyz + upper_bound_xyz) / 2
|
| 248 |
+
norm_xyz = norm_xyz - center
|
| 249 |
+
scene_gs.from_xyz(norm_xyz)
|
| 250 |
+
scene_gs.mininum_kernel_size /= inv_scale.item()
|
| 251 |
+
scene_gs.from_scaling(norm_scale)
|
| 252 |
+
return scene_gs
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def make_scene(*outputs, in_place=False):
|
| 256 |
+
if not in_place:
|
| 257 |
+
outputs = [deepcopy(output) for output in outputs]
|
| 258 |
+
|
| 259 |
+
all_outs = []
|
| 260 |
+
minimum_kernel_size = float("inf")
|
| 261 |
+
for output in outputs:
|
| 262 |
+
# move gaussians to scene frame of reference
|
| 263 |
+
PC = SceneVisualizer.object_pointcloud(
|
| 264 |
+
points_local=output["gaussian"][0].get_xyz.unsqueeze(0),
|
| 265 |
+
quat_l2c=output["rotation"],
|
| 266 |
+
trans_l2c=output["translation"],
|
| 267 |
+
scale_l2c=output["scale"],
|
| 268 |
+
)
|
| 269 |
+
output["gaussian"][0].from_xyz(PC.points_list()[0])
|
| 270 |
+
# must ... ROTATE
|
| 271 |
+
output["gaussian"][0].from_rotation(
|
| 272 |
+
quaternion_multiply(
|
| 273 |
+
quaternion_invert(output["rotation"]),
|
| 274 |
+
output["gaussian"][0].get_rotation,
|
| 275 |
+
)
|
| 276 |
+
)
|
| 277 |
+
scale = output["gaussian"][0].get_scaling
|
| 278 |
+
adjusted_scale = scale * output["scale"]
|
| 279 |
+
assert (
|
| 280 |
+
output["scale"][0, 0].item()
|
| 281 |
+
== output["scale"][0, 1].item()
|
| 282 |
+
== output["scale"][0, 2].item()
|
| 283 |
+
)
|
| 284 |
+
output["gaussian"][0].mininum_kernel_size *= output["scale"][0, 0].item()
|
| 285 |
+
adjusted_scale = torch.maximum(
|
| 286 |
+
adjusted_scale,
|
| 287 |
+
torch.tensor(
|
| 288 |
+
output["gaussian"][0].mininum_kernel_size * 1.1,
|
| 289 |
+
device=adjusted_scale.device,
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
output["gaussian"][0].from_scaling(adjusted_scale)
|
| 293 |
+
minimum_kernel_size = min(
|
| 294 |
+
minimum_kernel_size,
|
| 295 |
+
output["gaussian"][0].mininum_kernel_size,
|
| 296 |
+
)
|
| 297 |
+
all_outs.append(output)
|
| 298 |
+
|
| 299 |
+
# merge gaussians
|
| 300 |
+
scene_gs = all_outs[0]["gaussian"][0]
|
| 301 |
+
scene_gs.mininum_kernel_size = minimum_kernel_size
|
| 302 |
+
for out in all_outs[1:]:
|
| 303 |
+
out_gs = out["gaussian"][0]
|
| 304 |
+
scene_gs._xyz = torch.cat([scene_gs._xyz, out_gs._xyz], dim=0)
|
| 305 |
+
scene_gs._features_dc = torch.cat(
|
| 306 |
+
[scene_gs._features_dc, out_gs._features_dc], dim=0
|
| 307 |
+
)
|
| 308 |
+
scene_gs._scaling = torch.cat([scene_gs._scaling, out_gs._scaling], dim=0)
|
| 309 |
+
scene_gs._rotation = torch.cat([scene_gs._rotation, out_gs._rotation], dim=0)
|
| 310 |
+
scene_gs._opacity = torch.cat([scene_gs._opacity, out_gs._opacity], dim=0)
|
| 311 |
+
|
| 312 |
+
return scene_gs
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def check_target(
|
| 316 |
+
target: str,
|
| 317 |
+
whitelist_filters: List[Callable],
|
| 318 |
+
blacklist_filters: List[Callable],
|
| 319 |
+
):
|
| 320 |
+
if any(filt(target) for filt in whitelist_filters):
|
| 321 |
+
if not any(filt(target) for filt in blacklist_filters):
|
| 322 |
+
return
|
| 323 |
+
raise RuntimeError(
|
| 324 |
+
f"target '{target}' is not allowed to be hydra instantiated, if this is a mistake, please do modify the whitelist_filters / blacklist_filters"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def check_hydra_safety(
|
| 329 |
+
config: DictConfig,
|
| 330 |
+
whitelist_filters: List[Callable],
|
| 331 |
+
blacklist_filters: List[Callable],
|
| 332 |
+
):
|
| 333 |
+
to_check = [config]
|
| 334 |
+
while len(to_check) > 0:
|
| 335 |
+
node = to_check.pop()
|
| 336 |
+
if isinstance(node, DictConfig):
|
| 337 |
+
to_check.extend(list(node.values()))
|
| 338 |
+
if "_target_" in node:
|
| 339 |
+
check_target(node["_target_"], whitelist_filters, blacklist_filters)
|
| 340 |
+
elif isinstance(node, ListConfig):
|
| 341 |
+
to_check.extend(list(node))
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def load_image(path):
|
| 345 |
+
image = Image.open(path)
|
| 346 |
+
image = np.array(image)
|
| 347 |
+
image = image.astype(np.uint8)
|
| 348 |
+
return image
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def load_mask(path):
|
| 352 |
+
mask = load_image(path)
|
| 353 |
+
mask = mask > 0
|
| 354 |
+
if mask.ndim == 3:
|
| 355 |
+
mask = mask[..., -1]
|
| 356 |
+
return mask
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def load_single_mask(folder_path, index=0, extension=".png"):
|
| 360 |
+
masks = load_masks(folder_path, [index], extension)
|
| 361 |
+
return masks[0]
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def load_masks(folder_path, indices_list=None, extension=".png"):
|
| 365 |
+
masks = []
|
| 366 |
+
indices_list = [] if indices_list is None else list(indices_list)
|
| 367 |
+
if not len(indices_list) > 0: # get all all masks if not provided
|
| 368 |
+
idx = 0
|
| 369 |
+
while os.path.exists(os.path.join(folder_path, f"{idx}{extension}")):
|
| 370 |
+
indices_list.append(idx)
|
| 371 |
+
idx += 1
|
| 372 |
+
|
| 373 |
+
for idx in indices_list:
|
| 374 |
+
mask_path = os.path.join(folder_path, f"{idx}{extension}")
|
| 375 |
+
assert os.path.exists(mask_path), f"Mask path {mask_path} does not exist"
|
| 376 |
+
mask = load_mask(mask_path)
|
| 377 |
+
masks.append(mask)
|
| 378 |
+
return masks
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def display_image(image, masks=None):
|
| 382 |
+
def imshow(image, ax):
|
| 383 |
+
ax.axis("off")
|
| 384 |
+
ax.imshow(image)
|
| 385 |
+
|
| 386 |
+
grid = (1, 1) if masks is None else (2, 2)
|
| 387 |
+
fig, axes = plt.subplots(*grid)
|
| 388 |
+
if masks is not None:
|
| 389 |
+
mask_colors = sns.color_palette("husl", len(masks))
|
| 390 |
+
black_image = np.zeros_like(image[..., :3], dtype=float) # background
|
| 391 |
+
mask_display = np.copy(black_image)
|
| 392 |
+
mask_union = np.zeros_like(image[..., :3])
|
| 393 |
+
for i, mask in enumerate(masks):
|
| 394 |
+
mask_display[mask] = mask_colors[i]
|
| 395 |
+
mask_union |= mask[..., None] if mask.ndim == 2 else mask
|
| 396 |
+
imshow(black_image, axes[0, 1])
|
| 397 |
+
imshow(mask_display, axes[1, 0])
|
| 398 |
+
imshow(image * mask_union, axes[1, 1])
|
| 399 |
+
|
| 400 |
+
image_axe = axes if masks is None else axes[0, 0]
|
| 401 |
+
imshow(image, image_axe)
|
| 402 |
+
|
| 403 |
+
fig.tight_layout(pad=0)
|
| 404 |
+
fig.show()
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def interactive_visualizer(ply_path):
|
| 408 |
+
with gr.Blocks() as demo:
|
| 409 |
+
gr.Markdown("# 3D Gaussian Splatting (black-screen loading might take a while)")
|
| 410 |
+
gr.Model3D(
|
| 411 |
+
value=ply_path, # splat file
|
| 412 |
+
label="3D Scene",
|
| 413 |
+
)
|
| 414 |
+
demo.launch(share=True)
|
thirdparty/sam3d/sam3d/notebook/mesh_alignment.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
"""
|
| 3 |
+
SAM 3D Body (3DB) Mesh Alignment Utilities
|
| 4 |
+
Handles alignment of 3DB meshes to SAM 3D Object, same as MoGe point cloud scale.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import math
|
| 9 |
+
import json
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import trimesh
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from pytorch3d.structures import Meshes
|
| 16 |
+
from pytorch3d.renderer import PerspectiveCameras, RasterizationSettings, MeshRasterizer, TexturesVertex
|
| 17 |
+
from moge.model.v1 import MoGeModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_3db_mesh(mesh_path, device='cuda'):
|
| 21 |
+
"""Load 3DB mesh and convert from OpenGL to PyTorch3D coordinates."""
|
| 22 |
+
mesh = trimesh.load(mesh_path)
|
| 23 |
+
vertices = np.array(mesh.vertices)
|
| 24 |
+
faces = np.array(mesh.faces)
|
| 25 |
+
|
| 26 |
+
# Convert from OpenGL to PyTorch3D coordinates
|
| 27 |
+
vertices[:, 0] *= -1 # Flip X
|
| 28 |
+
vertices[:, 2] *= -1 # Flip Z
|
| 29 |
+
|
| 30 |
+
vertices = torch.from_numpy(vertices).float().to(device)
|
| 31 |
+
faces = torch.from_numpy(faces).long().to(device)
|
| 32 |
+
return vertices, faces
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_moge_pointcloud(image_tensor, device='cuda'):
|
| 36 |
+
"""Generate MoGe point cloud from image tensor."""
|
| 37 |
+
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device)
|
| 38 |
+
moge_model.eval()
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
moge_output = moge_model.infer(image_tensor)
|
| 41 |
+
return moge_output
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def denormalize_intrinsics(norm_K, height, width):
|
| 45 |
+
"""Convert normalized intrinsics to absolute pixel coordinates."""
|
| 46 |
+
cx_norm, cy_norm = norm_K[0, 2], norm_K[1, 2]
|
| 47 |
+
fx_norm, fy_norm = norm_K[0, 0], norm_K[1, 1]
|
| 48 |
+
|
| 49 |
+
fx_abs = fx_norm * width
|
| 50 |
+
fy_abs = fy_norm * height
|
| 51 |
+
cx_abs = cx_norm * width
|
| 52 |
+
cy_abs = cy_norm * height
|
| 53 |
+
fx_abs = fy_abs
|
| 54 |
+
|
| 55 |
+
return np.array([
|
| 56 |
+
[fx_abs, 0.0, cx_abs],
|
| 57 |
+
[0.0, fy_abs, cy_abs],
|
| 58 |
+
[0.0, 0.0, 1.0]
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def crop_mesh_with_mask(vertices, faces, focal_length, mask, device='cuda'):
|
| 63 |
+
"""Crop mesh vertices to only those visible in the mask."""
|
| 64 |
+
textures = TexturesVertex(verts_features=torch.ones_like(vertices)[None])
|
| 65 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)
|
| 66 |
+
|
| 67 |
+
H, W = mask.shape[-2:]
|
| 68 |
+
fx = fy = focal_length
|
| 69 |
+
cx, cy = W / 2.0, H / 2.0
|
| 70 |
+
|
| 71 |
+
camera = PerspectiveCameras(
|
| 72 |
+
focal_length=((fx, fy),),
|
| 73 |
+
principal_point=((cx, cy),),
|
| 74 |
+
image_size=((H, W),),
|
| 75 |
+
in_ndc=False, device=device
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
raster_settings = RasterizationSettings(
|
| 79 |
+
image_size=(H, W), blur_radius=0.0, faces_per_pixel=1,
|
| 80 |
+
cull_backfaces=False, bin_size=0,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
rasterizer = MeshRasterizer(cameras=camera, raster_settings=raster_settings)
|
| 84 |
+
fragments = rasterizer(mesh)
|
| 85 |
+
|
| 86 |
+
face_indices = fragments.pix_to_face[0, ..., 0] # (H, W)
|
| 87 |
+
visible_mask = (mask > 0) & (face_indices >= 0)
|
| 88 |
+
visible_face_ids = face_indices[visible_mask]
|
| 89 |
+
|
| 90 |
+
visible_faces = faces[visible_face_ids]
|
| 91 |
+
visible_vert_ids = torch.unique(visible_faces)
|
| 92 |
+
verts_cropped = vertices[visible_vert_ids]
|
| 93 |
+
|
| 94 |
+
return verts_cropped, visible_mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def extract_target_points(pointmap, visible_mask):
|
| 98 |
+
"""Extract target points from MoGe pointmap using visible mask."""
|
| 99 |
+
target_points = pointmap[visible_mask.bool()]
|
| 100 |
+
|
| 101 |
+
# Convert from MoGe coordinates to PyTorch3D coordinates
|
| 102 |
+
target_points[:, 0] *= -1
|
| 103 |
+
target_points[:, 1] *= -1
|
| 104 |
+
|
| 105 |
+
# Remove flying points using adaptive quantile filtering
|
| 106 |
+
z_range = torch.max(target_points[:, 2]) - torch.min(target_points[:, 2])
|
| 107 |
+
if z_range > 6.0:
|
| 108 |
+
thresh = 0.90
|
| 109 |
+
elif z_range > 2.0:
|
| 110 |
+
thresh = 0.93
|
| 111 |
+
else:
|
| 112 |
+
thresh = 0.95
|
| 113 |
+
|
| 114 |
+
depth_quantile = torch.quantile(target_points[:, 2], thresh)
|
| 115 |
+
target_points = target_points[target_points[:, 2] <= depth_quantile]
|
| 116 |
+
|
| 117 |
+
# Remove infinite values
|
| 118 |
+
finite_mask = torch.isfinite(target_points).all(dim=1)
|
| 119 |
+
target_points = target_points[finite_mask]
|
| 120 |
+
|
| 121 |
+
return target_points
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def align_mesh_to_pointcloud(vertices, target_points):
|
| 125 |
+
"""Align mesh vertices to target point cloud using scale and translation."""
|
| 126 |
+
if target_points.shape[0] == 0:
|
| 127 |
+
print("[WARNING] No target points for alignment!")
|
| 128 |
+
return vertices, torch.tensor(1.0), torch.zeros(3)
|
| 129 |
+
|
| 130 |
+
# Scale alignment based on height
|
| 131 |
+
height_src = torch.max(vertices[:, 1]) - torch.min(vertices[:, 1])
|
| 132 |
+
height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1])
|
| 133 |
+
scale_factor = height_tgt / height_src
|
| 134 |
+
|
| 135 |
+
vertices_scaled = vertices * scale_factor
|
| 136 |
+
|
| 137 |
+
# Translation alignment based on centers
|
| 138 |
+
center_src = torch.mean(vertices_scaled, dim=0)
|
| 139 |
+
center_tgt = torch.mean(target_points, dim=0)
|
| 140 |
+
translation = center_tgt - center_src
|
| 141 |
+
|
| 142 |
+
vertices_aligned = vertices_scaled + translation
|
| 143 |
+
return vertices_aligned, scale_factor, translation
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_mask_for_alignment(mask_path):
|
| 147 |
+
"""Load mask image as numpy array."""
|
| 148 |
+
mask = Image.open(mask_path).convert('L')
|
| 149 |
+
mask_array = np.array(mask) / 255.0
|
| 150 |
+
return mask_array
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def load_focal_length_from_json(json_path):
|
| 154 |
+
"""Load focal length from JSON file."""
|
| 155 |
+
try:
|
| 156 |
+
with open(json_path, 'r') as f:
|
| 157 |
+
data = json.load(f)
|
| 158 |
+
focal_length = data.get('focal_length')
|
| 159 |
+
if focal_length is None:
|
| 160 |
+
raise ValueError("'focal_length' key not found in JSON file")
|
| 161 |
+
print(f"[INFO] Loaded focal length from {json_path}: {focal_length}")
|
| 162 |
+
return focal_length
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"[ERROR] Failed to load focal length from {json_path}: {e}")
|
| 165 |
+
raise
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def process_3db_alignment(mesh_path, mask_path, image_path, device='cuda', focal_length_json_path=None):
|
| 169 |
+
"""Complete pipeline for aligning 3DB mesh to MoGe scale."""
|
| 170 |
+
print(f"[INFO] Processing alignment...")
|
| 171 |
+
|
| 172 |
+
# Load input data
|
| 173 |
+
vertices, faces = load_3db_mesh(mesh_path, device)
|
| 174 |
+
|
| 175 |
+
# Load and preprocess image
|
| 176 |
+
image = Image.open(image_path).convert('RGB')
|
| 177 |
+
image_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0
|
| 178 |
+
image_tensor = image_tensor.to(device)
|
| 179 |
+
|
| 180 |
+
# Load mask and resize to match image
|
| 181 |
+
H, W = image_tensor.shape[1:]
|
| 182 |
+
mask = load_mask_for_alignment(mask_path)
|
| 183 |
+
if mask.shape != (H, W):
|
| 184 |
+
mask = Image.fromarray((mask * 255).astype(np.uint8))
|
| 185 |
+
mask = mask.resize((W, H), Image.NEAREST)
|
| 186 |
+
mask = np.array(mask) / 255.0
|
| 187 |
+
mask = torch.from_numpy(mask).float().to(device)
|
| 188 |
+
|
| 189 |
+
# Generate MoGe point cloud
|
| 190 |
+
print("[INFO] Generating MoGe point cloud...")
|
| 191 |
+
moge_output = get_moge_pointcloud(image_tensor, device)
|
| 192 |
+
|
| 193 |
+
# Load focal length from JSON if provided, otherwise compute from MoGe intrinsics
|
| 194 |
+
if focal_length_json_path is not None:
|
| 195 |
+
focal_length = load_focal_length_from_json(focal_length_json_path)
|
| 196 |
+
else:
|
| 197 |
+
# Compute camera parameters from MoGe intrinsics (fallback)
|
| 198 |
+
intrinsics = denormalize_intrinsics(moge_output['intrinsics'].cpu().numpy(), H, W)
|
| 199 |
+
focal_length = intrinsics[1, 1] # Use fy
|
| 200 |
+
print(f"[INFO] Using computed focal length from MoGe: {focal_length}")
|
| 201 |
+
|
| 202 |
+
# Crop mesh using mask
|
| 203 |
+
print("[INFO] Cropping mesh with mask...")
|
| 204 |
+
verts_cropped, visible_mask = crop_mesh_with_mask(vertices, faces, focal_length, mask, device)
|
| 205 |
+
|
| 206 |
+
# Extract target points from MoGe
|
| 207 |
+
print("[INFO] Extracting target points...")
|
| 208 |
+
target_points = extract_target_points(moge_output['points'], visible_mask)
|
| 209 |
+
|
| 210 |
+
if target_points.shape[0] == 0:
|
| 211 |
+
print("[ERROR] No valid target points found!")
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
# Perform alignment
|
| 215 |
+
print("[INFO] Aligning mesh to point cloud...")
|
| 216 |
+
aligned_vertices, scale_factor, translation = align_mesh_to_pointcloud(verts_cropped, target_points)
|
| 217 |
+
|
| 218 |
+
# Apply alignment to full mesh
|
| 219 |
+
full_aligned_vertices = (vertices * scale_factor) + translation
|
| 220 |
+
|
| 221 |
+
# Convert back to OpenGL coordinates for final output
|
| 222 |
+
final_vertices_opengl = full_aligned_vertices.cpu().numpy()
|
| 223 |
+
final_vertices_opengl[:, 0] *= -1
|
| 224 |
+
final_vertices_opengl[:, 2] *= -1
|
| 225 |
+
|
| 226 |
+
results = {
|
| 227 |
+
'aligned_vertices_opengl': final_vertices_opengl,
|
| 228 |
+
'faces': faces.cpu().numpy(),
|
| 229 |
+
'scale_factor': scale_factor.item(),
|
| 230 |
+
'translation': translation.cpu().numpy(),
|
| 231 |
+
'focal_length': focal_length,
|
| 232 |
+
'target_points_count': target_points.shape[0],
|
| 233 |
+
'cropped_vertices_count': verts_cropped.shape[0]
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
print(f"[INFO] Alignment completed - Scale: {scale_factor.item():.4f}, Target points: {target_points.shape[0]}")
|
| 237 |
+
return results
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def process_and_save_alignment(mesh_path, mask_path, image_path, output_dir, device='cuda', focal_length_json_path=None):
|
| 241 |
+
"""
|
| 242 |
+
Complete pipeline for processing 3DB alignment and saving the result.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
mesh_path: Path to input 3DB mesh (.ply)
|
| 246 |
+
mask_path: Path to mask image (.png)
|
| 247 |
+
image_path: Path to input image (.jpg)
|
| 248 |
+
output_dir: Directory to save aligned mesh
|
| 249 |
+
device: Device to use ('cuda' or 'cpu')
|
| 250 |
+
focal_length_json_path: Optional path to focal length JSON file
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
tuple: (success: bool, output_mesh_path: str or None, result_info: dict or None)
|
| 254 |
+
"""
|
| 255 |
+
try:
|
| 256 |
+
print("[INFO] Starting 3DB mesh alignment pipeline...")
|
| 257 |
+
|
| 258 |
+
# Ensure output directory exists
|
| 259 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
# Process alignment
|
| 262 |
+
result = process_3db_alignment(
|
| 263 |
+
mesh_path=mesh_path,
|
| 264 |
+
mask_path=mask_path,
|
| 265 |
+
image_path=image_path,
|
| 266 |
+
device=device,
|
| 267 |
+
focal_length_json_path=focal_length_json_path
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if result is not None:
|
| 271 |
+
# Save aligned mesh
|
| 272 |
+
output_mesh_path = os.path.join(output_dir, 'human_aligned.ply')
|
| 273 |
+
aligned_mesh = trimesh.Trimesh(
|
| 274 |
+
vertices=result['aligned_vertices_opengl'],
|
| 275 |
+
faces=result['faces']
|
| 276 |
+
)
|
| 277 |
+
aligned_mesh.export(output_mesh_path)
|
| 278 |
+
|
| 279 |
+
print(f" SUCCESS! Saved aligned mesh to: {output_mesh_path}")
|
| 280 |
+
return True, output_mesh_path, result
|
| 281 |
+
else:
|
| 282 |
+
print(" ERROR: Failed to process mesh alignment")
|
| 283 |
+
return False, None, None
|
| 284 |
+
|
| 285 |
+
except Exception as e:
|
| 286 |
+
print(f" ERROR: Exception during processing: {e}")
|
| 287 |
+
import traceback
|
| 288 |
+
traceback.print_exc()
|
| 289 |
+
return False, None, None
|
| 290 |
+
|
| 291 |
+
finally:
|
| 292 |
+
print(" Processing complete!")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def visualize_meshes_interactive(aligned_mesh_path, dfy_mesh_path, output_dir=None, share=True, height=600):
|
| 296 |
+
"""
|
| 297 |
+
Interactive Gradio-based 3D visualization of aligned human and object meshes.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
aligned_mesh_path: Path to aligned mesh PLY file
|
| 301 |
+
dfy_mesh_path: Path to 3Dfy GLB file
|
| 302 |
+
output_dir: Directory to save combined GLB file (defaults to same dir as aligned_mesh_path)
|
| 303 |
+
share: Whether to create a public shareable link (default: True)
|
| 304 |
+
height: Height of the 3D viewer in pixels (default: 600)
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
tuple: (demo, combined_glb_path) - Gradio demo object and path to combined GLB file
|
| 308 |
+
"""
|
| 309 |
+
import gradio as gr
|
| 310 |
+
|
| 311 |
+
print("Loading meshes for interactive visualization...")
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
# Load aligned mesh (PLY)
|
| 315 |
+
aligned_mesh = trimesh.load(aligned_mesh_path)
|
| 316 |
+
print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices")
|
| 317 |
+
|
| 318 |
+
# Load 3Dfy mesh (GLB - handle scene structure)
|
| 319 |
+
dfy_scene = trimesh.load(dfy_mesh_path)
|
| 320 |
+
|
| 321 |
+
if hasattr(dfy_scene, 'dump'): # It's a scene
|
| 322 |
+
dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')]
|
| 323 |
+
if len(dfy_meshes) == 1:
|
| 324 |
+
dfy_mesh = dfy_meshes[0]
|
| 325 |
+
elif len(dfy_meshes) > 1:
|
| 326 |
+
dfy_mesh = trimesh.util.concatenate(dfy_meshes)
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError("No valid meshes in GLB file")
|
| 329 |
+
else:
|
| 330 |
+
dfy_mesh = dfy_scene
|
| 331 |
+
|
| 332 |
+
print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices")
|
| 333 |
+
|
| 334 |
+
# Create combined scene
|
| 335 |
+
scene = trimesh.Scene()
|
| 336 |
+
|
| 337 |
+
# Add both meshes with different colors
|
| 338 |
+
aligned_copy = aligned_mesh.copy()
|
| 339 |
+
aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red for aligned human
|
| 340 |
+
scene.add_geometry(aligned_copy, node_name="sam3d_aligned_human")
|
| 341 |
+
|
| 342 |
+
dfy_copy = dfy_mesh.copy()
|
| 343 |
+
dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue for 3Dfy object
|
| 344 |
+
scene.add_geometry(dfy_copy, node_name="dfy_object")
|
| 345 |
+
|
| 346 |
+
# Determine output path
|
| 347 |
+
if output_dir is None:
|
| 348 |
+
output_dir = os.path.dirname(aligned_mesh_path)
|
| 349 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 350 |
+
|
| 351 |
+
combined_glb_path = os.path.join(output_dir, 'combined_scene.glb')
|
| 352 |
+
scene.export(combined_glb_path)
|
| 353 |
+
print(f"Exported combined scene to: {combined_glb_path}")
|
| 354 |
+
|
| 355 |
+
# Create interactive Gradio viewer
|
| 356 |
+
with gr.Blocks() as demo:
|
| 357 |
+
gr.Markdown("# 3D Mesh Alignment Visualization")
|
| 358 |
+
gr.Markdown("**Red**: SAM 3D Body Aligned Human | **Blue**: 3Dfy Object")
|
| 359 |
+
gr.Model3D(
|
| 360 |
+
value=combined_glb_path,
|
| 361 |
+
label="Combined 3D Scene (Interactive)",
|
| 362 |
+
height=height
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Launch the viewer
|
| 366 |
+
print("Launching interactive 3D viewer...")
|
| 367 |
+
demo.launch(share=share)
|
| 368 |
+
|
| 369 |
+
return demo, combined_glb_path
|
| 370 |
+
|
| 371 |
+
except Exception as e:
|
| 372 |
+
print(f"ERROR in visualization: {e}")
|
| 373 |
+
import traceback
|
| 374 |
+
traceback.print_exc()
|
| 375 |
+
return None, None
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def visualize_meshes_comparison(aligned_mesh_path, dfy_mesh_path, use_interactive=False):
|
| 379 |
+
"""
|
| 380 |
+
Simple visualization of both meshes in a single 3D plot.
|
| 381 |
+
|
| 382 |
+
DEPRECATED: Use visualize_meshes_interactive() for better interactive visualization.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
aligned_mesh_path: Path to aligned mesh PLY file
|
| 386 |
+
dfy_mesh_path: Path to 3Dfy GLB file
|
| 387 |
+
use_interactive: Whether to attempt trimesh scene viewer (default: False)
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
tuple: (aligned_mesh, dfy_mesh) trimesh objects or (None, None) if failed
|
| 391 |
+
"""
|
| 392 |
+
import matplotlib.pyplot as plt
|
| 393 |
+
|
| 394 |
+
print("Loading meshes for visualization...")
|
| 395 |
+
|
| 396 |
+
try:
|
| 397 |
+
# Load aligned mesh (PLY)
|
| 398 |
+
aligned_mesh = trimesh.load(aligned_mesh_path)
|
| 399 |
+
print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices")
|
| 400 |
+
|
| 401 |
+
# Load 3Dfy mesh (GLB - handle scene structure)
|
| 402 |
+
dfy_scene = trimesh.load(dfy_mesh_path)
|
| 403 |
+
|
| 404 |
+
if hasattr(dfy_scene, 'dump'): # It's a scene
|
| 405 |
+
dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')]
|
| 406 |
+
if len(dfy_meshes) == 1:
|
| 407 |
+
dfy_mesh = dfy_meshes[0]
|
| 408 |
+
elif len(dfy_meshes) > 1:
|
| 409 |
+
dfy_mesh = trimesh.util.concatenate(dfy_meshes)
|
| 410 |
+
else:
|
| 411 |
+
raise ValueError("No valid meshes in GLB file")
|
| 412 |
+
else:
|
| 413 |
+
dfy_mesh = dfy_scene
|
| 414 |
+
|
| 415 |
+
print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices")
|
| 416 |
+
|
| 417 |
+
# Create single 3D plot with both meshes
|
| 418 |
+
fig = plt.figure(figsize=(12, 10))
|
| 419 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 420 |
+
|
| 421 |
+
# Plot both meshes in the same space
|
| 422 |
+
ax.scatter(dfy_mesh.vertices[:, 0],
|
| 423 |
+
dfy_mesh.vertices[:, 1],
|
| 424 |
+
dfy_mesh.vertices[:, 2],
|
| 425 |
+
c='blue', s=0.1, alpha=0.6, label='3Dfy Original')
|
| 426 |
+
|
| 427 |
+
ax.scatter(aligned_mesh.vertices[:, 0],
|
| 428 |
+
aligned_mesh.vertices[:, 1],
|
| 429 |
+
aligned_mesh.vertices[:, 2],
|
| 430 |
+
c='red', s=0.1, alpha=0.6, label='SAM 3D Body Aligned')
|
| 431 |
+
|
| 432 |
+
ax.set_title('Mesh Comparison: 3Dfy vs SAM 3D Body Aligned', fontsize=16, fontweight='bold')
|
| 433 |
+
ax.set_xlabel('X')
|
| 434 |
+
ax.set_ylabel('Y')
|
| 435 |
+
ax.set_zlabel('Z')
|
| 436 |
+
ax.legend()
|
| 437 |
+
|
| 438 |
+
plt.tight_layout()
|
| 439 |
+
plt.show()
|
| 440 |
+
|
| 441 |
+
# Optional trimesh scene viewer
|
| 442 |
+
if use_interactive:
|
| 443 |
+
try:
|
| 444 |
+
print("Creating trimesh scene...")
|
| 445 |
+
scene = trimesh.Scene()
|
| 446 |
+
|
| 447 |
+
# Add both meshes with different colors
|
| 448 |
+
aligned_copy = aligned_mesh.copy()
|
| 449 |
+
aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red
|
| 450 |
+
scene.add_geometry(aligned_copy, node_name="sam3d_aligned")
|
| 451 |
+
|
| 452 |
+
dfy_copy = dfy_mesh.copy()
|
| 453 |
+
dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue
|
| 454 |
+
scene.add_geometry(dfy_copy, node_name="dfy_original")
|
| 455 |
+
|
| 456 |
+
print("Opening interactive trimesh viewer...")
|
| 457 |
+
scene.show()
|
| 458 |
+
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"Trimesh viewer not available: {e}")
|
| 461 |
+
|
| 462 |
+
print("Visualization complete")
|
| 463 |
+
return aligned_mesh, dfy_mesh
|
| 464 |
+
|
| 465 |
+
except Exception as e:
|
| 466 |
+
print(f"ERROR in visualization: {e}")
|
| 467 |
+
import traceback
|
| 468 |
+
traceback.print_exc()
|
| 469 |
+
return None, None
|
thirdparty/sam3d/sam3d/patching/hydra
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import hydra
|
| 5 |
+
import urllib.request
|
| 6 |
+
|
| 7 |
+
if hydra.__version__ != "1.3.2":
|
| 8 |
+
raise RuntimeError("different hydra version has been found, cannot patch")
|
| 9 |
+
|
| 10 |
+
hydra_root = os.path.dirname(hydra.__file__)
|
| 11 |
+
utils_path = os.path.join(hydra_root, "core", "utils.py")
|
| 12 |
+
|
| 13 |
+
urllib.request.urlretrieve(
|
| 14 |
+
"https://raw.githubusercontent.com/gleize/hydra/78f00766b5f37672aa7232ebbf01bdd74246bd60/hydra/core/utils.py",
|
| 15 |
+
utils_path,
|
| 16 |
+
)
|
thirdparty/sam3d/sam3d/pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling", "hatch-requirements-txt"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[tool.hatch.envs.default.env-vars]
|
| 6 |
+
PIP_EXTRA_INDEX_URL = "https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121"
|
| 7 |
+
|
| 8 |
+
[tool.hatch.metadata]
|
| 9 |
+
# for git-referenced dependencies
|
| 10 |
+
allow-direct-references = true
|
| 11 |
+
|
| 12 |
+
[project]
|
| 13 |
+
name = "sam3d_objects"
|
| 14 |
+
version = "0.0.1"
|
| 15 |
+
# required for "hatch-requirements-txt" to work
|
| 16 |
+
dynamic = ["dependencies", "optional-dependencies"]
|
| 17 |
+
|
| 18 |
+
[tool.hatch.build]
|
| 19 |
+
ignore-vcs = true
|
| 20 |
+
include = ["**/*.py"]
|
| 21 |
+
exclude = ["conftest.py", "*_test.py"]
|
| 22 |
+
packages = ["sam3d_objects"]
|
| 23 |
+
|
| 24 |
+
[tool.hatch.metadata.hooks.requirements_txt]
|
| 25 |
+
files = ["requirements.txt"]
|
| 26 |
+
|
| 27 |
+
[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies]
|
| 28 |
+
p3d = ["requirements.p3d.txt"]
|
| 29 |
+
inference = ["requirements.inference.txt"]
|
| 30 |
+
dev = ["requirements.dev.txt"]
|
thirdparty/sam3d/sam3d/requirements.dev.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytest
|
| 2 |
+
findpydeps
|
| 3 |
+
pipdeptree
|
| 4 |
+
lovely_tensors
|
thirdparty/sam3d/sam3d/requirements.inference.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
kaolin==0.17.0
|
| 2 |
+
gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@2323de5905d5e90e035f792fe65bad0fedd413e7
|
| 3 |
+
seaborn==0.13.2
|
| 4 |
+
gradio==5.49.0
|
thirdparty/sam3d/sam3d/requirements.p3d.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47
|
| 2 |
+
flash_attn==2.8.3
|
thirdparty/sam3d/sam3d/requirements.txt
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
astor==0.8.1
|
| 2 |
+
async-timeout==4.0.3
|
| 3 |
+
auto_gptq==0.7.1
|
| 4 |
+
autoflake==2.3.1
|
| 5 |
+
av==12.0.0
|
| 6 |
+
bitsandbytes==0.43.0
|
| 7 |
+
black==24.3.0
|
| 8 |
+
bpy==4.3.0
|
| 9 |
+
colorama==0.4.6
|
| 10 |
+
conda-pack==0.7.1
|
| 11 |
+
crcmod==1.7
|
| 12 |
+
cuda-python==12.1.0
|
| 13 |
+
dataclasses==0.6
|
| 14 |
+
decord==0.6.0
|
| 15 |
+
deprecation==2.1.0
|
| 16 |
+
easydict==1.13
|
| 17 |
+
einops-exts==0.0.4
|
| 18 |
+
exceptiongroup==1.2.0
|
| 19 |
+
fastavro==1.9.4
|
| 20 |
+
fasteners==0.19
|
| 21 |
+
flake8==7.0.0
|
| 22 |
+
Flask==3.0.3
|
| 23 |
+
fqdn==1.5.1
|
| 24 |
+
ftfy==6.2.0
|
| 25 |
+
fvcore==0.1.5.post20221221
|
| 26 |
+
gdown==5.2.0
|
| 27 |
+
h5py==3.12.1
|
| 28 |
+
hdfs==2.7.3
|
| 29 |
+
httplib2==0.22.0
|
| 30 |
+
hydra-core==1.3.2
|
| 31 |
+
hydra-submitit-launcher==1.2.0
|
| 32 |
+
igraph==0.11.8
|
| 33 |
+
imath==0.0.2
|
| 34 |
+
isoduration==20.11.0
|
| 35 |
+
jsonlines==4.0.0
|
| 36 |
+
jsonpickle==3.0.4
|
| 37 |
+
jsonpointer==2.4
|
| 38 |
+
jupyter==1.1.1
|
| 39 |
+
librosa==0.10.1
|
| 40 |
+
lightning==2.3.3
|
| 41 |
+
loguru==0.7.2
|
| 42 |
+
mosaicml-streaming==0.7.5
|
| 43 |
+
nvidia-cuda-nvcc-cu12==12.1.105
|
| 44 |
+
nvidia-pyindex==1.0.9
|
| 45 |
+
objsize==0.7.0
|
| 46 |
+
open3d==0.18.0
|
| 47 |
+
opencv-python==4.9.0.80
|
| 48 |
+
OpenEXR==3.3.3
|
| 49 |
+
optimum==1.18.1
|
| 50 |
+
optree==0.14.1
|
| 51 |
+
orjson==3.10.0
|
| 52 |
+
panda3d-gltf==1.2.1
|
| 53 |
+
pdoc3==0.10.0
|
| 54 |
+
peft==0.10.0
|
| 55 |
+
pip-system-certs==4.0
|
| 56 |
+
point-cloud-utils==0.29.5
|
| 57 |
+
polyscope==2.3.0
|
| 58 |
+
pycocotools==2.0.7
|
| 59 |
+
pydot==1.4.2
|
| 60 |
+
pymeshfix==0.17.0
|
| 61 |
+
pymongo==4.6.3
|
| 62 |
+
pyrender==0.1.45
|
| 63 |
+
PySocks==1.7.1
|
| 64 |
+
pytest==8.1.1
|
| 65 |
+
python-pycg==0.9.2
|
| 66 |
+
randomname==0.2.1
|
| 67 |
+
roma==1.5.1
|
| 68 |
+
rootutils==1.0.7
|
| 69 |
+
Rtree==1.3.0
|
| 70 |
+
sagemaker==2.242.0
|
| 71 |
+
scikit-image==0.23.1
|
| 72 |
+
sentence-transformers==2.6.1
|
| 73 |
+
simplejson==3.19.2
|
| 74 |
+
smplx==0.1.28
|
| 75 |
+
spconv-cu121==2.3.8
|
| 76 |
+
tensorboard==2.16.2
|
| 77 |
+
timm==0.9.16
|
| 78 |
+
tomli==2.0.1
|
| 79 |
+
torchaudio==2.5.1+cu121
|
| 80 |
+
uri-template==1.3.0
|
| 81 |
+
usort==1.0.8.post1
|
| 82 |
+
wandb==0.20.0
|
| 83 |
+
webcolors==1.13
|
| 84 |
+
webdataset==0.2.86
|
| 85 |
+
Werkzeug==3.0.6
|
| 86 |
+
xatlas==0.0.9
|
| 87 |
+
xformers==0.0.28.post3
|
| 88 |
+
MoGe @ git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b
|
thirdparty/sam3d/sam3d/sam3d_objects/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# Allow skipping initialization for lightweight tools
|
| 5 |
+
if not os.environ.get('LIDRA_SKIP_INIT'):
|
| 6 |
+
import sam3d_objects.init
|
thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import functools
|
| 3 |
+
from typing import Any, Callable, Union
|
| 4 |
+
|
| 5 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 6 |
+
from hydra.utils import instantiate
|
| 7 |
+
|
| 8 |
+
TargetType = Union[str, type, Callable[..., Any]]
|
| 9 |
+
ClassOrCallableType = Union[type, Callable[..., Any]]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def dump_config(config: DictConfig, path: str = "./config.yaml"):
|
| 13 |
+
txt = OmegaConf.to_yaml(config, sort_keys=True)
|
| 14 |
+
with open(path, "w") as f:
|
| 15 |
+
f.write(txt)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def locate(path: str) -> Any:
|
| 19 |
+
if path == "":
|
| 20 |
+
raise ImportError("Empty path")
|
| 21 |
+
|
| 22 |
+
import builtins
|
| 23 |
+
from importlib import import_module
|
| 24 |
+
|
| 25 |
+
parts = [part for part in path.split(".") if part]
|
| 26 |
+
|
| 27 |
+
# load module part
|
| 28 |
+
module = None
|
| 29 |
+
for n in reversed(range(len(parts))):
|
| 30 |
+
try:
|
| 31 |
+
mod = ".".join(parts[:n])
|
| 32 |
+
module = import_module(mod)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
if n == 0:
|
| 35 |
+
raise ImportError(f"Error loading module '{path}'") from e
|
| 36 |
+
continue
|
| 37 |
+
if module:
|
| 38 |
+
break
|
| 39 |
+
|
| 40 |
+
if module:
|
| 41 |
+
obj = module
|
| 42 |
+
else:
|
| 43 |
+
obj = builtins
|
| 44 |
+
|
| 45 |
+
# load object path in module
|
| 46 |
+
for part in parts[n:]:
|
| 47 |
+
mod = mod + "." + part
|
| 48 |
+
if not hasattr(obj, part):
|
| 49 |
+
try:
|
| 50 |
+
import_module(mod)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
raise ImportError(
|
| 53 |
+
f"Encountered error: `{e}` when loading module '{path}'"
|
| 54 |
+
) from e
|
| 55 |
+
obj = getattr(obj, part)
|
| 56 |
+
|
| 57 |
+
return obj
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def full_instance_name(instance: Any) -> str:
|
| 61 |
+
return full_class_name(instance.__class__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def full_class_name(klass: Any) -> str:
|
| 65 |
+
module = klass.__module__
|
| 66 |
+
if module == "builtins":
|
| 67 |
+
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
| 68 |
+
return module + "." + klass.__qualname__
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ensure_is_subclass(child_class: type, parent_class: type) -> None:
|
| 72 |
+
if not issubclass(child_class, parent_class):
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
f"class {full_class_name(child_class)} should be a subclass of {full_class_name(parent_class)}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def find_class_or_callable_from_target(
|
| 79 |
+
target: TargetType,
|
| 80 |
+
) -> ClassOrCallableType:
|
| 81 |
+
if isinstance(target, str):
|
| 82 |
+
obj = locate(target)
|
| 83 |
+
else:
|
| 84 |
+
obj = target
|
| 85 |
+
|
| 86 |
+
if (not isinstance(obj, type)) and (not callable(obj)):
|
| 87 |
+
raise ValueError(f"Invalid type ({type(obj)}) found for {target}")
|
| 88 |
+
|
| 89 |
+
return obj
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def find_and_ensure_is_subclass(target: TargetType, type_: type) -> ClassOrCallableType:
|
| 93 |
+
klass = find_class_or_callable_from_target(target)
|
| 94 |
+
ensure_is_subclass(klass, type_)
|
| 95 |
+
return klass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class StrictPartial:
|
| 99 |
+
# remark : the `/` will handle the `path` argument name conflict (e.g. calling StrictPartial("a.b.c", ..., path="/a/b/c"))
|
| 100 |
+
def __init__(self, path, /, *args, **kwargs):
|
| 101 |
+
class_or_callable = find_class_or_callable_from_target(path)
|
| 102 |
+
self._partial = functools.partial(class_or_callable, *args, **kwargs)
|
| 103 |
+
|
| 104 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
| 105 |
+
return self._partial(*args, **kwargs)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class RecursivePartial:
|
| 109 |
+
@staticmethod
|
| 110 |
+
def replace_keys(config, key_mapping):
|
| 111 |
+
def recurse(data):
|
| 112 |
+
if isinstance(data, DictConfig):
|
| 113 |
+
new_data = {
|
| 114 |
+
key_mapping[k] if k in key_mapping else k: recurse(v)
|
| 115 |
+
for k, v in data.items()
|
| 116 |
+
}
|
| 117 |
+
new_data = DictConfig(new_data)
|
| 118 |
+
elif isinstance(data, ListConfig):
|
| 119 |
+
new_data = ListConfig([recurse(item) for item in data])
|
| 120 |
+
elif type(data) in {bool, str, int, float, type(None)}:
|
| 121 |
+
new_data = data
|
| 122 |
+
else:
|
| 123 |
+
raise RuntimeError(f"unknow type found : {type(data)}")
|
| 124 |
+
|
| 125 |
+
return new_data
|
| 126 |
+
|
| 127 |
+
return recurse(config)
|
| 128 |
+
|
| 129 |
+
def __init__(self, config):
|
| 130 |
+
self.config = RecursivePartial.replace_keys(
|
| 131 |
+
config, {"_rpartial_target_": "_target_"}
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
| 135 |
+
return instantiate(self.config)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Partial(StrictPartial):
|
| 139 |
+
# remark : allow `path` argument to be exposed for easier use
|
| 140 |
+
def __init__(self, path, *args, **kwargs):
|
| 141 |
+
super().__init__(path, *args, **kwargs)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def subkey(mapping, key):
|
| 145 |
+
return mapping[key]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def make_set(*args):
|
| 149 |
+
return set(args)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def make_tuple(*args):
|
| 153 |
+
return tuple(args)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def make_list_from_kwargs(**kwargs):
|
| 157 |
+
# Filter out None/null values to avoid issues with callbacks
|
| 158 |
+
return [v for v in kwargs.values() if v is not None]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def make_string(value):
|
| 162 |
+
return str(value)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def make_dict(**kwargs):
|
| 166 |
+
return dict(kwargs)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_item(data, key: str):
|
| 170 |
+
return data[key]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def get_attr(data, key: str):
|
| 174 |
+
return getattr(data, key)
|
thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py
ADDED
|
@@ -0,0 +1,986 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
import random
|
| 4 |
+
from typing import Optional, Dict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import torchvision.transforms.functional
|
| 9 |
+
from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered
|
| 10 |
+
from sam3d_objects.model.backbone.dit.embedder.point_remapper import PointRemapper
|
| 11 |
+
from typing import Optional, Dict
|
| 12 |
+
from loguru import logger
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torchvision
|
| 16 |
+
import torchvision.transforms as tv_transforms
|
| 17 |
+
import torchvision.transforms.functional
|
| 18 |
+
import torchvision.transforms.functional as TF
|
| 19 |
+
|
| 20 |
+
from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def UNNORMALIZE(mean, std):
|
| 24 |
+
mean = torch.tensor(mean).reshape((3, 1, 1))
|
| 25 |
+
std = torch.tensor(std).reshape((3, 1, 1))
|
| 26 |
+
|
| 27 |
+
def unnormalize_img(img):
|
| 28 |
+
assert img.ndim == 3 and img.shape[0] == 3
|
| 29 |
+
|
| 30 |
+
return img * std.to(img.device) + mean.to(img.device)
|
| 31 |
+
|
| 32 |
+
return unnormalize_img
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 36 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
IMAGENET_NORMALIZATION = tv_transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
|
| 40 |
+
IMAGENET_UNNORMALIZATION = UNNORMALIZE(IMAGENET_MEAN, IMAGENET_STD)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BoundingBoxError(Exception):
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def check_bounding_box(bbox_w, bbox_h):
|
| 48 |
+
if bbox_w < 2 or bbox_h < 2:
|
| 49 |
+
raise BoundingBoxError("Bounding box dimensions must be at least 2x2.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RGBAImageProcessor:
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
resize_and_make_square_kwargs: Optional[Dict] = None,
|
| 56 |
+
object_crop_kwargs: Optional[Dict] = None,
|
| 57 |
+
remove_background: bool = False,
|
| 58 |
+
imagenet_normalization: bool = False,
|
| 59 |
+
):
|
| 60 |
+
self.remove_background = remove_background
|
| 61 |
+
self.resize_and_pad_kwargs = resize_and_make_square_kwargs
|
| 62 |
+
self.object_crop_kwargs = object_crop_kwargs
|
| 63 |
+
self.imagenet_normalization = imagenet_normalization
|
| 64 |
+
if resize_and_make_square_kwargs is not None:
|
| 65 |
+
self.transforms = resize_and_make_square(**resize_and_make_square_kwargs)
|
| 66 |
+
|
| 67 |
+
def __call__(
|
| 68 |
+
self, image: torch.Tensor, mask: Optional[torch.Tensor] = None
|
| 69 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 70 |
+
if mask is None:
|
| 71 |
+
assert (
|
| 72 |
+
image.shape[0] == 4
|
| 73 |
+
), f"Requires 4 channels (RGB + alpha), got {image.shape[0]=}"
|
| 74 |
+
image, mask = split_rgba(image)
|
| 75 |
+
else:
|
| 76 |
+
assert (
|
| 77 |
+
image.shape[0] == 3
|
| 78 |
+
), f"Requires 3 channels (RGB), got {image.shape[0]=}"
|
| 79 |
+
assert mask.dim() == 2, f"Requires 2D mask, got {mask.dim()=}"
|
| 80 |
+
|
| 81 |
+
if not self.object_crop_kwargs in [None, False]:
|
| 82 |
+
image, mask = crop_around_mask_with_padding(
|
| 83 |
+
image, mask, **self.object_crop_kwargs
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if self.remove_background:
|
| 87 |
+
image, mask = rembg(image, mask)
|
| 88 |
+
|
| 89 |
+
image = self.transforms["img_transform"](image)
|
| 90 |
+
mask = self.transforms["mask_transform"](mask.unsqueeze(0))
|
| 91 |
+
|
| 92 |
+
if self.imagenet_normalization:
|
| 93 |
+
image = IMAGENET_NORMALIZATION(image)
|
| 94 |
+
return image, mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_rgb(fpath: str) -> torch.Tensor:
|
| 98 |
+
"""
|
| 99 |
+
Load a RGB(A) image from a file path.
|
| 100 |
+
"""
|
| 101 |
+
image = plt.imread(fpath) # Why use matplotlib?
|
| 102 |
+
if image.dtype == "uint8":
|
| 103 |
+
image = image / 255
|
| 104 |
+
image = image.astype(np.float32)
|
| 105 |
+
image = torch.from_numpy(image)
|
| 106 |
+
image = image.permute(2, 0, 1).contiguous()
|
| 107 |
+
return image
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def concat_rgba(
|
| 111 |
+
rgb_image: torch.Tensor,
|
| 112 |
+
mask: torch.Tensor,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
"""
|
| 115 |
+
Create a 4-channel RGBA image from a 3-channel RGB image and a mask.
|
| 116 |
+
"""
|
| 117 |
+
assert rgb_image.dim() == 3, f"{rgb_image.shape=}"
|
| 118 |
+
assert mask.dim() == 2, f"{mask.shape=}"
|
| 119 |
+
assert rgb_image.shape[0] == 3, f"{rgb_image.shape[0]=}"
|
| 120 |
+
assert rgb_image.shape[1:] == mask.shape, f"{rgb_image.shape[1:]=} != {mask.shape=}"
|
| 121 |
+
return torch.cat((rgb_image, mask[None, ...]), dim=0)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def split_rgba(rgba_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 125 |
+
"""
|
| 126 |
+
Split a 4-channel RGBA image into a 3-channel RGB image and a 1-channel mask.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
rgba_image: A 4-channel RGBA image.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
A tuple of (rgb_image, mask).
|
| 133 |
+
"""
|
| 134 |
+
assert rgba_image.dim() == 3, f"{rgba_image.shape=}"
|
| 135 |
+
assert rgba_image.shape[0] == 4, f"{rgba_image.shape[0]=}"
|
| 136 |
+
return rgba_image[:3], rgba_image[3]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_mask(
|
| 140 |
+
rgb_image: torch.Tensor,
|
| 141 |
+
depth_image: torch.Tensor,
|
| 142 |
+
mask_source: str,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
"""
|
| 145 |
+
Extract a mask from either the alpha channel of an RGB image or a depth image.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
rgb_image: Tensor of shape (B, C, H, W) or (C, H, W) where C >= 4 if using alpha channel
|
| 149 |
+
depth_image: Tensor of shape (B, 1, H, W) or (1, H, W) containing depth information
|
| 150 |
+
mask_source: Source of the mask, either "ALPHA_CHANNEL" or "DEPTH"
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
mask: Tensor of shape (B, 1, H, W) or (1, H, W) containing the extracted mask
|
| 154 |
+
"""
|
| 155 |
+
# Handle unbatched inputs (add batch dimension if needed)
|
| 156 |
+
is_batched = len(rgb_image.shape) == 4
|
| 157 |
+
|
| 158 |
+
if not is_batched:
|
| 159 |
+
rgb_image = rgb_image.unsqueeze(0)
|
| 160 |
+
if depth_image is not None:
|
| 161 |
+
depth_image = depth_image.unsqueeze(0)
|
| 162 |
+
|
| 163 |
+
if mask_source == "ALPHA_CHANNEL":
|
| 164 |
+
if rgb_image.shape[1] != 4:
|
| 165 |
+
logger.warning(f"No ALPHA CHANNEL for the image, cannot read mask.")
|
| 166 |
+
mask = None
|
| 167 |
+
else:
|
| 168 |
+
mask = rgb_image[:, 3:4, :, :]
|
| 169 |
+
elif mask_source == "DEPTH":
|
| 170 |
+
mask = depth_image
|
| 171 |
+
else:
|
| 172 |
+
raise ValueError(f"Invalid mask source: {mask_source}")
|
| 173 |
+
|
| 174 |
+
# Remove batch dimension if input was unbatched
|
| 175 |
+
if not is_batched:
|
| 176 |
+
mask = mask.squeeze(0)
|
| 177 |
+
|
| 178 |
+
return mask
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def rembg(image, mask, pointmap=None):
|
| 182 |
+
"""
|
| 183 |
+
Remove the background from an image using a mask.
|
| 184 |
+
For pointmaps, sets background regions to NaN.
|
| 185 |
+
|
| 186 |
+
This function follows the standard transform pattern:
|
| 187 |
+
- If called with (image, mask), returns (image, mask)
|
| 188 |
+
- If called with (image, mask, pointmap), returns (image, mask, pointmap)
|
| 189 |
+
"""
|
| 190 |
+
masked_image = image * mask
|
| 191 |
+
|
| 192 |
+
if pointmap is not None:
|
| 193 |
+
masked_pointmap = torch.where(mask > 0, pointmap, torch.nan)
|
| 194 |
+
return masked_image, mask, masked_pointmap
|
| 195 |
+
|
| 196 |
+
return masked_image, mask
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def resize_and_make_square(
|
| 200 |
+
img_size: int,
|
| 201 |
+
make_square: bool | str = False,
|
| 202 |
+
):
|
| 203 |
+
"""
|
| 204 |
+
Create image and mask transforms based on configuration.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
dict: {"img_transform": img_transform, "mask_transform": mask_transform}
|
| 208 |
+
"""
|
| 209 |
+
if isinstance(make_square, str):
|
| 210 |
+
make_square = make_square.lower()
|
| 211 |
+
assert make_square in ["pad", "crop", False]
|
| 212 |
+
pre_resize_transform = tv_transforms.Lambda(lambda x: x)
|
| 213 |
+
post_resize_transform = tv_transforms.Lambda(lambda x: x)
|
| 214 |
+
if make_square == "pad":
|
| 215 |
+
pre_resize_transform = pad_to_square_centered
|
| 216 |
+
elif make_square == "crop":
|
| 217 |
+
post_resize_transform = tv_transforms.CenterCrop(img_size)
|
| 218 |
+
|
| 219 |
+
img_resize = tv_transforms.Resize(img_size)
|
| 220 |
+
mask_resize = tv_transforms.Resize(
|
| 221 |
+
img_size,
|
| 222 |
+
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
img_transform = tv_transforms.Compose(
|
| 226 |
+
[
|
| 227 |
+
pre_resize_transform,
|
| 228 |
+
img_resize,
|
| 229 |
+
post_resize_transform,
|
| 230 |
+
]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
mask_transform = tv_transforms.Compose(
|
| 234 |
+
[
|
| 235 |
+
pre_resize_transform,
|
| 236 |
+
mask_resize,
|
| 237 |
+
post_resize_transform,
|
| 238 |
+
]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"img_transform": img_transform,
|
| 243 |
+
"mask_transform": mask_transform,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def crop_around_mask_with_random_box_size_factor(
|
| 248 |
+
loaded_image: torch.Tensor,
|
| 249 |
+
mask: torch.Tensor,
|
| 250 |
+
random_box_size_factor: float = 1.0,
|
| 251 |
+
pointmap: Optional[torch.Tensor] = None,
|
| 252 |
+
) -> np.ndarray:
|
| 253 |
+
return crop_around_mask_with_padding(
|
| 254 |
+
loaded_image,
|
| 255 |
+
mask,
|
| 256 |
+
box_size_factor=1.0 + random.uniform(0, 1) * random_box_size_factor,
|
| 257 |
+
padding_factor=0.0,
|
| 258 |
+
pointmap=pointmap,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def crop_around_mask_with_padding(
|
| 263 |
+
loaded_image: torch.Tensor,
|
| 264 |
+
mask: torch.Tensor,
|
| 265 |
+
box_size_factor: float = 1.6,
|
| 266 |
+
padding_factor: float = 0.1,
|
| 267 |
+
pointmap: Optional[torch.Tensor] = None,
|
| 268 |
+
) -> np.ndarray:
|
| 269 |
+
# cast to ensure the function can be called normally
|
| 270 |
+
cast_mask = False
|
| 271 |
+
if mask.dim() == 3:
|
| 272 |
+
assert mask.shape[0] == 1, "cannot take mask with channel dimension not 1"
|
| 273 |
+
mask = mask[0]
|
| 274 |
+
cast_mask = True
|
| 275 |
+
loaded_image = concat_rgba(loaded_image, mask)
|
| 276 |
+
|
| 277 |
+
bbox = compute_mask_bbox(mask, box_size_factor)
|
| 278 |
+
loaded_image = torchvision.transforms.functional.crop(
|
| 279 |
+
loaded_image, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Crop pointmap if provided
|
| 283 |
+
if pointmap is not None:
|
| 284 |
+
pointmap = torchvision.transforms.functional.crop(
|
| 285 |
+
pointmap, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
C, H, W = loaded_image.shape
|
| 289 |
+
max_dim = max(H, W) # Get the larger dimension
|
| 290 |
+
|
| 291 |
+
# Step 1: Pad to square shape
|
| 292 |
+
pad_h = (max_dim - H) // 2
|
| 293 |
+
pad_w = (max_dim - W) // 2
|
| 294 |
+
pad_h_extra = (max_dim - H) - pad_h # To ensure even padding
|
| 295 |
+
pad_w_extra = (max_dim - W) - pad_w
|
| 296 |
+
|
| 297 |
+
loaded_image = torch.nn.functional.pad(
|
| 298 |
+
loaded_image, (pad_w, pad_w_extra, pad_h, pad_h_extra), mode="constant", value=0
|
| 299 |
+
)
|
| 300 |
+
if pointmap is not None:
|
| 301 |
+
pointmap = torch.nn.functional.pad(
|
| 302 |
+
pointmap,
|
| 303 |
+
(pad_w, pad_w_extra, pad_h, pad_h_extra),
|
| 304 |
+
mode="constant",
|
| 305 |
+
value=float("nan"),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Step 2: Extend by 10% on each side; idk but this seems to have better results overall
|
| 309 |
+
if padding_factor > 0:
|
| 310 |
+
extend_size = int(max_dim * padding_factor) # 10% extension on each side
|
| 311 |
+
loaded_image = torch.nn.functional.pad(
|
| 312 |
+
loaded_image,
|
| 313 |
+
(extend_size, extend_size, extend_size, extend_size),
|
| 314 |
+
mode="constant",
|
| 315 |
+
value=0,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if pointmap is not None:
|
| 319 |
+
pointmap = torch.nn.functional.pad(
|
| 320 |
+
pointmap,
|
| 321 |
+
(extend_size, extend_size, extend_size, extend_size),
|
| 322 |
+
mode="constant",
|
| 323 |
+
value=float("nan"),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
rgb_image, mask = split_rgba(loaded_image)
|
| 327 |
+
if cast_mask:
|
| 328 |
+
mask = mask[None]
|
| 329 |
+
|
| 330 |
+
if pointmap is not None:
|
| 331 |
+
return rgb_image, mask, pointmap
|
| 332 |
+
return rgb_image, mask
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def compute_mask_bbox(
|
| 336 |
+
mask: torch.Tensor, box_size_factor: float = 1.0
|
| 337 |
+
) -> tuple[float, float, float, float]:
|
| 338 |
+
"""
|
| 339 |
+
Compute a bounding box around a binary mask with optional size adjustment.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
mask: A 2D binary tensor where non-zero values represent the object of interest.
|
| 343 |
+
box_size_factor: Factor to scale the bounding box size. Values > 1.0 create a larger box.
|
| 344 |
+
Default is 1.0 (tight bounding box).
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
A tuple of (x1, y1, x2, y2) coordinates representing the bounding box,
|
| 348 |
+
where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
|
| 349 |
+
|
| 350 |
+
Raises:
|
| 351 |
+
ValueError: If mask is not a torch.Tensor or not a 2D tensor.
|
| 352 |
+
"""
|
| 353 |
+
if not isinstance(mask, torch.Tensor):
|
| 354 |
+
raise ValueError("Mask must be a torch.Tensor")
|
| 355 |
+
if not mask.dim() == 2:
|
| 356 |
+
raise ValueError("Mask must be a 2D tensor")
|
| 357 |
+
bbox_indices = torch.nonzero(mask)
|
| 358 |
+
if bbox_indices.numel() == 0:
|
| 359 |
+
# Handle empty mask case
|
| 360 |
+
return (0, 0, 0, 0)
|
| 361 |
+
|
| 362 |
+
y_indices = bbox_indices[:, 0]
|
| 363 |
+
x_indices = bbox_indices[:, 1]
|
| 364 |
+
|
| 365 |
+
min_x = torch.min(x_indices).item()
|
| 366 |
+
min_y = torch.min(y_indices).item()
|
| 367 |
+
max_x = torch.max(x_indices).item()
|
| 368 |
+
max_y = torch.max(y_indices).item()
|
| 369 |
+
|
| 370 |
+
bbox = (min_x, min_y, max_x, max_y)
|
| 371 |
+
|
| 372 |
+
center_x = (bbox[0] + bbox[2]) / 2
|
| 373 |
+
center_y = (bbox[1] + bbox[3]) / 2
|
| 374 |
+
|
| 375 |
+
bbox_w, bbox_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
| 376 |
+
|
| 377 |
+
check_bounding_box(bbox_w, bbox_h)
|
| 378 |
+
|
| 379 |
+
size = max(bbox_w, bbox_h, 2)
|
| 380 |
+
size = int(size * box_size_factor)
|
| 381 |
+
|
| 382 |
+
bbox = (
|
| 383 |
+
int(center_x - size // 2),
|
| 384 |
+
int(center_y - size // 2),
|
| 385 |
+
int(center_x + size // 2),
|
| 386 |
+
int(center_y + size // 2),
|
| 387 |
+
)
|
| 388 |
+
# bbox = tuple(map(int, bbox))
|
| 389 |
+
return bbox
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def crop_and_pad(image, bbox):
|
| 393 |
+
"""
|
| 394 |
+
Crop an image using a bounding box and pad with zeros if out of bounds.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
image (torch.Tensor): CxHxW image.
|
| 398 |
+
bbox (tuple): (x1, y1, x2, y2) bounding box.
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
torch.Tensor: Cropped and zero-padded image.
|
| 402 |
+
"""
|
| 403 |
+
C, H, W = image.shape
|
| 404 |
+
x1, y1, x2, y2 = bbox
|
| 405 |
+
|
| 406 |
+
# Ensure coordinates are integers
|
| 407 |
+
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 408 |
+
|
| 409 |
+
# Compute cropping coordinates
|
| 410 |
+
x1_pad, y1_pad = max(0, -x1), max(0, -y1)
|
| 411 |
+
x2_pad, y2_pad = max(0, x2 - W), max(0, y2 - H)
|
| 412 |
+
|
| 413 |
+
# Compute valid region in the original image
|
| 414 |
+
x1_crop, y1_crop = max(0, x1), max(0, y1)
|
| 415 |
+
x2_crop, y2_crop = min(W, x2), min(H, y2)
|
| 416 |
+
|
| 417 |
+
# Extract the valid part
|
| 418 |
+
cropped = image[:, y1_crop:y2_crop, x1_crop:x2_crop]
|
| 419 |
+
|
| 420 |
+
# Create a zero-padded output
|
| 421 |
+
padded = torch.zeros((C, y2 - y1, x2 - x1), dtype=image.dtype)
|
| 422 |
+
|
| 423 |
+
# Place the cropped image into the zero-padded array
|
| 424 |
+
padded[
|
| 425 |
+
:, y1_pad : y1_pad + cropped.shape[1], x1_pad : x1_pad + cropped.shape[2]
|
| 426 |
+
] = cropped
|
| 427 |
+
|
| 428 |
+
return padded
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def resize_all_to_same_size(
|
| 432 |
+
rgb_image: torch.Tensor,
|
| 433 |
+
mask: torch.Tensor,
|
| 434 |
+
pointmap: Optional[torch.Tensor] = None,
|
| 435 |
+
target_size: Optional[tuple[int, int]] = None,
|
| 436 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 437 |
+
"""
|
| 438 |
+
Resize RGB image, mask, and pointmap to the same size.
|
| 439 |
+
|
| 440 |
+
This is crucial when pointmaps have different resolution than RGB images,
|
| 441 |
+
which must be done BEFORE any cropping operations.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
rgb_image: RGB image tensor of shape (C, H, W)
|
| 445 |
+
mask: Mask tensor of shape (H, W) or (1, H, W)
|
| 446 |
+
pointmap: Optional pointmap tensor of shape (C_p, H_p, W_p)
|
| 447 |
+
target_size: Target size as (H, W). If None, uses RGB image size.
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
Tuple of (resized_rgb, resized_mask, resized_pointmap)
|
| 451 |
+
"""
|
| 452 |
+
squeeze_mask = (mask.dim() == 2)
|
| 453 |
+
if squeeze_mask:
|
| 454 |
+
mask = mask.unsqueeze(0)
|
| 455 |
+
|
| 456 |
+
if target_size is None:
|
| 457 |
+
target_size = (rgb_image.shape[1], rgb_image.shape[2]) # (H, W)
|
| 458 |
+
|
| 459 |
+
rgb_needs_resize = (rgb_image.shape[1], rgb_image.shape[2]) != target_size
|
| 460 |
+
if rgb_needs_resize:
|
| 461 |
+
rgb_image = torchvision.transforms.functional.resize(
|
| 462 |
+
rgb_image, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 463 |
+
)
|
| 464 |
+
mask = torchvision.transforms.functional.resize(
|
| 465 |
+
mask, target_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if pointmap is not None:
|
| 469 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 470 |
+
if pointmap_size != target_size:
|
| 471 |
+
# Handle NaN values in pointmap during resizing
|
| 472 |
+
# Direct resize would propagate NaN values, so we need special handling
|
| 473 |
+
nan_mask = torch.isnan(pointmap).any(dim=0)
|
| 474 |
+
pointmap_clean = torch.where(torch.isnan(pointmap), torch.zeros_like(pointmap), pointmap)
|
| 475 |
+
pointmap_resized = torchvision.transforms.functional.resize(
|
| 476 |
+
pointmap_clean, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Resize the nan mask to identify which regions should remain invalid
|
| 480 |
+
nan_mask_resized = torchvision.transforms.functional.resize(
|
| 481 |
+
nan_mask.unsqueeze(0).float(), target_size,
|
| 482 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
| 483 |
+
).squeeze(0) > 0.5
|
| 484 |
+
|
| 485 |
+
# Restore NaN values in regions that were originally invalid
|
| 486 |
+
pointmap = torch.where(
|
| 487 |
+
nan_mask_resized.unsqueeze(0).expand_as(pointmap_resized),
|
| 488 |
+
torch.full_like(pointmap_resized, float('nan')),
|
| 489 |
+
pointmap_resized
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if squeeze_mask:
|
| 493 |
+
mask = mask.squeeze(0)
|
| 494 |
+
|
| 495 |
+
if pointmap is not None:
|
| 496 |
+
return rgb_image, mask, pointmap
|
| 497 |
+
return rgb_image, mask
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
SSINormalizedPointmap = namedtuple("SSINormalizedPointmap", ["pointmap", "scale", "shift"])
|
| 501 |
+
class SSIPointmapNormalizer:
|
| 502 |
+
|
| 503 |
+
def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
|
| 504 |
+
scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
|
| 505 |
+
) -> SSINormalizedPointmap:
|
| 506 |
+
if scale is None or shift is None:
|
| 507 |
+
normalized_pointmap, scale, shift = normalize_pointmap_ssi(pointmap)
|
| 508 |
+
else:
|
| 509 |
+
assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
|
| 510 |
+
normalized_pointmap = _apply_metric_to_ssi(pointmap, scale, shift)
|
| 511 |
+
return SSINormalizedPointmap(normalized_pointmap, scale, shift)
|
| 512 |
+
|
| 513 |
+
def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
|
| 514 |
+
pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True)
|
| 515 |
+
return pointmap
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class ObjectCentricSSI(SSIPointmapNormalizer):
|
| 520 |
+
def __init__(self,
|
| 521 |
+
use_scene_scale: bool = True,
|
| 522 |
+
quantile_drop_threshold: float = 0.1,
|
| 523 |
+
clip_beyond_scale: Optional[float] = None,
|
| 524 |
+
# scale_factor: float = 3.8076, # e^(1.337); empirical mean of R3+Artist train
|
| 525 |
+
scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train
|
| 526 |
+
allow_scale_and_shift_override: bool = False,
|
| 527 |
+
raise_on_no_valid_points: bool = False,
|
| 528 |
+
):
|
| 529 |
+
self.use_scene_scale = use_scene_scale
|
| 530 |
+
self.quantile_drop_threshold = quantile_drop_threshold
|
| 531 |
+
self.clip_beyond_scale = clip_beyond_scale
|
| 532 |
+
self.scale_factor = scale_factor
|
| 533 |
+
self.allow_scale_and_shift_override = allow_scale_and_shift_override
|
| 534 |
+
self.raise_on_no_valid_points = raise_on_no_valid_points
|
| 535 |
+
|
| 536 |
+
def _compute_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 537 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
mask_resized = torchvision.transforms.functional.resize(
|
| 541 |
+
mask, pointmap_size,
|
| 542 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
| 543 |
+
).squeeze(0)
|
| 544 |
+
|
| 545 |
+
pointmap_flat = pointmap.reshape(3, -1)
|
| 546 |
+
# Get valid points from the mask
|
| 547 |
+
mask_bool = mask_resized.reshape(-1) > 0.5
|
| 548 |
+
mask_points = pointmap_flat[:, mask_bool]
|
| 549 |
+
|
| 550 |
+
if mask_points.isfinite().max() == 0:
|
| 551 |
+
if self.raise_on_no_valid_points:
|
| 552 |
+
raise ValueError(f"No valid points found in mask")
|
| 553 |
+
logger.warning(f"No valid points found in mask; setting scale to {self.scale_factor} and shift to 0")
|
| 554 |
+
return torch.ones_like(pointmap_flat[:,0]) * self.scale_factor, torch.zeros_like(pointmap_flat[:,0])
|
| 555 |
+
|
| 556 |
+
# Compute median for shift
|
| 557 |
+
shift = mask_points.nanmedian(dim=-1).values
|
| 558 |
+
# logger.info(f"{pointmap.shape=} {mask_resized.shape=} {shift.shape=}")
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
if self.use_scene_scale == True:
|
| 562 |
+
# Normalize by the scene scale
|
| 563 |
+
points_centered = pointmap_flat - shift.unsqueeze(-1)
|
| 564 |
+
max_dims = points_centered.abs().max(dim=0).values
|
| 565 |
+
scale = max_dims.nanmedian(dim=-1).values
|
| 566 |
+
elif self.use_scene_scale == False:
|
| 567 |
+
# Normalize by the object scale
|
| 568 |
+
shifted_mask_points = mask_points - shift.unsqueeze(-1)
|
| 569 |
+
norm = shifted_mask_points.norm(dim=0)
|
| 570 |
+
quantiles = torch.nanquantile(norm,
|
| 571 |
+
torch.tensor([self.quantile_drop_threshold, 1. - self.quantile_drop_threshold],
|
| 572 |
+
device=shifted_mask_points.device),
|
| 573 |
+
dim=-1)
|
| 574 |
+
scale = (quantiles[1] - quantiles[0]).max(dim=-1).values * 2.0
|
| 575 |
+
elif self.use_scene_scale.upper() == "OBJECT_NORM_MEDIAN":
|
| 576 |
+
# Normalize by the object scale
|
| 577 |
+
shifted_mask_points = mask_points - shift.unsqueeze(-1)
|
| 578 |
+
norm = shifted_mask_points.norm(dim=0)
|
| 579 |
+
scale = norm.nanmedian(dim=-1).values
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError(f"Invalid use_scene_scale: {self.use_scene_scale}")
|
| 582 |
+
scale = scale.expand_as(shift) # per-dim scaling
|
| 583 |
+
scale = scale * self.scale_factor
|
| 584 |
+
return scale, shift
|
| 585 |
+
|
| 586 |
+
def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
|
| 587 |
+
scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
|
| 588 |
+
) -> torch.Tensor:
|
| 589 |
+
# 1. resize mask to size of pointmap using nearest interpolation
|
| 590 |
+
# 2. get mask points: pointmap[mask > 0.5]
|
| 591 |
+
# 3. shift = mask_points.median() # xyz
|
| 592 |
+
# 4. scale = # filter. If no points, then
|
| 593 |
+
# logger.info(f"{pointmap.shape=} {mask.shape=}")
|
| 594 |
+
assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
|
| 595 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 596 |
+
|
| 597 |
+
_scale, _shift = self._compute_scale_and_shift(pointmap, mask)
|
| 598 |
+
if scale is not None and self.allow_scale_and_shift_override:
|
| 599 |
+
_scale = scale
|
| 600 |
+
if shift is not None and self.allow_scale_and_shift_override:
|
| 601 |
+
_shift = shift
|
| 602 |
+
return_scale, return_shift = _scale, _shift
|
| 603 |
+
|
| 604 |
+
# Apply normalization
|
| 605 |
+
pointmap_normalized = _apply_metric_to_ssi(pointmap, return_scale, return_shift)
|
| 606 |
+
|
| 607 |
+
if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
|
| 608 |
+
new_norm = pointmap_normalized.norm(dim=0)
|
| 609 |
+
pointmap_normalized = torch.where(
|
| 610 |
+
new_norm > self.clip_beyond_scale,
|
| 611 |
+
torch.full_like(pointmap_normalized, float('nan')),
|
| 612 |
+
pointmap_normalized
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
return SSINormalizedPointmap(pointmap_normalized, return_scale, return_shift)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class ObjectApparentSizeSSI(SSIPointmapNormalizer):
|
| 619 |
+
def __init__(self,
|
| 620 |
+
clip_beyond_scale: Optional[float] = None,
|
| 621 |
+
use_scene_scale: bool = True,
|
| 622 |
+
scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train
|
| 623 |
+
):
|
| 624 |
+
self.clip_beyond_scale = clip_beyond_scale
|
| 625 |
+
self.use_scene_scale = use_scene_scale
|
| 626 |
+
self.scale_factor = scale_factor
|
| 627 |
+
|
| 628 |
+
def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 629 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 630 |
+
pointmap_flat = pointmap.reshape(3, -1)
|
| 631 |
+
|
| 632 |
+
if not self.use_scene_scale:
|
| 633 |
+
# Get valid points from the mask
|
| 634 |
+
mask_resized = torchvision.transforms.functional.resize(
|
| 635 |
+
mask, pointmap_size,
|
| 636 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
| 637 |
+
).squeeze(0)
|
| 638 |
+
mask_bool = mask_resized.reshape(-1) > 0.5
|
| 639 |
+
pointmap_flat = pointmap_flat[:, mask_bool]
|
| 640 |
+
|
| 641 |
+
# Median z-distance
|
| 642 |
+
median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0)
|
| 643 |
+
scale = median_z.expand(3) * self.scale_factor
|
| 644 |
+
shift = torch.zeros_like(scale)
|
| 645 |
+
# logger.info(f'median z = {median_z}')
|
| 646 |
+
return scale, shift
|
| 647 |
+
|
| 648 |
+
def normalize(self,
|
| 649 |
+
pointmap: torch.Tensor,
|
| 650 |
+
mask: torch.Tensor,
|
| 651 |
+
scale: Optional[torch.Tensor] = None,
|
| 652 |
+
shift: Optional[torch.Tensor] = None,
|
| 653 |
+
) -> torch.Tensor:
|
| 654 |
+
assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
|
| 655 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 656 |
+
|
| 657 |
+
if scale is None or shift is None:
|
| 658 |
+
scale, shift = self._get_scale_and_shift(pointmap, mask)
|
| 659 |
+
else:
|
| 660 |
+
assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
|
| 661 |
+
|
| 662 |
+
# Apply normalization and clip
|
| 663 |
+
pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift)
|
| 664 |
+
# logger.info(f"{pointmap_normalized.shape=}")
|
| 665 |
+
|
| 666 |
+
if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
|
| 667 |
+
pointmap_normalized = torch.where(
|
| 668 |
+
pointmap_normalized[-1, ...] > self.clip_beyond_scale,
|
| 669 |
+
torch.full_like(pointmap_normalized, float('nan')),
|
| 670 |
+
pointmap_normalized
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
# return pointmap_normalized, scale, shift
|
| 674 |
+
return SSINormalizedPointmap(pointmap_normalized, scale, shift)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
class NormalizedDisparitySpaceSSI(SSIPointmapNormalizer):
|
| 678 |
+
def __init__(self,
|
| 679 |
+
clip_beyond_scale: Optional[float] = None,
|
| 680 |
+
use_scene_scale: bool = True,
|
| 681 |
+
log_disparity_shift: float = 0.0,
|
| 682 |
+
):
|
| 683 |
+
self.clip_beyond_scale = clip_beyond_scale
|
| 684 |
+
self.use_scene_scale = use_scene_scale
|
| 685 |
+
self.point_remapper = PointRemapper(remap_type="exp_disparity")
|
| 686 |
+
self.log_disparity_shift = log_disparity_shift
|
| 687 |
+
|
| 688 |
+
def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor,
|
| 689 |
+
scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None,
|
| 690 |
+
) -> torch.Tensor:
|
| 691 |
+
assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format"
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
disparity_space_pointmap = self.point_remapper.forward(pointmap.permute(1, 2, 0)).permute(2, 0, 1)
|
| 695 |
+
if scale is None or shift is None:
|
| 696 |
+
scale, shift = self._get_scale_and_shift(disparity_space_pointmap, mask)
|
| 697 |
+
else:
|
| 698 |
+
assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format"
|
| 699 |
+
|
| 700 |
+
# pointmap_normalized = pointmap.clone().detach()
|
| 701 |
+
pointmap_normalized = _apply_metric_to_ssi(disparity_space_pointmap, scale, shift)
|
| 702 |
+
# logger.info(f"{pointmap_normalized.shape=}")
|
| 703 |
+
|
| 704 |
+
if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0:
|
| 705 |
+
pointmap_normalized = torch.where(
|
| 706 |
+
pointmap_normalized[2, ...].abs() > self.clip_beyond_scale,
|
| 707 |
+
torch.full_like(pointmap_normalized, float('nan')),
|
| 708 |
+
pointmap_normalized
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# return pointmap_normalized, scale, shift
|
| 712 |
+
return SSINormalizedPointmap(pointmap_normalized, scale, shift)
|
| 713 |
+
|
| 714 |
+
def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
|
| 715 |
+
pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True)
|
| 716 |
+
pointmap = self.point_remapper.inverse(pointmap.permute(1, 2, 0)).permute(2, 0, 1)
|
| 717 |
+
return pointmap
|
| 718 |
+
|
| 719 |
+
def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 720 |
+
pointmap_size = (pointmap.shape[1], pointmap.shape[2])
|
| 721 |
+
mask_resized = torchvision.transforms.functional.resize(
|
| 722 |
+
mask, pointmap_size,
|
| 723 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
| 724 |
+
).squeeze(0)
|
| 725 |
+
|
| 726 |
+
pointmap_flat = pointmap.reshape(3, -1)
|
| 727 |
+
if self.use_scene_scale:
|
| 728 |
+
median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0)
|
| 729 |
+
shift = torch.zeros_like(median_z.expand(3))
|
| 730 |
+
shift[-1, ...] = median_z[0] + self.log_disparity_shift
|
| 731 |
+
else:
|
| 732 |
+
# Get valid points from the mask (shift, x/z, y/z, log(z))
|
| 733 |
+
mask_bool = mask_resized.reshape(-1) > 0.5
|
| 734 |
+
pointmap_flat = pointmap_flat[:, mask_bool]
|
| 735 |
+
shift = pointmap_flat.nanmedian(dim=-1).values
|
| 736 |
+
|
| 737 |
+
scale = torch.ones_like(shift)
|
| 738 |
+
# logger.info(f'median z = {median_z}')
|
| 739 |
+
return scale, shift
|
| 740 |
+
|
| 741 |
+
def normalize_pointmap_ssi(pointmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 742 |
+
"""
|
| 743 |
+
Normalize pointmap using Scale-Shift Invariant (SSI) normalization.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W)
|
| 747 |
+
|
| 748 |
+
Returns:
|
| 749 |
+
Tuple of (normalized_pointmap, scale, shift)
|
| 750 |
+
"""
|
| 751 |
+
from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant
|
| 752 |
+
|
| 753 |
+
# Convert to (H, W, 3) if needed for get_scale_and_shift
|
| 754 |
+
if pointmap.shape[0] == 3:
|
| 755 |
+
pointmap_hw3 = pointmap.permute(1, 2, 0)
|
| 756 |
+
original_format = 'chw'
|
| 757 |
+
else:
|
| 758 |
+
pointmap_hw3 = pointmap
|
| 759 |
+
original_format = 'hwc'
|
| 760 |
+
|
| 761 |
+
# Get scale and shift using existing method
|
| 762 |
+
scale, shift = ScaleShiftInvariant.get_scale_and_shift(pointmap_hw3)
|
| 763 |
+
|
| 764 |
+
pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift)
|
| 765 |
+
return pointmap_normalized, scale, shift
|
| 766 |
+
|
| 767 |
+
def _apply_metric_to_ssi(pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, apply_inverse: bool = False) -> torch.Tensor:
|
| 768 |
+
"""
|
| 769 |
+
Normalize pointmap using Scale-Shift Invariant (SSI) normalization.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W)
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
Tuple of (normalized_pointmap, scale, shift)
|
| 776 |
+
"""
|
| 777 |
+
from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant
|
| 778 |
+
|
| 779 |
+
# Convert to (H, W, 3) if needed for get_scale_and_shift
|
| 780 |
+
if pointmap.shape[0] == 3:
|
| 781 |
+
pointmap_hw3 = pointmap.permute(1, 2, 0)
|
| 782 |
+
original_format = 'chw'
|
| 783 |
+
else:
|
| 784 |
+
pointmap_hw3 = pointmap
|
| 785 |
+
original_format = 'hwc'
|
| 786 |
+
|
| 787 |
+
# Apply normalization
|
| 788 |
+
ssi_to_metric = ScaleShiftInvariant.ssi_to_metric(scale, shift)
|
| 789 |
+
metric_to_ssi = ssi_to_metric.inverse()
|
| 790 |
+
transform_to_apply = metric_to_ssi
|
| 791 |
+
|
| 792 |
+
if apply_inverse:
|
| 793 |
+
transform_to_apply = ssi_to_metric
|
| 794 |
+
|
| 795 |
+
pointmap_flat = pointmap_hw3.reshape(-1, 3)
|
| 796 |
+
pointmap_normalized = transform_to_apply.transform_points(pointmap_flat)
|
| 797 |
+
|
| 798 |
+
# Reshape back to original format
|
| 799 |
+
if original_format == 'chw':
|
| 800 |
+
pointmap_normalized = pointmap_normalized.reshape(pointmap.shape[1], pointmap.shape[2], 3).permute(2, 0, 1)
|
| 801 |
+
else:
|
| 802 |
+
pointmap_normalized = pointmap_normalized.reshape(pointmap_hw3.shape)
|
| 803 |
+
|
| 804 |
+
return pointmap_normalized
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def perturb_mask_translation(
|
| 808 |
+
image: torch.Tensor,
|
| 809 |
+
mask: torch.Tensor,
|
| 810 |
+
max_px_delta: int = 5,
|
| 811 |
+
):
|
| 812 |
+
"""
|
| 813 |
+
Applies data augmentation to the mask by randomly translating the mask.
|
| 814 |
+
|
| 815 |
+
Args:
|
| 816 |
+
image: (C, H, W) float32 [0, 1] tensor.
|
| 817 |
+
mask: (1, H, W) float32 [0, 1] tensor.
|
| 818 |
+
max_px_delta: The maximum number of pixels we will randomly shift by in each 2D direction.
|
| 819 |
+
"""
|
| 820 |
+
dx = random.randint(-max_px_delta, max_px_delta)
|
| 821 |
+
dy = random.randint(-max_px_delta, max_px_delta)
|
| 822 |
+
|
| 823 |
+
mask = mask.squeeze(0)
|
| 824 |
+
mask = torch.roll(mask, shifts=(dy, dx), dims=(0, 1))
|
| 825 |
+
|
| 826 |
+
# Zero out wrapped regions
|
| 827 |
+
if dy > 0:
|
| 828 |
+
mask[:dy, :] = 0
|
| 829 |
+
elif dy < 0:
|
| 830 |
+
mask[dy:, :] = 0
|
| 831 |
+
if dx > 0:
|
| 832 |
+
mask[:, :dx] = 0
|
| 833 |
+
elif dx < 0:
|
| 834 |
+
mask[:, dx:] = 0
|
| 835 |
+
|
| 836 |
+
mask = mask.unsqueeze(0)
|
| 837 |
+
return image, mask
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def perturb_mask_boundary(
|
| 841 |
+
image: torch.Tensor,
|
| 842 |
+
mask: torch.Tensor,
|
| 843 |
+
kernel_range: tuple[int, int] = (2, 5),
|
| 844 |
+
p_erode: float = 0.1,
|
| 845 |
+
p_dilate: float = 0.8,
|
| 846 |
+
**kwargs,
|
| 847 |
+
):
|
| 848 |
+
"""
|
| 849 |
+
Applies data augmentation to the mask by randomly eroding or dilating the mask.
|
| 850 |
+
|
| 851 |
+
Args:
|
| 852 |
+
image: (C, H, W) float32 [0, 1] tensor.
|
| 853 |
+
mask: (1, H, W) float32 [0, 1] tensor.
|
| 854 |
+
kernel_range: Range of kernel sizes to sample from.
|
| 855 |
+
p_erode: Probability of erosion.
|
| 856 |
+
p_dilate: Probability of dilation.
|
| 857 |
+
kwargs: Kwargs for the cv2 erode/dilate function.
|
| 858 |
+
"""
|
| 859 |
+
import cv2
|
| 860 |
+
|
| 861 |
+
C, H, W = image.shape
|
| 862 |
+
assert mask.shape == (1, H, W)
|
| 863 |
+
assert mask.dtype == torch.float32
|
| 864 |
+
assert torch.all((mask == 0) | (mask == 1)), "Mask must be binary (0 or 1)"
|
| 865 |
+
|
| 866 |
+
p_none = 1.0 - p_erode - p_dilate
|
| 867 |
+
assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid."
|
| 868 |
+
|
| 869 |
+
# Sample operation.
|
| 870 |
+
op = random.choices(["erode", "dilate", "none"], weights=[p_erode, p_dilate, p_none], k=1)[0]
|
| 871 |
+
|
| 872 |
+
if op == "none":
|
| 873 |
+
pass
|
| 874 |
+
else:
|
| 875 |
+
# Sample kernel size
|
| 876 |
+
ksize = random.randint(*kernel_range)
|
| 877 |
+
kernel = np.ones((ksize, ksize), np.uint8)
|
| 878 |
+
|
| 879 |
+
mask = mask.squeeze().cpu().numpy().astype(np.uint8) # (H, W)
|
| 880 |
+
|
| 881 |
+
if op == "erode":
|
| 882 |
+
mask = cv2.erode(mask, kernel, **kwargs)
|
| 883 |
+
elif op == "dilate":
|
| 884 |
+
mask = cv2.dilate(mask, kernel, **kwargs)
|
| 885 |
+
else:
|
| 886 |
+
raise NotImplementedError
|
| 887 |
+
|
| 888 |
+
mask = torch.from_numpy(mask).float()[None] # (1, H, W)
|
| 889 |
+
|
| 890 |
+
return image, mask
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def resolution_blur(
|
| 894 |
+
image: torch.Tensor,
|
| 895 |
+
mask: torch.Tensor,
|
| 896 |
+
scale_range=(0.05, 0.95),
|
| 897 |
+
interpolation_down=tv_transforms.InterpolationMode.BICUBIC,
|
| 898 |
+
interpolation_up=tv_transforms.InterpolationMode.BICUBIC,
|
| 899 |
+
):
|
| 900 |
+
"""
|
| 901 |
+
Blur the input image by applying upsample(downsample(x)).
|
| 902 |
+
|
| 903 |
+
Args:
|
| 904 |
+
image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1].
|
| 905 |
+
mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged.
|
| 906 |
+
scale_range: Tuple of (min_scale, max_scale) for downsampling.
|
| 907 |
+
interpolation_down: Interpolation mode for downsampling.
|
| 908 |
+
interpolation_up: Interpolation mode for upsampling.
|
| 909 |
+
"""
|
| 910 |
+
C, H, W = image.shape
|
| 911 |
+
scale = random.uniform(*scale_range)
|
| 912 |
+
new_H, new_W = max(1, int(H * scale)), max(1, int(W * scale))
|
| 913 |
+
|
| 914 |
+
# Downsample
|
| 915 |
+
image = TF.resize(image, size=[new_H, new_W], interpolation=interpolation_down)
|
| 916 |
+
|
| 917 |
+
# Upsample back to original size
|
| 918 |
+
image = TF.resize(image, size=[H, W], interpolation=interpolation_up)
|
| 919 |
+
|
| 920 |
+
return image, mask
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def gaussian_blur(
|
| 924 |
+
image: torch.Tensor,
|
| 925 |
+
mask: torch.Tensor,
|
| 926 |
+
kernel_range: tuple[int, int] = (3, 15),
|
| 927 |
+
sigma_range: tuple[int, int] = (0.1, 4.0),
|
| 928 |
+
):
|
| 929 |
+
"""
|
| 930 |
+
Apply gaussian blur to the input image.
|
| 931 |
+
|
| 932 |
+
Args:
|
| 933 |
+
image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1].
|
| 934 |
+
mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged.
|
| 935 |
+
kernel_range (tuple): Range of odd kernel sizes to sample from for the Gaussian blur (min, max).
|
| 936 |
+
sigma_range (tuple): Range of sigma values (standard deviation) to sample from for the Gaussian kernel (min, max).
|
| 937 |
+
"""
|
| 938 |
+
kernel_size = random.choice([k for k in range(kernel_range[0], kernel_range[1]+1) if k % 2 == 1])
|
| 939 |
+
sigma = random.uniform(*sigma_range)
|
| 940 |
+
pad = kernel_size // 2
|
| 941 |
+
|
| 942 |
+
# Step 1: Pad the image
|
| 943 |
+
image = F.pad(image.unsqueeze(0), (pad, pad, pad, pad), mode='replicate')
|
| 944 |
+
|
| 945 |
+
# Step 2: Apply gaussian blur
|
| 946 |
+
image = TF.gaussian_blur(image, kernel_size=[kernel_size, kernel_size], sigma=sigma)
|
| 947 |
+
|
| 948 |
+
# Step 3: Unpad to get back to original size
|
| 949 |
+
image = image[:, :, pad:-pad, pad:-pad]
|
| 950 |
+
|
| 951 |
+
return image.squeeze(0), mask
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def apply_blur_augmentation(
|
| 955 |
+
image: torch.Tensor,
|
| 956 |
+
mask: torch.Tensor,
|
| 957 |
+
p_resolution: float = 0.33,
|
| 958 |
+
p_gaussian: float = 0.33,
|
| 959 |
+
gaussian_kwargs: dict = None,
|
| 960 |
+
resolution_kwargs: dict = None,
|
| 961 |
+
):
|
| 962 |
+
"""Apply blur augmentation with configurable parameters"""
|
| 963 |
+
|
| 964 |
+
# Handle None defaults BEFORE unpacking
|
| 965 |
+
if gaussian_kwargs is None:
|
| 966 |
+
gaussian_kwargs = {}
|
| 967 |
+
if resolution_kwargs is None:
|
| 968 |
+
resolution_kwargs = {}
|
| 969 |
+
|
| 970 |
+
p_none = 1.0 - p_gaussian - p_resolution
|
| 971 |
+
assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid."
|
| 972 |
+
|
| 973 |
+
operation = random.choices(
|
| 974 |
+
["gaussian", "resolution", "none"],
|
| 975 |
+
weights=[p_gaussian, p_resolution, p_none],
|
| 976 |
+
k=1
|
| 977 |
+
)[0]
|
| 978 |
+
|
| 979 |
+
if operation == "gaussian":
|
| 980 |
+
return gaussian_blur(image, mask, **gaussian_kwargs)
|
| 981 |
+
elif operation == "resolution":
|
| 982 |
+
return resolution_blur(image, mask, **resolution_kwargs)
|
| 983 |
+
elif operation == "none":
|
| 984 |
+
return image, mask
|
| 985 |
+
else:
|
| 986 |
+
raise NotImplementedError
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from torchvision.transforms import functional as tv_F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
| 14 |
+
"""
|
| 15 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
| 16 |
+
This may lead to results different with torchvision's version.
|
| 17 |
+
Following BYOL's TF code:
|
| 18 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def get_params(img, scale, ratio):
|
| 23 |
+
width, height = tv_F._get_image_size(img)
|
| 24 |
+
area = height * width
|
| 25 |
+
|
| 26 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
| 27 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
| 28 |
+
aspect_ratio = torch.exp(
|
| 29 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
| 30 |
+
).item()
|
| 31 |
+
|
| 32 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 33 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 34 |
+
|
| 35 |
+
w = min(w, width)
|
| 36 |
+
h = min(h, height)
|
| 37 |
+
|
| 38 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 39 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 40 |
+
|
| 41 |
+
return i, j, h, w
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# following PT3D CO3D data to pad image
|
| 45 |
+
def pad_to_square(image, value=0):
|
| 46 |
+
_, _, h, w = image.shape # Assuming image is in (B, C, H, W) format
|
| 47 |
+
if h == w:
|
| 48 |
+
return image # The image is already square
|
| 49 |
+
|
| 50 |
+
# Calculate the padding
|
| 51 |
+
diff = abs(h - w)
|
| 52 |
+
pad2 = diff
|
| 53 |
+
|
| 54 |
+
# Pad the image to make it square
|
| 55 |
+
if h > w:
|
| 56 |
+
padding = (0, pad2, 0, 0) # Pad width (left, right, top, bottom)
|
| 57 |
+
else:
|
| 58 |
+
padding = (0, 0, 0, pad2) # Pad height
|
| 59 |
+
# Apply padding
|
| 60 |
+
padded_image = torch.nn.functional.pad(image, padding, mode="constant", value=value)
|
| 61 |
+
return padded_image
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def preprocess_img(
|
| 65 |
+
x,
|
| 66 |
+
mask=None,
|
| 67 |
+
img_target_shape=224,
|
| 68 |
+
mask_target_shape=256,
|
| 69 |
+
normalize=False,
|
| 70 |
+
):
|
| 71 |
+
if x.shape[1] != x.shape[2]:
|
| 72 |
+
x = pad_to_square(x)
|
| 73 |
+
if mask is not None and mask.shape[1] != mask.shape[2]:
|
| 74 |
+
mask = pad_to_square(mask)
|
| 75 |
+
if x.shape[2] != img_target_shape:
|
| 76 |
+
x = F.interpolate(
|
| 77 |
+
x,
|
| 78 |
+
size=(img_target_shape, img_target_shape),
|
| 79 |
+
# scale_factor=float(img_target_shape)/x.shape[2],
|
| 80 |
+
mode="bilinear",
|
| 81 |
+
)
|
| 82 |
+
if mask is not None and mask.shape[2] != mask_target_shape:
|
| 83 |
+
if mask is not None:
|
| 84 |
+
mask = F.interpolate(
|
| 85 |
+
mask,
|
| 86 |
+
size=(mask_target_shape, mask_target_shape),
|
| 87 |
+
# scale_factor=float(mask_target_shape)/mask.shape[2],
|
| 88 |
+
mode="nearest",
|
| 89 |
+
)
|
| 90 |
+
if normalize:
|
| 91 |
+
imgs_normed = resnet_img_normalization(x)
|
| 92 |
+
else:
|
| 93 |
+
imgs_normed = x
|
| 94 |
+
return imgs_normed, mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def resnet_img_normalization(x):
|
| 98 |
+
resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape(
|
| 99 |
+
(3, 1, 1)
|
| 100 |
+
)
|
| 101 |
+
resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((3, 1, 1))
|
| 102 |
+
if x.ndim == 4:
|
| 103 |
+
resnet_mean = resnet_mean[None]
|
| 104 |
+
resnet_std = resnet_std[None]
|
| 105 |
+
x = (x - resnet_mean) / resnet_std
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# pad image to be centered for unprojecting depth
|
| 110 |
+
def pad_to_square_centered(image, value=0, pointmap=None):
|
| 111 |
+
h, w = image.shape[-2], image.shape[-1] # Assuming image is in (B, C, H, W) format
|
| 112 |
+
if h == w:
|
| 113 |
+
if pointmap is not None:
|
| 114 |
+
return image, pointmap
|
| 115 |
+
return image # The image is already square
|
| 116 |
+
|
| 117 |
+
# Calculate the padding
|
| 118 |
+
diff = abs(h - w)
|
| 119 |
+
pad1 = diff // 2
|
| 120 |
+
pad2 = diff - pad1
|
| 121 |
+
|
| 122 |
+
# Pad the image to make it square
|
| 123 |
+
if h > w:
|
| 124 |
+
padding = (pad1, pad2, 0, 0) # Pad width (left, right, top, bottom)
|
| 125 |
+
else:
|
| 126 |
+
padding = (0, 0, pad1, pad2) # Pad height
|
| 127 |
+
# Apply padding to image
|
| 128 |
+
padded_image = F.pad(image, padding, mode="constant", value=value)
|
| 129 |
+
|
| 130 |
+
# Apply padding to pointmap if provided
|
| 131 |
+
if pointmap is not None:
|
| 132 |
+
# Pad pointmap using torch functional with NaN fill value
|
| 133 |
+
padded_pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan"))
|
| 134 |
+
|
| 135 |
+
return padded_image, padded_pointmap
|
| 136 |
+
return padded_image
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def crop_img_to_obj(mask, context_size):
|
| 140 |
+
nonzeros = torch.nonzero(mask)
|
| 141 |
+
if len(nonzeros) > 0:
|
| 142 |
+
r_max, c_max = nonzeros.max(dim=0)[0]
|
| 143 |
+
r_min, c_min = nonzeros.min(dim=0)[0]
|
| 144 |
+
box_h = max(1, r_max - r_min)
|
| 145 |
+
box_w = max(1, c_max - c_min)
|
| 146 |
+
left = max(0, c_min - int(box_w * context_size))
|
| 147 |
+
right = min(mask.shape[-1], c_max + int(box_w * context_size))
|
| 148 |
+
top = max(0, r_min - int(box_h * context_size))
|
| 149 |
+
bot = min(mask.shape[-2], r_max + int(box_h * context_size))
|
| 150 |
+
return left, right, top, bot
|
| 151 |
+
return None, None, None, None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def random_pad(img, mask=None, max_ratio=0.0, pointmap=None):
|
| 155 |
+
max_size = int(max(img.shape) * max_ratio)
|
| 156 |
+
padding = tuple([random.randint(0, max_size) for _ in range(4)])
|
| 157 |
+
img = F.pad(img, padding)
|
| 158 |
+
if mask is not None:
|
| 159 |
+
mask = F.pad(mask, padding)
|
| 160 |
+
|
| 161 |
+
if pointmap is not None:
|
| 162 |
+
pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan"))
|
| 163 |
+
return img, mask, pointmap
|
| 164 |
+
return img, mask
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_img_color_augmentation(
|
| 168 |
+
color_jit_prob=0.5,
|
| 169 |
+
gaussian_blur_prob=0.1,
|
| 170 |
+
):
|
| 171 |
+
transform = transforms.Compose(
|
| 172 |
+
[
|
| 173 |
+
# (a) Random Color Jitter
|
| 174 |
+
transforms.RandomApply(
|
| 175 |
+
[
|
| 176 |
+
transforms.ColorJitter(
|
| 177 |
+
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
|
| 178 |
+
)
|
| 179 |
+
],
|
| 180 |
+
p=color_jit_prob,
|
| 181 |
+
),
|
| 182 |
+
# (b) Randomly apply GaussianBlur
|
| 183 |
+
transforms.RandomApply(
|
| 184 |
+
[transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))],
|
| 185 |
+
p=gaussian_blur_prob,
|
| 186 |
+
),
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
return transform
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py
ADDED
|
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Dict, Optional, Tuple, Any
|
| 4 |
+
from dataclasses import dataclass, asdict, field
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
from sam3d_objects.data.utils import expand_as_right, tree_tensor_map
|
| 8 |
+
from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform
|
| 9 |
+
from pytorch3d.transforms import Transform3d, quaternion_to_matrix, matrix_to_quaternion
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class InstancePose:
|
| 14 |
+
"""
|
| 15 |
+
Stores the pose of an object.
|
| 16 |
+
Also, stores some information about the scene that was used to normalize the pose.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
instance_scale_l2c: torch.Tensor
|
| 20 |
+
instance_position_l2c: torch.Tensor
|
| 21 |
+
instance_quaternion_l2c: torch.Tensor
|
| 22 |
+
scene_scale: torch.Tensor
|
| 23 |
+
scene_shift: torch.Tensor
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def _broadcast_postcompose(
|
| 27 |
+
cls,
|
| 28 |
+
scale: torch.Tensor,
|
| 29 |
+
rotation: torch.Tensor,
|
| 30 |
+
translation: torch.Tensor,
|
| 31 |
+
transform_to_postcompose: Transform3d,
|
| 32 |
+
) -> Transform3d:
|
| 33 |
+
"""
|
| 34 |
+
Assumes scale, rotation, translation are of shape:
|
| 35 |
+
B, K, C
|
| 36 |
+
---
|
| 37 |
+
B: batch size
|
| 38 |
+
K: number of objects
|
| 39 |
+
C: number of channels
|
| 40 |
+
|
| 41 |
+
Takes a transform where
|
| 42 |
+
get_matrix() has shape (B, 3, 3)
|
| 43 |
+
|
| 44 |
+
Returns pose.compose(transform_to_postcompose)
|
| 45 |
+
"""
|
| 46 |
+
scale_c = scale.shape[-1]
|
| 47 |
+
ndim_orig = scale.ndim
|
| 48 |
+
if ndim_orig == 3:
|
| 49 |
+
b, k, _ = scale.shape
|
| 50 |
+
elif ndim_orig == 2:
|
| 51 |
+
b = scale.shape[0]
|
| 52 |
+
k = 1
|
| 53 |
+
elif ndim_orig == 1:
|
| 54 |
+
b = 1
|
| 55 |
+
k = 1
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Invalid scale shape: {scale.shape}")
|
| 58 |
+
|
| 59 |
+
# Create transform of shape (B * K)
|
| 60 |
+
wide = {"scale": scale, "rotation": rotation, "translation": translation}
|
| 61 |
+
shapes_orig = {k: v.shape for k, v in wide.items()}
|
| 62 |
+
long = tree_tensor_map(lambda x: x.reshape(b * k, x.shape[-1]), wide)
|
| 63 |
+
long["rotation"] = quaternion_to_matrix(long["rotation"])
|
| 64 |
+
if scale_c == 1:
|
| 65 |
+
long["scale"] = long["scale"].expand(b * k, 3)
|
| 66 |
+
|
| 67 |
+
composed = compose_transform(**long)
|
| 68 |
+
|
| 69 |
+
# Apply transform to shape (B * K)
|
| 70 |
+
pc_transform = transform_to_postcompose.get_matrix()
|
| 71 |
+
pc_transform = pc_transform.repeat(k, 1, 1)
|
| 72 |
+
stacked_pc_transform = Transform3d(matrix=pc_transform)
|
| 73 |
+
assert stacked_pc_transform.get_matrix().shape == composed.get_matrix().shape
|
| 74 |
+
postcomposed = composed.compose(stacked_pc_transform)
|
| 75 |
+
|
| 76 |
+
# Decompose transform to shape (B, K, C)
|
| 77 |
+
scale, rotation, translation = decompose_transform(postcomposed)
|
| 78 |
+
rotation = matrix_to_quaternion(rotation)
|
| 79 |
+
pc_long = {"scale": scale, "rotation": rotation, "translation": translation}
|
| 80 |
+
pc_wide = tree_tensor_map(lambda x: x.reshape(b, k, x.shape[-1]), pc_long)
|
| 81 |
+
if scale_c == 1:
|
| 82 |
+
pc_wide["scale"] = pc_wide["scale"][..., 0].unsqueeze(-1)
|
| 83 |
+
for k, shape in shapes_orig.items():
|
| 84 |
+
pc_wide[k] = pc_wide[k].reshape(*shape)
|
| 85 |
+
return pc_wide["scale"], pc_wide["rotation"], pc_wide["translation"]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class PoseTarget:
|
| 90 |
+
x_instance_scale: torch.Tensor
|
| 91 |
+
x_instance_rotation: torch.Tensor
|
| 92 |
+
x_instance_translation: torch.Tensor
|
| 93 |
+
x_scene_scale: torch.Tensor
|
| 94 |
+
x_scene_center: torch.Tensor
|
| 95 |
+
x_translation_scale: torch.Tensor
|
| 96 |
+
pose_target_convention: str = field(default="unknown")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class InvariantPoseTarget:
|
| 101 |
+
"""
|
| 102 |
+
This is the canonical representation of pose targets, used for computing metrics.
|
| 103 |
+
instance_pose <-> invariant_pose_targets <-> all other pose_target_conventions
|
| 104 |
+
|
| 105 |
+
Background:
|
| 106 |
+
---
|
| 107 |
+
We want to estimate a transformation T: R³ → R³ despite scene scale ambiguity.
|
| 108 |
+
|
| 109 |
+
The transformation taking object points to scene points is defined as
|
| 110 |
+
T(x) = s · R(q) · x + t
|
| 111 |
+
where:
|
| 112 |
+
- x is a point in the object coordinate frame,
|
| 113 |
+
- q is a unit quaternion representing rotation,
|
| 114 |
+
- s is the object-to-scene scale, and
|
| 115 |
+
- t is the translation.
|
| 116 |
+
|
| 117 |
+
However, there is an inherent scale ambiguity in the scene, denoted as s_scene;
|
| 118 |
+
This ambiguity introduces irreducible error that complicates both evaluation and training.
|
| 119 |
+
|
| 120 |
+
To decouple the scene scale from the invariant quantities, we define:
|
| 121 |
+
T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ]
|
| 122 |
+
where we define
|
| 123 |
+
t_rel = t / s_scene
|
| 124 |
+
s_rel = s / s_scene
|
| 125 |
+
s_tilde = s_rel / |t_rel|
|
| 126 |
+
t_unit = t_rel / |t_rel|
|
| 127 |
+
|
| 128 |
+
During training, you would predict (q, s_tilde, t_unit), leaving s_scene separate.
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
Hand-wavy error analysis:
|
| 132 |
+
---
|
| 133 |
+
1. Naive (coupled) estimate:
|
| 134 |
+
T(x) = s_scene [ s_rel · R(q) · x + t_rel ]
|
| 135 |
+
|
| 136 |
+
We can define:
|
| 137 |
+
U = ln(s_rel)
|
| 138 |
+
V = ln(|t_rel|)
|
| 139 |
+
so that the error is governed by Var(U + V).
|
| 140 |
+
|
| 141 |
+
2. In the decoupled case, we have:
|
| 142 |
+
T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ]
|
| 143 |
+
= s_scene · |t_rel| [ (s_rel / |t_rel|) R(q) · x + t_unit ]
|
| 144 |
+
Then ln(s_tilde) = ln(s_rel) - ln(|t_rel|) = U - V, and the error is
|
| 145 |
+
Var(U - V) = Var(U) + Var(V) - 2Cov(U, V).
|
| 146 |
+
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
# These are invariant
|
| 150 |
+
q: torch.Tensor
|
| 151 |
+
t_unit: torch.Tensor
|
| 152 |
+
s_scene: torch.Tensor
|
| 153 |
+
t_scene_center: Optional[torch.Tensor] = None
|
| 154 |
+
t_rel_norm: Optional[torch.Tensor] = None
|
| 155 |
+
s_tilde: Optional[torch.Tensor] = None
|
| 156 |
+
s_rel: Optional[torch.Tensor] = None
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
# Check that fields that are required always have values.
|
| 160 |
+
if self.q is None:
|
| 161 |
+
raise ValueError("Field 'q' (quaternion) must be provided.")
|
| 162 |
+
if self.s_scene is None:
|
| 163 |
+
raise ValueError("Field 's_scene' must be provided.")
|
| 164 |
+
if self.s_rel is None:
|
| 165 |
+
if self.s_tilde is not None:
|
| 166 |
+
self.s_rel = self.s_tilde * self.t_rel_norm
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError("Field 's_rel' or 's_tilde' must be provided.")
|
| 169 |
+
if self.t_unit is None:
|
| 170 |
+
raise ValueError("Field 't_unit' must be provided.")
|
| 171 |
+
|
| 172 |
+
if self.t_scene_center is None:
|
| 173 |
+
self.t_scene_center = torch.zeros_like(self.t_unit)
|
| 174 |
+
|
| 175 |
+
# There is a simple relationship between s_tilde and t_rel_norm:
|
| 176 |
+
# s_tilde = s_rel / t_rel_norm
|
| 177 |
+
#
|
| 178 |
+
# If one of these is missing and the other is provided, we can compute the missing field.
|
| 179 |
+
if self.s_tilde is None and self.t_rel_norm is not None:
|
| 180 |
+
self.s_tilde = self.s_rel / self.t_rel_norm
|
| 181 |
+
elif self.t_rel_norm is None and self.s_tilde is not None:
|
| 182 |
+
self.t_rel_norm = self.s_rel / self.s_tilde
|
| 183 |
+
|
| 184 |
+
# If both are provided, we check for consistency.
|
| 185 |
+
if self.s_tilde is not None and self.t_rel_norm is not None:
|
| 186 |
+
computed_s_tilde = self.s_rel / self.t_rel_norm
|
| 187 |
+
# If the provided s_tilde deviates from what is computed, update it.
|
| 188 |
+
if not torch.allclose(self.s_tilde, computed_s_tilde, atol=1e-6):
|
| 189 |
+
logger.warning(
|
| 190 |
+
f"s_tilde and t_rel_norm are provided, but they are not consistent. "
|
| 191 |
+
f"Updating s_tilde to {computed_s_tilde}."
|
| 192 |
+
)
|
| 193 |
+
self.s_tilde = computed_s_tilde
|
| 194 |
+
|
| 195 |
+
self._validate_fields()
|
| 196 |
+
|
| 197 |
+
def _validate_fields(self):
|
| 198 |
+
for field in self.__dict__:
|
| 199 |
+
if self.__dict__[field] is None:
|
| 200 |
+
raise ValueError(f"Field '{field}' must be provided.")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def from_instance_pose(instance_pose: InstancePose) -> "InvariantPoseTarget":
|
| 205 |
+
q = instance_pose.instance_quaternion_l2c
|
| 206 |
+
s_obj_to_scene = instance_pose.instance_scale_l2c # (..., 1) or (..., 3)
|
| 207 |
+
t_obj_to_scene = instance_pose.instance_position_l2c # (..., 3)
|
| 208 |
+
s_scene = instance_pose.scene_scale # (..., 1) or scalar-broadcastable
|
| 209 |
+
t_scene_center = instance_pose.scene_shift # (..., 3)
|
| 210 |
+
|
| 211 |
+
# Normalize to scene scale (per the derivation)
|
| 212 |
+
if not ( s_obj_to_scene.ndim == (s_scene.ndim + 1)):
|
| 213 |
+
raise ValueError(f"s_scene should be ND [...,3] and s_obj_to_scene should be (N+1)D [...,K,3], but got {s_scene.shape=} {s_obj_to_scene.shape=}")
|
| 214 |
+
if not (t_obj_to_scene.ndim == (s_scene.ndim + 1)):
|
| 215 |
+
raise ValueError(f"t_scene_center should be ND [B,3] and t_obj_to_scene should be (N+1)D [B,K,3], but got {t_scene_center.shape=} {t_obj_to_scene.shape=}")
|
| 216 |
+
s_scene_exp = s_scene.unsqueeze(-2)
|
| 217 |
+
|
| 218 |
+
s_rel = s_obj_to_scene / s_scene_exp
|
| 219 |
+
t_rel = t_obj_to_scene / s_scene_exp
|
| 220 |
+
|
| 221 |
+
# Robust norms
|
| 222 |
+
eps = 1e-8
|
| 223 |
+
t_rel_norm = t_rel.norm(dim=-1, keepdim=True).clamp_min(eps)
|
| 224 |
+
|
| 225 |
+
s_tilde = s_rel / t_rel_norm
|
| 226 |
+
t_unit = t_rel / t_rel_norm
|
| 227 |
+
|
| 228 |
+
return InvariantPoseTarget(
|
| 229 |
+
q=q,
|
| 230 |
+
s_scene=s_scene,
|
| 231 |
+
t_scene_center=t_scene_center,
|
| 232 |
+
s_rel=s_rel,
|
| 233 |
+
s_tilde=s_tilde,
|
| 234 |
+
t_unit=t_unit,
|
| 235 |
+
t_rel_norm=t_rel_norm,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def to_instance_pose(invariant_targets: "InvariantPoseTarget") -> InstancePose:
|
| 241 |
+
# scale factor per the derivation: s_scene * |t_rel|
|
| 242 |
+
# Normalize to scene scale (per the derivation)
|
| 243 |
+
t_rel_norm_ndim = invariant_targets.t_rel_norm.ndim
|
| 244 |
+
if not (invariant_targets.s_scene.ndim == (t_rel_norm_ndim - 1)) :
|
| 245 |
+
raise ValueError(f"s_scene should be ND [...,3] and t_rel_norm should be (N+1)D [...,K,3], but got {invariant_targets.s_scene.shape=} {invariant_targets.t_rel_norm.shape=}")
|
| 246 |
+
|
| 247 |
+
scale = invariant_targets.s_scene.unsqueeze(-2) * invariant_targets.t_rel_norm
|
| 248 |
+
return InstancePose(
|
| 249 |
+
instance_scale_l2c=invariant_targets.s_tilde * scale,
|
| 250 |
+
instance_position_l2c=invariant_targets.t_unit * scale,
|
| 251 |
+
instance_quaternion_l2c=invariant_targets.q,
|
| 252 |
+
scene_scale=invariant_targets.s_scene,
|
| 253 |
+
scene_shift=invariant_targets.t_scene_center,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class PoseTargetConvention:
|
| 258 |
+
"""
|
| 259 |
+
Converts pose_targets <-> instance_pose <-> invariant_pose_targets
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
pose_target_convention: str
|
| 263 |
+
|
| 264 |
+
@classmethod
|
| 265 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
|
| 266 |
+
raise NotImplementedError("Implement this in a subclass")
|
| 267 |
+
|
| 268 |
+
@classmethod
|
| 269 |
+
def to_invariant(cls, instance_pose: InstancePose) -> InvariantPoseTarget:
|
| 270 |
+
raise NotImplementedError("Implement this in a subclass")
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget:
|
| 274 |
+
invariant_targets = InvariantPoseTarget.from_instance_pose(instance_pose)
|
| 275 |
+
return cls.from_invariant(invariant_targets)
|
| 276 |
+
|
| 277 |
+
@classmethod
|
| 278 |
+
def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose:
|
| 279 |
+
invariant_targets = cls.to_invariant(pose_target)
|
| 280 |
+
return InvariantPoseTarget.to_instance_pose(invariant_targets)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class ScaleShiftInvariant(PoseTargetConvention):
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
Midas eq. (6): https://arxiv.org/pdf/1907.01341v3
|
| 287 |
+
But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
pose_target_convention: str = "ScaleShiftInvariant"
|
| 291 |
+
scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32)
|
| 292 |
+
scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32)
|
| 293 |
+
translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32)
|
| 294 |
+
translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32)
|
| 295 |
+
|
| 296 |
+
@classmethod
|
| 297 |
+
def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
|
| 298 |
+
metric_to_ssi = cls.ssi_to_metric(
|
| 299 |
+
instance_pose.scene_scale, instance_pose.scene_shift
|
| 300 |
+
).inverse()
|
| 301 |
+
|
| 302 |
+
ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose(
|
| 303 |
+
scale=instance_pose.instance_scale_l2c,
|
| 304 |
+
rotation=instance_pose.instance_quaternion_l2c,
|
| 305 |
+
translation=instance_pose.instance_position_l2c,
|
| 306 |
+
transform_to_postcompose=metric_to_ssi,
|
| 307 |
+
)
|
| 308 |
+
# logger.info(f"{normalize=} {ssi_scale.shape=} {ssi_rotation.shape=} {ssi_translation.shape=}")
|
| 309 |
+
if normalize:
|
| 310 |
+
device = ssi_scale.device
|
| 311 |
+
ssi_scale = (ssi_scale - cls.scale_mean.to(device)) / cls.scale_std.to(device)
|
| 312 |
+
ssi_translation = (ssi_translation - cls.translation_mean.to(device)) / cls.translation_std.to(device)
|
| 313 |
+
|
| 314 |
+
return PoseTarget(
|
| 315 |
+
x_instance_scale=ssi_scale,
|
| 316 |
+
x_instance_rotation=ssi_rotation,
|
| 317 |
+
x_instance_translation=ssi_translation,
|
| 318 |
+
x_scene_scale=instance_pose.scene_scale,
|
| 319 |
+
x_scene_center=instance_pose.scene_shift,
|
| 320 |
+
x_translation_scale=torch.ones_like(ssi_scale)[..., 0].unsqueeze(-1),
|
| 321 |
+
pose_target_convention=cls.pose_target_convention,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
@classmethod
|
| 325 |
+
def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
|
| 326 |
+
scene_scale = pose_target.x_scene_scale
|
| 327 |
+
scene_shift = pose_target.x_scene_center
|
| 328 |
+
ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift)
|
| 329 |
+
|
| 330 |
+
if normalize:
|
| 331 |
+
device = pose_target.x_instance_scale.device
|
| 332 |
+
pose_target.x_instance_scale = pose_target.x_instance_scale * cls.scale_std.to(device) + cls.scale_mean.to(device)
|
| 333 |
+
pose_target.x_instance_translation = pose_target.x_instance_translation * cls.translation_std.to(device) + cls.translation_mean.to(device)
|
| 334 |
+
|
| 335 |
+
ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose(
|
| 336 |
+
scale=pose_target.x_instance_scale,
|
| 337 |
+
rotation=pose_target.x_instance_rotation,
|
| 338 |
+
translation=pose_target.x_instance_translation,
|
| 339 |
+
transform_to_postcompose=ssi_to_metric,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
return InstancePose(
|
| 343 |
+
instance_scale_l2c=ins_scale,
|
| 344 |
+
instance_position_l2c=ins_translation,
|
| 345 |
+
instance_quaternion_l2c=ins_rotation,
|
| 346 |
+
scene_scale=scene_scale,
|
| 347 |
+
scene_shift=scene_shift,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
@classmethod
|
| 351 |
+
def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget:
|
| 352 |
+
instance_pose = cls.to_instance_pose(pose_target, normalize=normalize)
|
| 353 |
+
return InvariantPoseTarget.from_instance_pose(instance_pose)
|
| 354 |
+
|
| 355 |
+
@classmethod
|
| 356 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget:
|
| 357 |
+
instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
|
| 358 |
+
return cls.from_instance_pose(instance_pose, normalize=normalize)
|
| 359 |
+
|
| 360 |
+
@classmethod
|
| 361 |
+
def get_scale_and_shift(cls, pointmap):
|
| 362 |
+
shift_z = pointmap[..., -1].nanmedian().unsqueeze(0)
|
| 363 |
+
shift = torch.zeros_like(shift_z.expand(1, 3))
|
| 364 |
+
shift[..., -1] = shift_z
|
| 365 |
+
|
| 366 |
+
shifted_pointmap = pointmap - shift
|
| 367 |
+
scale = shifted_pointmap.abs().nanmean().to(shift.device)
|
| 368 |
+
|
| 369 |
+
shift = shift.reshape(3)
|
| 370 |
+
scale = scale.expand(3)
|
| 371 |
+
|
| 372 |
+
return scale, shift
|
| 373 |
+
|
| 374 |
+
@staticmethod
|
| 375 |
+
def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor):
|
| 376 |
+
if scale.ndim == 1:
|
| 377 |
+
scale = scale.unsqueeze(0)
|
| 378 |
+
if shift.ndim == 1:
|
| 379 |
+
shift = shift.unsqueeze(0)
|
| 380 |
+
return Transform3d().scale(scale).translate(shift).to(shift.device)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class ScaleShiftInvariantWTranslationScale(PoseTargetConvention):
|
| 384 |
+
"""
|
| 385 |
+
|
| 386 |
+
Midas eq. (6): https://arxiv.org/pdf/1907.01341v3
|
| 387 |
+
But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
pose_target_convention: str = "ScaleShiftInvariantWTranslationScale"
|
| 391 |
+
scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32)
|
| 392 |
+
scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32)
|
| 393 |
+
translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32)
|
| 394 |
+
translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32)
|
| 395 |
+
|
| 396 |
+
@classmethod
|
| 397 |
+
def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
|
| 398 |
+
metric_to_ssi = cls.ssi_to_metric(
|
| 399 |
+
instance_pose.scene_scale, instance_pose.scene_shift
|
| 400 |
+
).inverse()
|
| 401 |
+
|
| 402 |
+
ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose(
|
| 403 |
+
scale=instance_pose.instance_scale_l2c,
|
| 404 |
+
rotation=instance_pose.instance_quaternion_l2c,
|
| 405 |
+
translation=instance_pose.instance_position_l2c,
|
| 406 |
+
transform_to_postcompose=metric_to_ssi,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
ssi_translation_scale = ssi_translation.norm(dim=-1, keepdim=True)
|
| 410 |
+
ssi_translation_unit = ssi_translation / ssi_translation_scale.clamp_min(1e-7)
|
| 411 |
+
|
| 412 |
+
return PoseTarget(
|
| 413 |
+
x_instance_scale=ssi_scale,
|
| 414 |
+
x_instance_rotation=ssi_rotation,
|
| 415 |
+
x_instance_translation=ssi_translation_unit,
|
| 416 |
+
x_scene_scale=instance_pose.scene_scale,
|
| 417 |
+
x_scene_center=instance_pose.scene_shift,
|
| 418 |
+
x_translation_scale=ssi_translation_scale,
|
| 419 |
+
pose_target_convention=cls.pose_target_convention,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
@classmethod
|
| 423 |
+
def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
|
| 424 |
+
scene_scale = pose_target.x_scene_scale
|
| 425 |
+
scene_shift = pose_target.x_scene_center
|
| 426 |
+
ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift)
|
| 427 |
+
|
| 428 |
+
ins_translation_unit = pose_target.x_instance_translation / pose_target.x_instance_translation.norm(dim=-1, keepdim=True)
|
| 429 |
+
ins_translation = ins_translation_unit * pose_target.x_translation_scale
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose(
|
| 433 |
+
scale=pose_target.x_instance_scale,
|
| 434 |
+
rotation=pose_target.x_instance_rotation,
|
| 435 |
+
translation=ins_translation,
|
| 436 |
+
transform_to_postcompose=ssi_to_metric,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
return InstancePose(
|
| 441 |
+
instance_scale_l2c=ins_scale,
|
| 442 |
+
instance_position_l2c=ins_translation,
|
| 443 |
+
instance_quaternion_l2c=ins_rotation,
|
| 444 |
+
scene_scale=scene_scale,
|
| 445 |
+
scene_shift=scene_shift,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
@classmethod
|
| 449 |
+
def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget:
|
| 450 |
+
instance_pose = cls.to_instance_pose(pose_target)
|
| 451 |
+
return InvariantPoseTarget.from_instance_pose(instance_pose)
|
| 452 |
+
|
| 453 |
+
@classmethod
|
| 454 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
|
| 455 |
+
instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
|
| 456 |
+
return cls.from_instance_pose(instance_pose)
|
| 457 |
+
|
| 458 |
+
@classmethod
|
| 459 |
+
def get_scale_and_shift(cls, pointmap):
|
| 460 |
+
shift_z = pointmap[..., -1].nanmedian().unsqueeze(0)
|
| 461 |
+
shift = torch.zeros_like(shift_z.expand(1, 3))
|
| 462 |
+
shift[..., -1] = shift_z
|
| 463 |
+
|
| 464 |
+
shifted_pointmap = pointmap - shift
|
| 465 |
+
scale = shifted_pointmap.abs().nanmean().to(shift.device)
|
| 466 |
+
|
| 467 |
+
shift = shift.reshape(3)
|
| 468 |
+
scale = scale.expand(3)
|
| 469 |
+
|
| 470 |
+
return scale, shift
|
| 471 |
+
|
| 472 |
+
@staticmethod
|
| 473 |
+
def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor):
|
| 474 |
+
if scale.ndim == 1:
|
| 475 |
+
scale = scale.unsqueeze(0)
|
| 476 |
+
if shift.ndim == 1:
|
| 477 |
+
shift = shift.unsqueeze(0)
|
| 478 |
+
return Transform3d().scale(scale).translate(shift).to(shift.device)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class DisparitySpace(PoseTargetConvention):
|
| 482 |
+
pose_target_convention: str = "DisparitySpace"
|
| 483 |
+
|
| 484 |
+
@classmethod
|
| 485 |
+
def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget:
|
| 486 |
+
|
| 487 |
+
# x_instance_scale = orig_scale / scene_scale
|
| 488 |
+
# x_instance_translation = [x/z, y/z, 0] / scene_scale
|
| 489 |
+
# x_translation_scale = z / scene_scale
|
| 490 |
+
assert torch.allclose(instance_pose.scene_scale, torch.ones_like(instance_pose.scene_scale))
|
| 491 |
+
|
| 492 |
+
if not instance_pose.scene_shift.ndim == instance_pose.instance_position_l2c.ndim - 1:
|
| 493 |
+
raise ValueError(f"scene_shift must be (N+1)D and instance_position_l2c must be (N+1)D, but got {instance_pose.scene_shift.ndim} and {instance_pose.instance_position_l2c.ndim}")
|
| 494 |
+
shift_xy, shift_z_log = instance_pose.scene_shift.unsqueeze(-2).split([2, 1], dim=-1)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
pose_xy, pose_z = instance_pose.instance_position_l2c.split([2, 1], dim=-1)
|
| 498 |
+
# Handle batch dimensions properly
|
| 499 |
+
if shift_xy.ndim < pose_xy.ndim:
|
| 500 |
+
shift_xy = shift_xy.unsqueeze(-2)
|
| 501 |
+
pose_xy_scaled = pose_xy / pose_z - shift_xy
|
| 502 |
+
|
| 503 |
+
pose_z_scaled_log = torch.log(pose_z) - shift_z_log
|
| 504 |
+
x_instance_scale_log = torch.log(instance_pose.instance_scale_l2c) - torch.log(pose_z)
|
| 505 |
+
|
| 506 |
+
x_instance_translation = torch.cat([pose_xy_scaled, torch.zeros_like(pose_z)], dim=-1)
|
| 507 |
+
x_translation_scale = torch.exp(pose_z_scaled_log)
|
| 508 |
+
x_instance_scale = torch.exp(x_instance_scale_log)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
return PoseTarget(
|
| 513 |
+
x_instance_scale=x_instance_scale,
|
| 514 |
+
x_instance_translation=x_instance_translation,
|
| 515 |
+
x_instance_rotation=instance_pose.instance_quaternion_l2c,
|
| 516 |
+
x_scene_scale=instance_pose.scene_scale,
|
| 517 |
+
x_scene_center=instance_pose.scene_shift,
|
| 518 |
+
x_translation_scale=x_translation_scale,
|
| 519 |
+
pose_target_convention=cls.pose_target_convention,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
@classmethod
|
| 523 |
+
def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
|
| 524 |
+
scene_scale = pose_target.x_scene_scale
|
| 525 |
+
scene_shift = pose_target.x_scene_center
|
| 526 |
+
|
| 527 |
+
if not pose_target.x_scene_center.ndim == pose_target.x_instance_translation.ndim - 1:
|
| 528 |
+
raise ValueError(f"x_scene_center must be (N+1)D and x_instance_translation must be (N+1)D, but got {pose_target.x_scene_center.ndim} and {pose_target.x_instance_translation.ndim}")
|
| 529 |
+
shift_xy, shift_z_log = pose_target.x_scene_center.unsqueeze(-2).split([2, 1], dim=-1)
|
| 530 |
+
scene_z_scale = torch.exp(shift_z_log)
|
| 531 |
+
|
| 532 |
+
z = pose_target.x_translation_scale
|
| 533 |
+
ins_translation = pose_target.x_instance_translation.clone()
|
| 534 |
+
ins_translation[...,2] = 1.0
|
| 535 |
+
ins_translation[...,:2] = ins_translation[...,:2] + shift_xy
|
| 536 |
+
ins_translation = ins_translation * z * scene_z_scale
|
| 537 |
+
|
| 538 |
+
ins_scale = pose_target.x_instance_scale * z * scene_z_scale
|
| 539 |
+
|
| 540 |
+
return InstancePose(
|
| 541 |
+
instance_scale_l2c=ins_scale * scene_scale,
|
| 542 |
+
instance_position_l2c=ins_translation * scene_scale,
|
| 543 |
+
instance_quaternion_l2c=pose_target.x_instance_rotation,
|
| 544 |
+
scene_scale=scene_scale,
|
| 545 |
+
scene_shift=scene_shift,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
@classmethod
|
| 549 |
+
def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget:
|
| 550 |
+
instance_pose = cls.to_instance_pose(pose_target, normalize=normalize)
|
| 551 |
+
return InvariantPoseTarget.from_instance_pose(instance_pose)
|
| 552 |
+
|
| 553 |
+
@classmethod
|
| 554 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget:
|
| 555 |
+
instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
|
| 556 |
+
return cls.from_instance_pose(instance_pose, normalize=normalize)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class NormalizedSceneScale(PoseTargetConvention):
|
| 561 |
+
"""
|
| 562 |
+
x_instance_scale and x_translation_scale are normalized to x_scene_scale
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
pose_target_convention: str = "NormalizedSceneScale"
|
| 566 |
+
|
| 567 |
+
@classmethod
|
| 568 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget):
|
| 569 |
+
translation = invariant_targets.t_unit * invariant_targets.t_rel_norm
|
| 570 |
+
return PoseTarget(
|
| 571 |
+
x_instance_scale=invariant_targets.s_rel,
|
| 572 |
+
x_instance_rotation=invariant_targets.q,
|
| 573 |
+
x_instance_translation=translation,
|
| 574 |
+
x_scene_scale=invariant_targets.s_scene,
|
| 575 |
+
x_scene_center=invariant_targets.t_scene_center,
|
| 576 |
+
x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm),
|
| 577 |
+
pose_target_convention=cls.pose_target_convention,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
@classmethod
|
| 581 |
+
def to_invariant(cls, pose_target: PoseTarget):
|
| 582 |
+
t_rel_norm = torch.norm(
|
| 583 |
+
pose_target.x_instance_translation, dim=-1, keepdim=True
|
| 584 |
+
)
|
| 585 |
+
return InvariantPoseTarget(
|
| 586 |
+
s_scene=pose_target.x_scene_scale,
|
| 587 |
+
s_rel=pose_target.x_instance_scale,
|
| 588 |
+
q=pose_target.x_instance_rotation,
|
| 589 |
+
t_unit=pose_target.x_instance_translation / t_rel_norm,
|
| 590 |
+
t_rel_norm=t_rel_norm,
|
| 591 |
+
t_scene_center=pose_target.x_scene_center,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
class Naive(PoseTargetConvention):
|
| 596 |
+
pose_target_convention: str = "Naive"
|
| 597 |
+
|
| 598 |
+
@classmethod
|
| 599 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget):
|
| 600 |
+
s_scene = invariant_targets.s_rel * invariant_targets.s_scene
|
| 601 |
+
t_scene = invariant_targets.t_unit * invariant_targets.t_rel_norm
|
| 602 |
+
return PoseTarget(
|
| 603 |
+
x_instance_scale=s_scene,
|
| 604 |
+
x_instance_rotation=invariant_targets.q,
|
| 605 |
+
x_instance_translation=t_scene,
|
| 606 |
+
x_scene_scale=invariant_targets.s_scene,
|
| 607 |
+
x_scene_center=invariant_targets.t_scene_center,
|
| 608 |
+
x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm),
|
| 609 |
+
pose_target_convention=cls.pose_target_convention,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
@classmethod
|
| 613 |
+
def to_invariant(cls, pose_target: PoseTarget):
|
| 614 |
+
s_scene = pose_target.x_scene_scale
|
| 615 |
+
t_rel_norm = torch.norm(
|
| 616 |
+
pose_target.x_instance_translation, dim=-1, keepdim=True
|
| 617 |
+
)
|
| 618 |
+
return InvariantPoseTarget(
|
| 619 |
+
s_scene=s_scene,
|
| 620 |
+
t_scene_center=pose_target.x_scene_center,
|
| 621 |
+
s_rel=pose_target.x_instance_scale / s_scene,
|
| 622 |
+
q=pose_target.x_instance_rotation,
|
| 623 |
+
t_unit=pose_target.x_instance_translation / t_rel_norm,
|
| 624 |
+
t_rel_norm=t_rel_norm,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class NormalizedSceneScaleAndTranslation(PoseTargetConvention):
|
| 629 |
+
"""
|
| 630 |
+
x_instance_scale and x_translation_scale are normalized to x_scene_scale
|
| 631 |
+
x_instance_translation is unit
|
| 632 |
+
"""
|
| 633 |
+
|
| 634 |
+
pose_target_convention: str = "NormalizedSceneScaleAndTranslation"
|
| 635 |
+
|
| 636 |
+
@classmethod
|
| 637 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget):
|
| 638 |
+
return PoseTarget(
|
| 639 |
+
x_instance_scale=invariant_targets.s_rel,
|
| 640 |
+
x_instance_rotation=invariant_targets.q,
|
| 641 |
+
x_instance_translation=invariant_targets.t_unit,
|
| 642 |
+
x_scene_scale=invariant_targets.s_scene,
|
| 643 |
+
x_scene_center=invariant_targets.t_scene_center,
|
| 644 |
+
x_translation_scale=invariant_targets.t_rel_norm,
|
| 645 |
+
pose_target_convention=cls.pose_target_convention,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
@classmethod
|
| 649 |
+
def to_invariant(cls, pose_target: PoseTarget):
|
| 650 |
+
return InvariantPoseTarget(
|
| 651 |
+
s_scene=pose_target.x_scene_scale,
|
| 652 |
+
t_scene_center=pose_target.x_scene_center,
|
| 653 |
+
s_rel=pose_target.x_instance_scale,
|
| 654 |
+
q=pose_target.x_instance_rotation,
|
| 655 |
+
t_unit=pose_target.x_instance_translation,
|
| 656 |
+
t_rel_norm=pose_target.x_translation_scale,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class ApparentSize(PoseTargetConvention):
|
| 661 |
+
pose_target_convention: str = "ApparentSize"
|
| 662 |
+
|
| 663 |
+
@classmethod
|
| 664 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget):
|
| 665 |
+
return PoseTarget(
|
| 666 |
+
x_instance_scale=invariant_targets.s_tilde,
|
| 667 |
+
x_instance_rotation=invariant_targets.q,
|
| 668 |
+
x_instance_translation=invariant_targets.t_unit,
|
| 669 |
+
x_scene_scale=invariant_targets.s_scene,
|
| 670 |
+
x_scene_center=invariant_targets.t_scene_center,
|
| 671 |
+
x_translation_scale=invariant_targets.t_rel_norm,
|
| 672 |
+
pose_target_convention=cls.pose_target_convention,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
@classmethod
|
| 676 |
+
def to_invariant(cls, pose_target: PoseTarget):
|
| 677 |
+
return InvariantPoseTarget(
|
| 678 |
+
s_scene=pose_target.x_scene_scale,
|
| 679 |
+
t_scene_center=pose_target.x_scene_center,
|
| 680 |
+
s_tilde=pose_target.x_instance_scale,
|
| 681 |
+
q=pose_target.x_instance_rotation,
|
| 682 |
+
t_unit=pose_target.x_instance_translation,
|
| 683 |
+
t_rel_norm=pose_target.x_translation_scale,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class Identity(PoseTargetConvention):
|
| 688 |
+
"""
|
| 689 |
+
Identity convention - no transformation applied.
|
| 690 |
+
Direct passthrough mapping between instance pose and pose target values.
|
| 691 |
+
This preserves all values including scene_scale and scene_shift.
|
| 692 |
+
"""
|
| 693 |
+
|
| 694 |
+
pose_target_convention: str = "Identity"
|
| 695 |
+
|
| 696 |
+
@classmethod
|
| 697 |
+
def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget:
|
| 698 |
+
return PoseTarget(
|
| 699 |
+
x_instance_scale=instance_pose.instance_scale_l2c,
|
| 700 |
+
x_instance_rotation=instance_pose.instance_quaternion_l2c,
|
| 701 |
+
x_instance_translation=instance_pose.instance_position_l2c,
|
| 702 |
+
x_scene_scale=instance_pose.scene_scale,
|
| 703 |
+
x_scene_center=instance_pose.scene_shift,
|
| 704 |
+
x_translation_scale=torch.ones_like(instance_pose.instance_scale_l2c)[..., 0].unsqueeze(-1),
|
| 705 |
+
pose_target_convention=cls.pose_target_convention,
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
@classmethod
|
| 709 |
+
def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose:
|
| 710 |
+
return InstancePose(
|
| 711 |
+
instance_scale_l2c=pose_target.x_instance_scale,
|
| 712 |
+
instance_position_l2c=pose_target.x_instance_translation,
|
| 713 |
+
instance_quaternion_l2c=pose_target.x_instance_rotation,
|
| 714 |
+
scene_scale=pose_target.x_scene_scale,
|
| 715 |
+
scene_shift=pose_target.x_scene_center,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
@classmethod
|
| 719 |
+
def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget:
|
| 720 |
+
instance_pose = cls.to_instance_pose(pose_target)
|
| 721 |
+
return InvariantPoseTarget.from_instance_pose(instance_pose)
|
| 722 |
+
|
| 723 |
+
@classmethod
|
| 724 |
+
def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget:
|
| 725 |
+
instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets)
|
| 726 |
+
return cls.from_instance_pose(instance_pose)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class PoseTargetConverter:
|
| 730 |
+
@staticmethod
|
| 731 |
+
def pose_target_to_instance_pose(pose_target: PoseTarget, normalize: bool = False) -> InstancePose:
|
| 732 |
+
_convention_class = globals()[pose_target.pose_target_convention]
|
| 733 |
+
if _convention_class == ScaleShiftInvariant:
|
| 734 |
+
return _convention_class.to_instance_pose(pose_target, normalize=normalize)
|
| 735 |
+
else:
|
| 736 |
+
return _convention_class.to_instance_pose(pose_target)
|
| 737 |
+
|
| 738 |
+
@staticmethod
|
| 739 |
+
def instance_pose_to_pose_target(
|
| 740 |
+
instance_pose: InstancePose, pose_target_convention: str, normalize: bool = False
|
| 741 |
+
) -> PoseTarget:
|
| 742 |
+
_convention_class = globals()[pose_target_convention]
|
| 743 |
+
if _convention_class == ScaleShiftInvariant:
|
| 744 |
+
return _convention_class.from_instance_pose(instance_pose, normalize=normalize)
|
| 745 |
+
else:
|
| 746 |
+
return _convention_class.from_instance_pose(instance_pose)
|
| 747 |
+
|
| 748 |
+
@staticmethod
|
| 749 |
+
def dicts_instance_pose_to_pose_target(
|
| 750 |
+
pose_target_convention: str,
|
| 751 |
+
**kwargs,
|
| 752 |
+
):
|
| 753 |
+
instance_pose = InstancePose(**kwargs)
|
| 754 |
+
pose_target = PoseTargetConverter.instance_pose_to_pose_target(
|
| 755 |
+
instance_pose, pose_target_convention
|
| 756 |
+
)
|
| 757 |
+
return asdict(pose_target)
|
| 758 |
+
|
| 759 |
+
@staticmethod
|
| 760 |
+
def dicts_pose_target_to_instance_pose(
|
| 761 |
+
**kwargs,
|
| 762 |
+
):
|
| 763 |
+
pose_target_convention = kwargs.get("pose_target_convention")
|
| 764 |
+
_convention_class = globals()[pose_target_convention]
|
| 765 |
+
assert (
|
| 766 |
+
_convention_class.pose_target_convention == pose_target_convention
|
| 767 |
+
), f"Normalization name mismatch: {_convention_class.pose_target_convention} != {pose_target_convention}"
|
| 768 |
+
|
| 769 |
+
normalize = kwargs.pop("normalize", False)
|
| 770 |
+
pose_target = PoseTarget(**kwargs)
|
| 771 |
+
instance_pose = PoseTargetConverter.pose_target_to_instance_pose(pose_target, normalize)
|
| 772 |
+
return asdict(instance_pose)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
class LogScaleShiftNormalizer:
|
| 776 |
+
def __init__(self, shift_log: torch.Tensor = 0.0, scale_log: torch.Tensor = 1.0):
|
| 777 |
+
self.shift_log = shift_log
|
| 778 |
+
self.scale_log = scale_log
|
| 779 |
+
|
| 780 |
+
def normalize(self, value: torch.Tensor):
|
| 781 |
+
return torch.log(value) - self.shift_log / self.scale_log
|
| 782 |
+
|
| 783 |
+
def denormalize(self, value: torch.Tensor):
|
| 784 |
+
return torch.exp(value * self.scale_log + self.shift_log)
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import warnings
|
| 3 |
+
import torch
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Callable, Optional
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
from .img_and_mask_transforms import (
|
| 10 |
+
SSIPointmapNormalizer,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Load and process data
|
| 15 |
+
@dataclass
|
| 16 |
+
class PreProcessor:
|
| 17 |
+
"""
|
| 18 |
+
Preprocessor configuration for image, mask, and pointmap transforms.
|
| 19 |
+
|
| 20 |
+
Transform application order:
|
| 21 |
+
1. Pointmap normalization (if normalize_pointmap=True)
|
| 22 |
+
2. Joint transforms (img_mask_pointmap_joint_transform or img_mask_joint_transform)
|
| 23 |
+
3. Individual transforms (img_transform, mask_transform, pointmap_transform)
|
| 24 |
+
|
| 25 |
+
For backward compatibility, img_mask_joint_transform is preserved. When both
|
| 26 |
+
img_mask_pointmap_joint_transform and img_mask_joint_transform are present,
|
| 27 |
+
img_mask_pointmap_joint_transform takes priority.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
img_transform: Callable = (None,)
|
| 31 |
+
mask_transform: Callable = (None,)
|
| 32 |
+
img_mask_joint_transform: list[Callable] = (None,)
|
| 33 |
+
rgb_img_mask_joint_transform: list[Callable] = (None,)
|
| 34 |
+
|
| 35 |
+
# New fields for pointmap support
|
| 36 |
+
pointmap_transform: Callable = (None,)
|
| 37 |
+
img_mask_pointmap_joint_transform: list[Callable] = (None,)
|
| 38 |
+
|
| 39 |
+
# Pointmap normalization option
|
| 40 |
+
normalize_pointmap: bool = False
|
| 41 |
+
pointmap_normalizer: Optional[Callable] = None
|
| 42 |
+
rgb_pointmap_normalizer: Optional[Callable] = None
|
| 43 |
+
|
| 44 |
+
def __post_init__(self):
|
| 45 |
+
if self.pointmap_normalizer is None:
|
| 46 |
+
self.pointmap_normalizer = SSIPointmapNormalizer()
|
| 47 |
+
if self.normalize_pointmap == False:
|
| 48 |
+
warnings.warn("normalize_pointmap is also set to False, which means we will return the moments but not normalize the pointmap. This supports old unnormalized pointmap models, but this is dangerous behavior.", DeprecationWarning, stacklevel=2)
|
| 49 |
+
|
| 50 |
+
if self.rgb_pointmap_normalizer is None:
|
| 51 |
+
logger.warning("No rgb pointmap normalizer provided, using scale + shift ")
|
| 52 |
+
self.rgb_pointmap_normalizer = self.pointmap_normalizer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _normalize_pointmap(
|
| 56 |
+
self, pointmap: torch.Tensor,
|
| 57 |
+
mask: torch.Tensor,
|
| 58 |
+
pointmap_normalizer: Callable,
|
| 59 |
+
scale: Optional[torch.Tensor] = None,
|
| 60 |
+
shift: Optional[torch.Tensor] = None,
|
| 61 |
+
):
|
| 62 |
+
if pointmap is None:
|
| 63 |
+
return pointmap, None, None
|
| 64 |
+
|
| 65 |
+
if self.normalize_pointmap == False:
|
| 66 |
+
# old behavior: Pose is normalized to the pointmap center, but pointmap is not
|
| 67 |
+
_, pointmap_scale, pointmap_shift = pointmap_normalizer.normalize(pointmap, mask)
|
| 68 |
+
return pointmap, pointmap_scale, pointmap_shift
|
| 69 |
+
|
| 70 |
+
if scale is not None or shift is not None:
|
| 71 |
+
return pointmap_normalizer.normalize(pointmap, mask, scale, shift)
|
| 72 |
+
|
| 73 |
+
return pointmap_normalizer.normalize(pointmap, mask)
|
| 74 |
+
|
| 75 |
+
def _process_image_mask_pointmap_mess(
|
| 76 |
+
self, rgb_image, rgb_image_mask, pointmap=None
|
| 77 |
+
):
|
| 78 |
+
"""Extended version that handles pointmaps"""
|
| 79 |
+
|
| 80 |
+
# Apply pointmap normalization if enabled
|
| 81 |
+
pointmap_for_crop, pointmap_scale, pointmap_shift = self._normalize_pointmap(
|
| 82 |
+
pointmap, rgb_image_mask, self.pointmap_normalizer
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Apply transforms to the original full rgb image and mask.
|
| 86 |
+
rgb_image, rgb_image_mask = self._preprocess_rgb_image_mask(rgb_image, rgb_image_mask)
|
| 87 |
+
|
| 88 |
+
# These two are typically used for getting cropped images of the object
|
| 89 |
+
# : first apply joint transforms
|
| 90 |
+
processed_rgb_image, processed_mask, processed_pointmap = (
|
| 91 |
+
self._preprocess_image_mask_pointmap(rgb_image, rgb_image_mask, pointmap_for_crop)
|
| 92 |
+
)
|
| 93 |
+
# : then apply individual transforms on top of the joint transforms
|
| 94 |
+
processed_rgb_image = self._apply_transform(
|
| 95 |
+
processed_rgb_image, self.img_transform
|
| 96 |
+
)
|
| 97 |
+
processed_mask = self._apply_transform(processed_mask, self.mask_transform)
|
| 98 |
+
if processed_pointmap is not None:
|
| 99 |
+
processed_pointmap = self._apply_transform(
|
| 100 |
+
processed_pointmap, self.pointmap_transform
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# This version is typically the full version of the image
|
| 104 |
+
# : apply individual transforms only
|
| 105 |
+
rgb_image = self._apply_transform(rgb_image, self.img_transform)
|
| 106 |
+
rgb_image_mask = self._apply_transform(rgb_image_mask, self.mask_transform)
|
| 107 |
+
|
| 108 |
+
rgb_pointmap, rgb_pointmap_scale, rgb_pointmap_shift = self._normalize_pointmap(
|
| 109 |
+
pointmap, rgb_image_mask, self.rgb_pointmap_normalizer, pointmap_scale, pointmap_shift
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if rgb_pointmap is not None:
|
| 113 |
+
rgb_pointmap = self._apply_transform(rgb_pointmap, self.pointmap_transform)
|
| 114 |
+
|
| 115 |
+
result = {
|
| 116 |
+
"mask": processed_mask,
|
| 117 |
+
"image": processed_rgb_image,
|
| 118 |
+
"rgb_image": rgb_image,
|
| 119 |
+
"rgb_image_mask": rgb_image_mask,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# Add pointmap results if available
|
| 123 |
+
if processed_pointmap is not None:
|
| 124 |
+
result.update(
|
| 125 |
+
{
|
| 126 |
+
"pointmap": processed_pointmap,
|
| 127 |
+
"rgb_pointmap": rgb_pointmap,
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Add normalization parameters if normalization was applied
|
| 132 |
+
if pointmap_scale is not None and pointmap_shift is not None:
|
| 133 |
+
result.update(
|
| 134 |
+
{
|
| 135 |
+
"pointmap_scale": pointmap_scale,
|
| 136 |
+
"pointmap_shift": pointmap_shift,
|
| 137 |
+
"rgb_pointmap_scale": rgb_pointmap_scale,
|
| 138 |
+
"rgb_pointmap_shift": rgb_pointmap_shift,
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return result
|
| 143 |
+
|
| 144 |
+
def _process_image_and_mask_mess(self, rgb_image, rgb_image_mask):
|
| 145 |
+
"""Original method - calls extended version without pointmap"""
|
| 146 |
+
return self._process_image_mask_pointmap_mess(rgb_image, rgb_image_mask, None)
|
| 147 |
+
|
| 148 |
+
def _preprocess_rgb_image_mask(self, rgb_image: torch.Tensor, rgb_image_mask: torch.Tensor):
|
| 149 |
+
"""Apply joint transforms to rgb_image and rgb_image_mask."""
|
| 150 |
+
if (
|
| 151 |
+
self.rgb_img_mask_joint_transform != (None,)
|
| 152 |
+
and self.rgb_img_mask_joint_transform is not None
|
| 153 |
+
):
|
| 154 |
+
for trans in self.rgb_img_mask_joint_transform:
|
| 155 |
+
rgb_image, rgb_image_mask = trans(rgb_image, rgb_image_mask)
|
| 156 |
+
return rgb_image, rgb_image_mask
|
| 157 |
+
|
| 158 |
+
def _preprocess_image_mask_pointmap(self, rgb_image, mask_image, pointmap=None):
|
| 159 |
+
"""Apply joint transforms with priority: triple transforms > dual transforms."""
|
| 160 |
+
# Priority: img_mask_pointmap_joint_transform when pointmap is provided
|
| 161 |
+
if (
|
| 162 |
+
self.img_mask_pointmap_joint_transform != (None,)
|
| 163 |
+
and self.img_mask_pointmap_joint_transform is not None
|
| 164 |
+
and pointmap is not None
|
| 165 |
+
):
|
| 166 |
+
for trans in self.img_mask_pointmap_joint_transform:
|
| 167 |
+
rgb_image, mask_image, pointmap = trans(
|
| 168 |
+
rgb_image, mask_image, pointmap=pointmap
|
| 169 |
+
)
|
| 170 |
+
return rgb_image, mask_image, pointmap
|
| 171 |
+
|
| 172 |
+
# Fallback: img_mask_joint_transform (existing behavior)
|
| 173 |
+
elif (
|
| 174 |
+
self.img_mask_joint_transform != (None,)
|
| 175 |
+
and self.img_mask_joint_transform is not None
|
| 176 |
+
):
|
| 177 |
+
for trans in self.img_mask_joint_transform:
|
| 178 |
+
rgb_image, mask_image = trans(rgb_image, mask_image)
|
| 179 |
+
return rgb_image, mask_image, pointmap
|
| 180 |
+
|
| 181 |
+
return rgb_image, mask_image, pointmap
|
| 182 |
+
|
| 183 |
+
def _preprocess_image_and_mask(self, rgb_image, mask_image):
|
| 184 |
+
"""Backward compatibility wrapper - only applies dual transforms"""
|
| 185 |
+
rgb_image, mask_image, _ = self._preprocess_image_mask_pointmap(
|
| 186 |
+
rgb_image, mask_image, None
|
| 187 |
+
)
|
| 188 |
+
return rgb_image, mask_image
|
| 189 |
+
|
| 190 |
+
# keep here for backward compatibility
|
| 191 |
+
def _preprocess_image_and_mask_inference(self, rgb_image, mask_image):
|
| 192 |
+
warnings.warn(
|
| 193 |
+
"The _preprocess_image_and_mask_inference is deprecated! Please use _preprocess_image_and_mask",
|
| 194 |
+
category=DeprecationWarning,
|
| 195 |
+
stacklevel=2,
|
| 196 |
+
)
|
| 197 |
+
return self._preprocess_image_and_mask(rgb_image, mask_image)
|
| 198 |
+
|
| 199 |
+
def _apply_transform(self, input: torch.Tensor, transform):
|
| 200 |
+
if input is not None and transform is not None and transform != (None,):
|
| 201 |
+
input = transform(input)
|
| 202 |
+
|
| 203 |
+
return input
|
thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from pytorch3d.transforms import (
|
| 7 |
+
Rotate,
|
| 8 |
+
Translate,
|
| 9 |
+
Scale,
|
| 10 |
+
Transform3d,
|
| 11 |
+
quaternion_to_matrix,
|
| 12 |
+
axis_angle_to_quaternion,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
DecomposedTransform = namedtuple(
|
| 16 |
+
"DecomposedTransform", ["scale", "rotation", "translation"]
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def compose_transform(
|
| 21 |
+
scale: torch.Tensor, rotation: torch.Tensor, translation: torch.Tensor
|
| 22 |
+
) -> Transform3d:
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
scale: (..., 3) tensor of scale factors
|
| 26 |
+
rotation: (..., 3, 3) tensor of rotation matrices
|
| 27 |
+
translation: (..., 3) tensor of translation vectors
|
| 28 |
+
"""
|
| 29 |
+
tfm = Transform3d(dtype=scale.dtype, device=scale.device)
|
| 30 |
+
return tfm.scale(scale).rotate(rotation).translate(translation)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def decompose_transform(transform: Transform3d) -> DecomposedTransform:
|
| 34 |
+
"""
|
| 35 |
+
Returns:
|
| 36 |
+
scale: (..., 3) tensor of scale factors
|
| 37 |
+
rotation: (..., 3, 3) tensor of rotation matrices
|
| 38 |
+
translation: (..., 3) tensor of translation vectors
|
| 39 |
+
"""
|
| 40 |
+
matrices = transform.get_matrix()
|
| 41 |
+
scale = torch.norm(matrices[:, :3, :3], dim=-1)
|
| 42 |
+
rotation = matrices[:, :3, :3] / scale.unsqueeze(-1) # Normalize rotation matrix
|
| 43 |
+
translation = matrices[:, 3, :3] # Extract translation vector
|
| 44 |
+
return DecomposedTransform(scale, rotation, translation)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_rotation_about_x_axis(angle: float = math.pi / 2) -> torch.Tensor:
|
| 48 |
+
axis = torch.tensor([1.0, 0.0, 0.0])
|
| 49 |
+
axis_angle = axis * angle
|
| 50 |
+
return axis_angle_to_quaternion(axis_angle)
|
thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from typing import Any, Iterable, Tuple, Union, Dict, Sequence, Mapping, Container
|
| 3 |
+
import optree
|
| 4 |
+
import torch
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
import inspect
|
| 7 |
+
import ast
|
| 8 |
+
import astor
|
| 9 |
+
from torch.utils import _pytree
|
| 10 |
+
|
| 11 |
+
# None = root, Iterable[Any] = path, Any = path of one
|
| 12 |
+
ChildPathType = Union[None, Iterable[Any], Any]
|
| 13 |
+
ArgsType = Iterable[ChildPathType]
|
| 14 |
+
KwargsType = Mapping[str, ChildPathType]
|
| 15 |
+
ArgsKwargsType = Tuple[ArgsType, KwargsType]
|
| 16 |
+
MappingType = Union[None, ArgsKwargsType, ArgsType, KwargsType]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def tree_transpose_level_one(
|
| 20 |
+
structure,
|
| 21 |
+
check_children=False,
|
| 22 |
+
map_fn=None,
|
| 23 |
+
is_leaf=None,
|
| 24 |
+
):
|
| 25 |
+
_, outer_spec = optree.tree_flatten(
|
| 26 |
+
structure,
|
| 27 |
+
is_leaf=lambda x: x is not structure,
|
| 28 |
+
none_is_leaf=True,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
spec = optree.tree_structure(structure, none_is_leaf=True, is_leaf=is_leaf)
|
| 32 |
+
children_spec = spec.children()
|
| 33 |
+
if len(children_spec) > 0:
|
| 34 |
+
inner_spec = children_spec[0]
|
| 35 |
+
if check_children:
|
| 36 |
+
for child_spec in children_spec[1:]:
|
| 37 |
+
assert (
|
| 38 |
+
inner_spec == child_spec
|
| 39 |
+
), f"one child was found having a different tree structure ({inner_spec} != {child_spec})"
|
| 40 |
+
|
| 41 |
+
structure = optree.tree_transpose(outer_spec, inner_spec, structure)
|
| 42 |
+
|
| 43 |
+
if map_fn is not None:
|
| 44 |
+
structure = optree.tree_map(
|
| 45 |
+
map_fn,
|
| 46 |
+
structure,
|
| 47 |
+
is_leaf=lambda x: optree.tree_structure(
|
| 48 |
+
x, is_leaf=is_leaf, none_is_leaf=True
|
| 49 |
+
)
|
| 50 |
+
== outer_spec,
|
| 51 |
+
none_is_leaf=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return structure
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def tree_tensor_map(fn, tree, *rest):
|
| 59 |
+
return optree.tree_map(
|
| 60 |
+
fn,
|
| 61 |
+
tree,
|
| 62 |
+
*rest,
|
| 63 |
+
is_leaf=lambda x: isinstance(x, torch.Tensor),
|
| 64 |
+
none_is_leaf=False,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def to_device(obj, device):
|
| 69 |
+
"""Recursively moves all tensors in obj to the specified device.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
obj: Object to move to device - can be a tensor, list, tuple, dict or any nested combination
|
| 73 |
+
device: Target device (e.g. 'cuda', 'cpu', torch.device('cuda:0') etc.)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Same object structure with all contained tensors moved to specified device
|
| 77 |
+
"""
|
| 78 |
+
to_fn = lambda x: x.to(device)
|
| 79 |
+
return optree.tree_map(to_fn, obj, is_leaf=torch.is_tensor, none_is_leaf=False)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def expand_right(tensor, target_shape):
|
| 83 |
+
"""
|
| 84 |
+
e.g. Takes tensor of (a, b, c) and returns a tensor of (a, b, c, 1, 1, ...)
|
| 85 |
+
"""
|
| 86 |
+
current_shape = tensor.shape
|
| 87 |
+
dims_to_add = len(target_shape) - len(current_shape)
|
| 88 |
+
result = tensor
|
| 89 |
+
for _ in range(dims_to_add):
|
| 90 |
+
result = result.unsqueeze(-1)
|
| 91 |
+
expand_shape = list(current_shape) + [-1] * dims_to_add
|
| 92 |
+
for i in range(len(target_shape)):
|
| 93 |
+
if i < len(expand_shape) and expand_shape[i] == -1:
|
| 94 |
+
expand_shape[i] = target_shape[i]
|
| 95 |
+
return result.expand(*expand_shape)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def expand_as_right(tensor, target):
|
| 99 |
+
return expand_right(tensor, target.shape)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def as_keys(path: ChildPathType):
|
| 103 |
+
if isinstance(path, Iterable) and (not isinstance(path, str)):
|
| 104 |
+
return tuple(path)
|
| 105 |
+
elif path is None:
|
| 106 |
+
return ()
|
| 107 |
+
return (path,)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_child(obj: Any, *keys: Iterable[Any]):
|
| 111 |
+
for key in keys:
|
| 112 |
+
obj = obj[key]
|
| 113 |
+
return obj
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def set_child(obj: Any, value: Any, *keys: Iterable[Any]):
|
| 117 |
+
parent = None
|
| 118 |
+
for key in keys:
|
| 119 |
+
parent = obj
|
| 120 |
+
obj = obj[key]
|
| 121 |
+
if parent is None:
|
| 122 |
+
obj = value
|
| 123 |
+
else:
|
| 124 |
+
parent[key] = value
|
| 125 |
+
return obj
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_args_batch_extractor(args_mapping: ArgsType):
|
| 129 |
+
def extract_fn(batch):
|
| 130 |
+
return tuple(get_child(batch, *as_keys(path)) for path in args_mapping)
|
| 131 |
+
|
| 132 |
+
return extract_fn
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def build_kwargs_batch_extractor(kwargs_mapping: KwargsType):
|
| 136 |
+
def extract_fn(batch):
|
| 137 |
+
return {
|
| 138 |
+
name: get_child(batch, *as_keys(path))
|
| 139 |
+
for name, path in kwargs_mapping.items()
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
return extract_fn
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
empty_mapping = object()
|
| 146 |
+
kwargs_identity_mapping = object()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def build_batch_extractor(mapping: MappingType):
|
| 150 |
+
extract_args_fn = lambda x: ()
|
| 151 |
+
extract_kwargs_fn = lambda x: {}
|
| 152 |
+
|
| 153 |
+
if mapping is None:
|
| 154 |
+
|
| 155 |
+
def extract_args_fn(batch):
|
| 156 |
+
return (batch,)
|
| 157 |
+
|
| 158 |
+
elif mapping is empty_mapping:
|
| 159 |
+
pass
|
| 160 |
+
elif mapping is kwargs_identity_mapping:
|
| 161 |
+
extract_kwargs_fn = lambda x: x
|
| 162 |
+
elif isinstance(mapping, Sequence) and (not isinstance(mapping, str)):
|
| 163 |
+
if (
|
| 164 |
+
len(mapping) == 2
|
| 165 |
+
and isinstance(mapping[0], Sequence)
|
| 166 |
+
and isinstance(mapping[1], Dict)
|
| 167 |
+
):
|
| 168 |
+
extract_args_fn = build_args_batch_extractor(mapping[0])
|
| 169 |
+
extract_kwargs_fn = build_kwargs_batch_extractor(mapping[1])
|
| 170 |
+
else:
|
| 171 |
+
extract_args_fn = build_args_batch_extractor(mapping)
|
| 172 |
+
elif isinstance(mapping, Mapping):
|
| 173 |
+
extract_kwargs_fn = build_kwargs_batch_extractor(mapping)
|
| 174 |
+
else:
|
| 175 |
+
|
| 176 |
+
def extract_args_fn(batch):
|
| 177 |
+
return (get_child(batch, *as_keys(mapping)),)
|
| 178 |
+
|
| 179 |
+
def extract_fn(batch):
|
| 180 |
+
return extract_args_fn(batch), extract_kwargs_fn(batch)
|
| 181 |
+
|
| 182 |
+
return extract_fn
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# >
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def right_broadcasting(arr, target):
|
| 189 |
+
return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_stats(tensor: torch.Tensor):
|
| 193 |
+
float_tensor = tensor.float()
|
| 194 |
+
return {
|
| 195 |
+
"shape": tuple(tensor.shape),
|
| 196 |
+
"min": tensor.min().item(),
|
| 197 |
+
"max": tensor.max().item(),
|
| 198 |
+
"mean": float_tensor.mean().item(),
|
| 199 |
+
"median": tensor.median().item(),
|
| 200 |
+
"std": float_tensor.std().item(),
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _get_caller_arg_name(argnum=0, parent_frame=1):
|
| 205 |
+
try:
|
| 206 |
+
frame = inspect.currentframe() # current frame
|
| 207 |
+
frame = inspect.getouterframes(frame)[1 + parent_frame] # parent frame
|
| 208 |
+
code = inspect.getframeinfo(frame[0]).code_context[0].strip() # get code line
|
| 209 |
+
|
| 210 |
+
tree = ast.parse(code)
|
| 211 |
+
|
| 212 |
+
for node in ast.walk(tree):
|
| 213 |
+
if isinstance(node, ast.Call):
|
| 214 |
+
args = node.args
|
| 215 |
+
break # only get the first parent call
|
| 216 |
+
|
| 217 |
+
# get first argument string (do not handle '=')
|
| 218 |
+
label = astor.to_source(args[argnum]).strip()
|
| 219 |
+
except:
|
| 220 |
+
# TODO(Pierre) log exception
|
| 221 |
+
label = "{label}"
|
| 222 |
+
return label
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def print_stats(tensor, label=None):
|
| 226 |
+
if label is None:
|
| 227 |
+
label = _get_caller_arg_name(argnum=0)
|
| 228 |
+
stats = get_stats(tensor)
|
| 229 |
+
string = f"{label}:\n" + "\n".join(f"- {k}: {v}" for k, v in stats.items())
|
| 230 |
+
print(string)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def tree_reduce_unique(fn, tree, ensure_unique=True, **kwargs):
|
| 234 |
+
values = _pytree.tree_flatten(tree, **kwargs)[0]
|
| 235 |
+
values = tuple(map(fn, values))
|
| 236 |
+
first = values[0]
|
| 237 |
+
if ensure_unique:
|
| 238 |
+
for value in values[1:]:
|
| 239 |
+
if value != first:
|
| 240 |
+
raise RuntimeError(
|
| 241 |
+
f"different values found, {value} and {first} should be the same"
|
| 242 |
+
)
|
| 243 |
+
return first
|
thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
import warnings
|
| 5 |
+
from torchvision.transforms import Normalize
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Dino(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
input_size: int = 224,
|
| 14 |
+
repo_or_dir: str = "facebookresearch/dinov2",
|
| 15 |
+
dino_model: str = "dinov2_vitb14",
|
| 16 |
+
source: str = "github",
|
| 17 |
+
backbone_kwargs: Optional[Dict[str, Any]] = None,
|
| 18 |
+
normalize_images: bool = True,
|
| 19 |
+
# for backward compatible
|
| 20 |
+
prenorm_features: bool = False,
|
| 21 |
+
freeze_backbone: bool = True,
|
| 22 |
+
prune_network: bool = False, # False for backward compatible
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
if backbone_kwargs is None:
|
| 26 |
+
backbone_kwargs = {}
|
| 27 |
+
|
| 28 |
+
with warnings.catch_warnings():
|
| 29 |
+
warnings.simplefilter("ignore")
|
| 30 |
+
|
| 31 |
+
logger.info(f"Loading DINO model: {dino_model} from {repo_or_dir} (source: {source})")
|
| 32 |
+
if backbone_kwargs:
|
| 33 |
+
logger.info(f"DINO backbone kwargs: {backbone_kwargs}")
|
| 34 |
+
|
| 35 |
+
self.backbone = torch.hub.load(
|
| 36 |
+
repo_or_dir=repo_or_dir,
|
| 37 |
+
model=dino_model,
|
| 38 |
+
source=source,
|
| 39 |
+
verbose=False,
|
| 40 |
+
**backbone_kwargs,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Log model properties after loading
|
| 44 |
+
logger.info(f"Loaded DINO model - type: {type(self.backbone)}, "
|
| 45 |
+
f"embed_dim: {self.backbone.embed_dim}, "
|
| 46 |
+
f"patch_size: {getattr(self.backbone.patch_embed, 'patch_size', 'N/A')}")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
self.resize_input_size = (input_size, input_size)
|
| 50 |
+
self.embed_dim = self.backbone.embed_dim
|
| 51 |
+
self.input_size = input_size
|
| 52 |
+
self.input_channels = 3
|
| 53 |
+
self.normalize_images = normalize_images
|
| 54 |
+
self.prenorm_features = prenorm_features
|
| 55 |
+
self.register_buffer('mean', torch.as_tensor([[0.485, 0.456, 0.406]]).view(-1, 1, 1), persistent=False)
|
| 56 |
+
self.register_buffer('std', torch.as_tensor([[0.229, 0.224, 0.225]]).view(-1, 1, 1), persistent=False)
|
| 57 |
+
|
| 58 |
+
# freeze
|
| 59 |
+
if freeze_backbone:
|
| 60 |
+
self.requires_grad_(False)
|
| 61 |
+
self.eval()
|
| 62 |
+
elif not prune_network:
|
| 63 |
+
logger.warning(
|
| 64 |
+
"Unfreeze encoder w/o prune parameter may lead to error in ddp/fp16 training"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if prune_network:
|
| 68 |
+
self._prune_network()
|
| 69 |
+
|
| 70 |
+
def _preprocess_input(self, x):
|
| 71 |
+
_resized_images = torch.nn.functional.interpolate(
|
| 72 |
+
x,
|
| 73 |
+
size=self.resize_input_size,
|
| 74 |
+
mode="bilinear",
|
| 75 |
+
align_corners=False,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if x.shape[1] == 1:
|
| 79 |
+
_resized_images = _resized_images.repeat(1, 3, 1, 1)
|
| 80 |
+
|
| 81 |
+
if self.normalize_images:
|
| 82 |
+
_resized_images = _resized_images.sub_(self.mean).div_(self.std)
|
| 83 |
+
|
| 84 |
+
return _resized_images
|
| 85 |
+
|
| 86 |
+
def _forward_intermediate_layers(
|
| 87 |
+
self, input_img, intermediate_layers, cls_token=True
|
| 88 |
+
):
|
| 89 |
+
return self.backbone.get_intermediate_layers(
|
| 90 |
+
input_img,
|
| 91 |
+
intermediate_layers,
|
| 92 |
+
return_class_token=cls_token,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def _forward_last_layer(self, input_img):
|
| 96 |
+
output = self.backbone.forward_features(input_img)
|
| 97 |
+
if self.prenorm_features:
|
| 98 |
+
features = output["x_prenorm"]
|
| 99 |
+
tokens = F.layer_norm(features, features.shape[-1:])
|
| 100 |
+
else:
|
| 101 |
+
tokens = torch.cat(
|
| 102 |
+
[
|
| 103 |
+
output["x_norm_clstoken"].unsqueeze(1),
|
| 104 |
+
output["x_norm_patchtokens"],
|
| 105 |
+
],
|
| 106 |
+
dim=1,
|
| 107 |
+
)
|
| 108 |
+
return tokens
|
| 109 |
+
|
| 110 |
+
def forward(self, x, **kwargs):
|
| 111 |
+
_resized_images = self._preprocess_input(x)
|
| 112 |
+
tokens = self._forward_last_layer(_resized_images)
|
| 113 |
+
return tokens.to(x.dtype)
|
| 114 |
+
|
| 115 |
+
def _prune_network(self):
|
| 116 |
+
"""
|
| 117 |
+
Ran this script:
|
| 118 |
+
out = model(input)
|
| 119 |
+
loss = out.sum()
|
| 120 |
+
loss.backward()
|
| 121 |
+
|
| 122 |
+
for name, p in dino_model.named_parameters():
|
| 123 |
+
if p.grad is None:
|
| 124 |
+
print(name)
|
| 125 |
+
model.zero_grad()
|
| 126 |
+
"""
|
| 127 |
+
self.backbone.mask_token = None
|
| 128 |
+
if self.prenorm_features:
|
| 129 |
+
self.backbone.norm = torch.nn.Identity()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class DinoForMasks(torch.nn.Module):
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
backbone: Dino,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.backbone = backbone
|
| 139 |
+
self.embed_dim = self.backbone.embed_dim
|
| 140 |
+
|
| 141 |
+
def forward(self, image, mask):
|
| 142 |
+
return self.backbone.forward(mask)
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from torch import nn
|
| 6 |
+
from typing import Optional, Tuple, List, Literal, Dict
|
| 7 |
+
from sam3d_objects.model.layers.llama3.ff import FeedForward
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
class EmbedderFuser(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Fusing individual condition embedder. Require kwargs for the forward!
|
| 13 |
+
Args:
|
| 14 |
+
embedder_list: List of Tuples. Each Tuple consists of a condition_embedder
|
| 15 |
+
and a list of tuple. In the list, each tuple consists of a string, indicating
|
| 16 |
+
a kward, and astring, indicating the group of positional encoding to be used.
|
| 17 |
+
use_pos_embedding: whether to add positional embedding. If add, follow the index in
|
| 18 |
+
embedder_list. Choices of None (no pos emb), random, and learned.
|
| 19 |
+
projection_pre_norm: pre-normalize features before feeding into projector layers.
|
| 20 |
+
projection_net_hidden_dim_multiplier: hidden dimension for projection layer. If 0, don't use.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
embedder_list: List[Tuple[nn.Module, List[Tuple[str, Optional[str]]]]],
|
| 26 |
+
use_pos_embedding: Optional[Literal["random", "learned"]] = "learned",
|
| 27 |
+
projection_pre_norm: bool = True,
|
| 28 |
+
projection_net_hidden_dim_multiplier: float = 4.0,
|
| 29 |
+
compression_projection_multiplier: float = 0,
|
| 30 |
+
freeze: bool = False,
|
| 31 |
+
drop_modalities_weight: Dict[List[str], float] = None,
|
| 32 |
+
dropout_prob: float = 0.0,
|
| 33 |
+
force_drop_modalities: List[str] = None,
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
# torch.compile does not support OmegaConf.ListConfig, so we convert to a list
|
| 37 |
+
if not isinstance(embedder_list, List):
|
| 38 |
+
self.embedder_list = OmegaConf.to_container(embedder_list)
|
| 39 |
+
else:
|
| 40 |
+
self.embedder_list = embedder_list
|
| 41 |
+
|
| 42 |
+
self.embed_dims = 0
|
| 43 |
+
self.compression_projection_multiplier = compression_projection_multiplier
|
| 44 |
+
self.concate_embed_dims = 0
|
| 45 |
+
# keep moduleList to be compatible with nn module
|
| 46 |
+
self.module_list = []
|
| 47 |
+
max_positional_embed_idx = 0
|
| 48 |
+
self.positional_embed_map = {}
|
| 49 |
+
for condition_embedder, kwargs_info in self.embedder_list:
|
| 50 |
+
self.embed_dims = max(self.embed_dims, condition_embedder.embed_dim)
|
| 51 |
+
self.module_list.append(condition_embedder)
|
| 52 |
+
for _, pos_group in kwargs_info:
|
| 53 |
+
self.concate_embed_dims += condition_embedder.embed_dim
|
| 54 |
+
if pos_group is not None:
|
| 55 |
+
if pos_group not in self.positional_embed_map:
|
| 56 |
+
self.positional_embed_map[pos_group] = max_positional_embed_idx
|
| 57 |
+
max_positional_embed_idx += 1
|
| 58 |
+
self.module_list = nn.ModuleList(self.module_list)
|
| 59 |
+
self.use_pos_embedding = use_pos_embedding
|
| 60 |
+
if self.use_pos_embedding == "random":
|
| 61 |
+
idx_emb = torch.randn(max_positional_embed_idx + 1, 1, self.embed_dims)
|
| 62 |
+
self.register_buffer("idx_emb", idx_emb)
|
| 63 |
+
elif self.use_pos_embedding == "learned":
|
| 64 |
+
self.idx_emb = nn.Parameter(
|
| 65 |
+
torch.empty(max_positional_embed_idx + 1, self.embed_dims)
|
| 66 |
+
)
|
| 67 |
+
nn.init.normal_(
|
| 68 |
+
self.idx_emb, mean=0.0, std=1.0 / math.sqrt(self.embed_dims)
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
raise NotImplementedError(f"Unknown pos embedding {self.use_pos_embedding}")
|
| 72 |
+
|
| 73 |
+
self.projection_pre_norm = projection_pre_norm
|
| 74 |
+
self.projection_net_hidden_dim_multiplier = projection_net_hidden_dim_multiplier
|
| 75 |
+
if projection_net_hidden_dim_multiplier > 0:
|
| 76 |
+
self.projection_nets = []
|
| 77 |
+
for condition_embedder, _ in self.embedder_list:
|
| 78 |
+
self.projection_nets.append(
|
| 79 |
+
self._make_projection_net(
|
| 80 |
+
condition_embedder.embed_dim,
|
| 81 |
+
self.embed_dims,
|
| 82 |
+
self.projection_net_hidden_dim_multiplier,
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
self.projection_nets = nn.ModuleList(self.projection_nets)
|
| 86 |
+
|
| 87 |
+
if compression_projection_multiplier > 0:
|
| 88 |
+
self.compression_projector = self._make_projection_net(
|
| 89 |
+
self.concate_embed_dims,
|
| 90 |
+
self.embed_dims,
|
| 91 |
+
self.compression_projection_multiplier,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.drop_modalities_weight = drop_modalities_weight if drop_modalities_weight is not None else []
|
| 95 |
+
self.dropout_prob = dropout_prob
|
| 96 |
+
self.force_drop_modalities = force_drop_modalities
|
| 97 |
+
|
| 98 |
+
if freeze:
|
| 99 |
+
self.requires_grad_(False)
|
| 100 |
+
self.eval()
|
| 101 |
+
|
| 102 |
+
def _make_projection_net(
|
| 103 |
+
self,
|
| 104 |
+
input_embed_dim,
|
| 105 |
+
output_embed_dim: int,
|
| 106 |
+
multiplier: int,
|
| 107 |
+
):
|
| 108 |
+
if self.projection_pre_norm:
|
| 109 |
+
pre_norm = nn.LayerNorm(input_embed_dim)
|
| 110 |
+
else:
|
| 111 |
+
pre_norm = nn.Identity()
|
| 112 |
+
|
| 113 |
+
# Per-token projection + gated activation
|
| 114 |
+
ff_net = FeedForward(
|
| 115 |
+
dim=input_embed_dim,
|
| 116 |
+
hidden_dim=int(multiplier * output_embed_dim),
|
| 117 |
+
output_dim=output_embed_dim,
|
| 118 |
+
)
|
| 119 |
+
return nn.Sequential(pre_norm, ff_net)
|
| 120 |
+
|
| 121 |
+
def _build_dropout_distribution(self, device):
|
| 122 |
+
"""
|
| 123 |
+
Build the probability distribution for dropout configurations.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
dropout_configs: List of sets containing modalities to drop
|
| 127 |
+
cumsum_weights: Cumulative sum of weights for sampling
|
| 128 |
+
"""
|
| 129 |
+
dropout_configs = []
|
| 130 |
+
weights = []
|
| 131 |
+
|
| 132 |
+
# Add no-dropout configuration with remaining probability
|
| 133 |
+
dropout_configs.append(set())
|
| 134 |
+
weights.append(1.0 - self.dropout_prob)
|
| 135 |
+
|
| 136 |
+
# Add configured dropout patterns
|
| 137 |
+
total_dropout_weight = sum(w for _, w in self.drop_modalities_weight)
|
| 138 |
+
assert total_dropout_weight > 0, "Total dropout weight must be positive when drop_modalities_weight is provided"
|
| 139 |
+
for modality_list, weight in self.drop_modalities_weight:
|
| 140 |
+
dropout_configs.append(set(modality_list))
|
| 141 |
+
# Scale weight by dropout_prob to ensure total probability sums to 1
|
| 142 |
+
weights.append(self.dropout_prob * weight / total_dropout_weight)
|
| 143 |
+
|
| 144 |
+
# Convert weights to cumulative distribution
|
| 145 |
+
weights_tensor = torch.tensor(weights, device=device)
|
| 146 |
+
|
| 147 |
+
was_deterministic = torch.are_deterministic_algorithms_enabled()
|
| 148 |
+
torch.use_deterministic_algorithms(False)
|
| 149 |
+
cumsum_weights = torch.cumsum(weights_tensor, dim=0)
|
| 150 |
+
torch.use_deterministic_algorithms(was_deterministic)
|
| 151 |
+
|
| 152 |
+
return dropout_configs, cumsum_weights
|
| 153 |
+
|
| 154 |
+
def _apply_force_drop(self, kwarg_names: List[str], tokens: List[torch.Tensor]):
|
| 155 |
+
if not self.force_drop_modalities:
|
| 156 |
+
return tokens
|
| 157 |
+
|
| 158 |
+
force_drop_set = set(self.force_drop_modalities)
|
| 159 |
+
result_tokens = []
|
| 160 |
+
|
| 161 |
+
for kwarg_name, token_tensor in zip(kwarg_names, tokens):
|
| 162 |
+
# Create mask: 0 for forced drop, 1 otherwise
|
| 163 |
+
mask = 0.0 if kwarg_name in force_drop_set else 1.0
|
| 164 |
+
result_tokens.append(token_tensor * mask)
|
| 165 |
+
|
| 166 |
+
return result_tokens
|
| 167 |
+
|
| 168 |
+
def _dropout_modalities(self, kwarg_names: List[str], tokens: List[torch.Tensor]):
|
| 169 |
+
# First apply forced drops (deterministic, always applied)
|
| 170 |
+
tokens = self._apply_force_drop(kwarg_names, tokens)
|
| 171 |
+
|
| 172 |
+
# Then apply probabilistic dropout (only in training)
|
| 173 |
+
if not self.training or self.dropout_prob <= 0 or not self.drop_modalities_weight:
|
| 174 |
+
return tokens
|
| 175 |
+
|
| 176 |
+
batch_size = tokens[0].shape[0]
|
| 177 |
+
device = tokens[0].device
|
| 178 |
+
|
| 179 |
+
# Build dropout configurations and sample which to use per batch element
|
| 180 |
+
dropout_configs, cumsum_weights = self._build_dropout_distribution(device)
|
| 181 |
+
rand_vals = torch.rand(batch_size, device=device)
|
| 182 |
+
# Clamp indices to valid range (handle edge case where rand_val == 1.0)
|
| 183 |
+
config_indices = torch.searchsorted(cumsum_weights, rand_vals).clamp(max=len(dropout_configs) - 1)
|
| 184 |
+
|
| 185 |
+
# Apply dropout masks with vectorized operations
|
| 186 |
+
result_tokens = []
|
| 187 |
+
for kwarg_name, token_tensor in zip(kwarg_names, tokens):
|
| 188 |
+
# Start with all ones (no dropout)
|
| 189 |
+
mask = torch.ones(batch_size, dtype=token_tensor.dtype, device=device)
|
| 190 |
+
|
| 191 |
+
# Vectorized mask creation: check all configurations at once
|
| 192 |
+
for config_idx, modalities_to_drop in enumerate(dropout_configs):
|
| 193 |
+
if kwarg_name in modalities_to_drop:
|
| 194 |
+
# Set mask to 0 for all batch elements using this configuration
|
| 195 |
+
mask[config_indices == config_idx] = 0.0
|
| 196 |
+
|
| 197 |
+
# Reshape mask to match token dimensions
|
| 198 |
+
mask = mask.view([batch_size] + [1] * (token_tensor.ndim - 1))
|
| 199 |
+
result_tokens.append(token_tensor * mask)
|
| 200 |
+
|
| 201 |
+
return result_tokens
|
| 202 |
+
|
| 203 |
+
def forward(self, *args, **kwargs):
|
| 204 |
+
tokens = []
|
| 205 |
+
kwarg_names = []
|
| 206 |
+
|
| 207 |
+
for i, (condition_embedder, kwargs_info) in enumerate(self.embedder_list):
|
| 208 |
+
# Ideally, we would batch the inputs; but that assumes same-sized inputs
|
| 209 |
+
for kwarg_name, pos_group in kwargs_info:
|
| 210 |
+
if kwarg_name not in kwargs:
|
| 211 |
+
logger.warning(f"{kwarg_name} not in kwargs to condition embedder!")
|
| 212 |
+
input_cond = kwargs[kwarg_name]
|
| 213 |
+
cond_token = condition_embedder(input_cond)
|
| 214 |
+
if self.projection_net_hidden_dim_multiplier > 0:
|
| 215 |
+
cond_token = self.projection_nets[i](cond_token)
|
| 216 |
+
if pos_group is not None:
|
| 217 |
+
pos_idx = self.positional_embed_map[pos_group]
|
| 218 |
+
if self.use_pos_embedding == "random":
|
| 219 |
+
cond_token += self.idx_emb[pos_idx : pos_idx + 1]
|
| 220 |
+
elif self.use_pos_embedding == "learned":
|
| 221 |
+
cond_token += self.idx_emb[pos_idx : pos_idx + 1, None]
|
| 222 |
+
else:
|
| 223 |
+
raise NotImplementedError(
|
| 224 |
+
f"Unknown pos embedding {self.use_pos_embedding}"
|
| 225 |
+
)
|
| 226 |
+
tokens.append(cond_token)
|
| 227 |
+
kwarg_names.append(kwarg_name)
|
| 228 |
+
|
| 229 |
+
# Apply dropout modalities with preserved order
|
| 230 |
+
tokens = self._dropout_modalities(kwarg_names, tokens)
|
| 231 |
+
|
| 232 |
+
if self.compression_projection_multiplier > 0:
|
| 233 |
+
tokens = torch.cat(tokens, dim=-1)
|
| 234 |
+
tokens = self.compression_projector(tokens)
|
| 235 |
+
else:
|
| 236 |
+
tokens = torch.cat(tokens, dim=1)
|
| 237 |
+
|
| 238 |
+
return tokens
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PointRemapper(nn.Module):
|
| 7 |
+
"""Handles remapping of 3D point coordinates and their inverse transformations."""
|
| 8 |
+
|
| 9 |
+
VALID_TYPES = ["linear", "sinh", "exp", "sinh_exp", "exp_disparity"]
|
| 10 |
+
|
| 11 |
+
def __init__(self, remap_type: str = "exp"):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.remap_type = remap_type
|
| 14 |
+
|
| 15 |
+
if remap_type not in self.VALID_TYPES:
|
| 16 |
+
raise ValueError(
|
| 17 |
+
f"Invalid remap type: {remap_type}. Must be one of {self.VALID_TYPES}"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
"""Apply remapping to point coordinates."""
|
| 22 |
+
if self.remap_type == "linear":
|
| 23 |
+
return points
|
| 24 |
+
|
| 25 |
+
elif self.remap_type == "sinh":
|
| 26 |
+
return torch.asinh(points)
|
| 27 |
+
|
| 28 |
+
elif self.remap_type == "exp":
|
| 29 |
+
xy_scaled, z_exp = points.split([2, 1], dim=-1)
|
| 30 |
+
# Use log1p for better numerical stability near zero
|
| 31 |
+
z = torch.log1p(z_exp)
|
| 32 |
+
xy = xy_scaled / (1 + z_exp)
|
| 33 |
+
return torch.cat([xy, z], dim=-1)
|
| 34 |
+
|
| 35 |
+
elif self.remap_type == "exp_disparity":
|
| 36 |
+
xy_scaled, z_exp = points.split([2, 1], dim=-1)
|
| 37 |
+
xy = xy_scaled / z_exp
|
| 38 |
+
z = torch.log(z_exp)
|
| 39 |
+
return torch.cat([xy, z], dim=-1)
|
| 40 |
+
|
| 41 |
+
elif self.remap_type == "sinh_exp":
|
| 42 |
+
xy_sinh, z_exp = points.split([2, 1], dim=-1)
|
| 43 |
+
xy = torch.asinh(xy_sinh)
|
| 44 |
+
z = torch.log(z_exp.clamp(min=1e-8))
|
| 45 |
+
return torch.cat([xy, z], dim=-1)
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unknown remap type: {self.remap_type}")
|
| 49 |
+
|
| 50 |
+
def inverse(self, points: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
"""Apply inverse remapping to recover original point coordinates."""
|
| 52 |
+
if self.remap_type == "linear":
|
| 53 |
+
return points
|
| 54 |
+
|
| 55 |
+
elif self.remap_type == "sinh":
|
| 56 |
+
return torch.sinh(points)
|
| 57 |
+
|
| 58 |
+
elif self.remap_type == "exp":
|
| 59 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 60 |
+
# Inverse of log1p is expm1(z) = exp(z) - 1
|
| 61 |
+
z_exp = torch.expm1(z)
|
| 62 |
+
# Inverse of xy/(1+z_exp) is xy*(1+z_exp)
|
| 63 |
+
return torch.cat([xy * (1 + z_exp), z_exp], dim=-1)
|
| 64 |
+
|
| 65 |
+
elif self.remap_type == "exp_disparity":
|
| 66 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 67 |
+
z_exp = torch.exp(z)
|
| 68 |
+
return torch.cat([xy * z_exp, z_exp], dim=-1)
|
| 69 |
+
|
| 70 |
+
elif self.remap_type == "sinh_exp":
|
| 71 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 72 |
+
return torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 73 |
+
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"Unknown remap type: {self.remap_type}")
|
| 76 |
+
|
| 77 |
+
def extra_repr(self) -> str:
|
| 78 |
+
return f"remap_type='{self.remap_type}'"
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from timm.models.vision_transformer import Block
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from functools import partial
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
from .point_remapper import PointRemapper
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PointPatchEmbed(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Projects (x,y,z) → D
|
| 15 |
+
Splits into patches (patch_size x patch_size)
|
| 16 |
+
Runs a tiny self-attention block inside each window
|
| 17 |
+
Returns one token per window.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
input_size: int = 256,
|
| 23 |
+
patch_size: int = 8,
|
| 24 |
+
embed_dim: int = 768,
|
| 25 |
+
remap_output: str = "exp", # Add remap_output parameter
|
| 26 |
+
dropout_prob: float = 0.0, # Dropout probability for pointmap
|
| 27 |
+
force_dropout_always: bool = False, # Force dropout during validation/inference
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.input_size = input_size
|
| 31 |
+
self.patch_size = patch_size
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
self.dropout_prob = dropout_prob
|
| 34 |
+
self.force_dropout_always = force_dropout_always
|
| 35 |
+
|
| 36 |
+
# Point remapper
|
| 37 |
+
self.point_remapper = PointRemapper(remap_output)
|
| 38 |
+
|
| 39 |
+
# (1) point embedding
|
| 40 |
+
self.point_proj = nn.Linear(3, embed_dim)
|
| 41 |
+
self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim))
|
| 42 |
+
|
| 43 |
+
# Special embedding for dropped patches (used during dropout)
|
| 44 |
+
# Alternative dropout strategies to consider:
|
| 45 |
+
# 1. Drop all tokens entirely or use a single token only
|
| 46 |
+
# 2. Different dropout patterns per window
|
| 47 |
+
# 3. Use dropped_xyz_token/invalid_xyz_token per pixel
|
| 48 |
+
if dropout_prob > 0:
|
| 49 |
+
self.dropped_xyz_token = nn.Parameter(torch.zeros(embed_dim))
|
| 50 |
+
|
| 51 |
+
# (2) positional embedding
|
| 52 |
+
num_patches = input_size // patch_size
|
| 53 |
+
# For patches
|
| 54 |
+
self.pos_embed = nn.Parameter(
|
| 55 |
+
torch.zeros(1, embed_dim, num_patches, num_patches)
|
| 56 |
+
)
|
| 57 |
+
# For points in a patch
|
| 58 |
+
self.pos_embed_window = nn.Parameter(
|
| 59 |
+
torch.zeros(1, 1 + patch_size * patch_size, embed_dim)
|
| 60 |
+
)
|
| 61 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 62 |
+
|
| 63 |
+
# (3) within-patch transformer block(s)
|
| 64 |
+
# From MCC: https://github.com/facebookresearch/MCC/blob/b04c97518360e4fdedfb6f090db7e90d0c2f8ae6/mcc_model.py#L97
|
| 65 |
+
self.blocks = nn.ModuleList(
|
| 66 |
+
[
|
| 67 |
+
Block(
|
| 68 |
+
embed_dim,
|
| 69 |
+
num_heads=16,
|
| 70 |
+
mlp_ratio=2.0,
|
| 71 |
+
qkv_bias=True,
|
| 72 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 73 |
+
)
|
| 74 |
+
]
|
| 75 |
+
)
|
| 76 |
+
self.initialize_weights()
|
| 77 |
+
|
| 78 |
+
def initialize_weights(self):
|
| 79 |
+
# Initialize positional embeddings with small std
|
| 80 |
+
nn.init.normal_(self.pos_embed, std=0.02)
|
| 81 |
+
nn.init.normal_(self.pos_embed_window, std=0.02)
|
| 82 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 83 |
+
nn.init.normal_(self.invalid_xyz_token, std=0.02)
|
| 84 |
+
|
| 85 |
+
# Initialize dropped pointmap token if dropout is enabled
|
| 86 |
+
if self.dropout_prob > 0:
|
| 87 |
+
nn.init.normal_(self.dropped_xyz_token, std=0.02)
|
| 88 |
+
|
| 89 |
+
# Initialize point projection with xavier uniform for better gradient flow
|
| 90 |
+
# This is crucial since pointmaps can have large value ranges
|
| 91 |
+
nn.init.xavier_uniform_(self.point_proj.weight, gain=0.02)
|
| 92 |
+
if self.point_proj.bias is not None:
|
| 93 |
+
nn.init.constant_(self.point_proj.bias, 0)
|
| 94 |
+
|
| 95 |
+
def _get_pos_embed(self, hw):
|
| 96 |
+
h, w = hw
|
| 97 |
+
pos_embed = F.interpolate(
|
| 98 |
+
self.pos_embed, size=(h, w), mode="bilinear", align_corners=False
|
| 99 |
+
)
|
| 100 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1) # (B, H, W, C)
|
| 101 |
+
return pos_embed
|
| 102 |
+
|
| 103 |
+
def resize_input(self, xyz: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
resized_xyz = F.interpolate(xyz, size=self.input_size, mode="nearest")
|
| 105 |
+
resized_xyz = resized_xyz.permute(0, 2, 3, 1) # (B, H, W, C)
|
| 106 |
+
return resized_xyz
|
| 107 |
+
|
| 108 |
+
def apply_pointmap_dropout(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Apply dropout to pointmap embeddings.
|
| 111 |
+
Drops entire pointmap for selected samples during training or when forced.
|
| 112 |
+
|
| 113 |
+
When force_dropout_always is True, always drops pointmap regardless of training mode.
|
| 114 |
+
"""
|
| 115 |
+
# Check if we should apply dropout
|
| 116 |
+
should_apply_dropout = (self.training or self.force_dropout_always) and self.dropout_prob > 0
|
| 117 |
+
|
| 118 |
+
if not should_apply_dropout:
|
| 119 |
+
return embeddings
|
| 120 |
+
|
| 121 |
+
# Check if dropout infrastructure exists
|
| 122 |
+
if not hasattr(self, 'dropped_xyz_token'):
|
| 123 |
+
if self.force_dropout_always:
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
"Cannot force dropout: model was initialized with dropout_prob=0. "
|
| 126 |
+
"Re-initialize with dropout_prob > 0 to enable forced dropout."
|
| 127 |
+
)
|
| 128 |
+
return embeddings
|
| 129 |
+
|
| 130 |
+
batch_size, n_windows, embed_dim = embeddings.shape
|
| 131 |
+
|
| 132 |
+
# Decide dropout behavior
|
| 133 |
+
if self.force_dropout_always and not self.training:
|
| 134 |
+
# When forced during inference, always drop (100% dropout)
|
| 135 |
+
drop_mask = torch.ones(batch_size, device=embeddings.device, dtype=torch.bool)
|
| 136 |
+
else:
|
| 137 |
+
# Normal training dropout - use configured probability
|
| 138 |
+
drop_mask = torch.rand(batch_size, device=embeddings.device) < self.dropout_prob
|
| 139 |
+
|
| 140 |
+
# Create dropped embedding for all windows - use same token for all patches
|
| 141 |
+
# Shape: (batch_size, n_windows, embed_dim)
|
| 142 |
+
dropped_embedding = self.dropped_xyz_token.view(1, 1, embed_dim).expand(batch_size, n_windows, embed_dim)
|
| 143 |
+
|
| 144 |
+
# Add positional embeddings to dropped tokens (same as regular embeddings get)
|
| 145 |
+
n_windows_h = n_windows_w = int(n_windows ** 0.5)
|
| 146 |
+
pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape(
|
| 147 |
+
1, n_windows, embed_dim
|
| 148 |
+
)
|
| 149 |
+
dropped_embedding = dropped_embedding + pos_embed_patch
|
| 150 |
+
drop_mask_expanded = drop_mask.view(batch_size, 1, 1).expand_as(embeddings)
|
| 151 |
+
embeddings = torch.where(drop_mask_expanded, dropped_embedding, embeddings)
|
| 152 |
+
|
| 153 |
+
return embeddings
|
| 154 |
+
|
| 155 |
+
@torch._dynamo.disable()
|
| 156 |
+
def embed_pointmap_windows(
|
| 157 |
+
self, xyz: torch.Tensor, valid_mask: torch.Tensor = None
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""Process pointmap into window embeddings without positional encoding"""
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
xyz = self.resize_input(xyz)
|
| 162 |
+
if valid_mask is None:
|
| 163 |
+
valid_mask = xyz.isfinite().all(dim=-1)
|
| 164 |
+
|
| 165 |
+
B, H, W, _ = xyz.shape
|
| 166 |
+
assert (
|
| 167 |
+
H % self.patch_size == 0 and W % self.patch_size == 0
|
| 168 |
+
), "image must be divisible by patch_size"
|
| 169 |
+
|
| 170 |
+
# (1) Handle NaN values before remapping to prevent propagation
|
| 171 |
+
xyz_safe = xyz.clone()
|
| 172 |
+
xyz_safe[~valid_mask] = 0.0 # Set invalid points to 0 before remapping
|
| 173 |
+
|
| 174 |
+
# (1b) remap points to normalize their range
|
| 175 |
+
xyz_remapped = self.point_remapper(xyz_safe)
|
| 176 |
+
|
| 177 |
+
# (2) project + invalid token
|
| 178 |
+
x = self.point_proj(xyz_remapped) # (B,H,W,D)
|
| 179 |
+
|
| 180 |
+
x[~valid_mask] = 0.0 # Stop gradient for invalid points
|
| 181 |
+
x[~valid_mask] += self.invalid_xyz_token
|
| 182 |
+
|
| 183 |
+
return x, B, H, W
|
| 184 |
+
|
| 185 |
+
def inner_forward(
|
| 186 |
+
self, x: torch.Tensor, B: int, H: int, W: int
|
| 187 |
+
) -> torch.Tensor:
|
| 188 |
+
x = x.view(
|
| 189 |
+
B,
|
| 190 |
+
H // self.patch_size,
|
| 191 |
+
self.patch_size,
|
| 192 |
+
W // self.patch_size,
|
| 193 |
+
self.patch_size,
|
| 194 |
+
self.embed_dim,
|
| 195 |
+
) # (B, hW, wW, ws, ws, D)
|
| 196 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # (B, hW, wW, ws, ws, D)
|
| 197 |
+
x = x.view(-1, self.patch_size * self.patch_size, self.embed_dim)
|
| 198 |
+
|
| 199 |
+
# (4) CLS token that contains the patch information
|
| 200 |
+
cls_tok = self.cls_token.expand(x.shape[0], -1, -1)
|
| 201 |
+
toks = torch.cat([cls_tok, x], dim=1)
|
| 202 |
+
|
| 203 |
+
# (5) add positional embedding for window
|
| 204 |
+
toks = toks + self.pos_embed_window
|
| 205 |
+
|
| 206 |
+
# (6) intra-window attention
|
| 207 |
+
for blk in self.blocks:
|
| 208 |
+
toks = blk(toks)
|
| 209 |
+
|
| 210 |
+
# (7) Extract CLS tokens and reshape to (B, n_windows, embed_dim)
|
| 211 |
+
n_windows_h = H // self.patch_size
|
| 212 |
+
n_windows_w = W // self.patch_size
|
| 213 |
+
window_embeddings = toks[:, 0].view(B, n_windows_h * n_windows_w, self.embed_dim)
|
| 214 |
+
|
| 215 |
+
# Add positional embeddings
|
| 216 |
+
pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape(
|
| 217 |
+
1, n_windows_h * n_windows_w, self.embed_dim
|
| 218 |
+
)
|
| 219 |
+
out = window_embeddings + pos_embed_patch
|
| 220 |
+
|
| 221 |
+
# Apply dropout if enabled (during training OR when forced)
|
| 222 |
+
if (self.training or self.force_dropout_always) and self.dropout_prob > 0:
|
| 223 |
+
out = self.apply_pointmap_dropout(out)
|
| 224 |
+
|
| 225 |
+
return out
|
| 226 |
+
|
| 227 |
+
def forward(
|
| 228 |
+
self, xyz: torch.Tensor, valid_mask: torch.Tensor = None
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
"""
|
| 231 |
+
xyz : (B, 3, H, W) map of (x,y,z) coordinates
|
| 232 |
+
valid_mask : (B, H, W) boolean - True for valid points (optional)
|
| 233 |
+
|
| 234 |
+
returns: (B, num_windows, D)
|
| 235 |
+
"""
|
| 236 |
+
# Get window embeddings
|
| 237 |
+
x, B, H, W = self.embed_pointmap_windows(xyz, valid_mask)
|
| 238 |
+
return self.inner_forward(x, B, H, W)
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import torch
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Base(torch.nn.Module):
|
| 7 |
+
def __init__(self, seed_or_generator: Optional[Union[int, torch.Generator]] = None):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
if isinstance(seed_or_generator, torch.Generator):
|
| 11 |
+
self.random_generator = seed_or_generator
|
| 12 |
+
elif isinstance(seed_or_generator, int):
|
| 13 |
+
self.seed = seed_or_generator
|
| 14 |
+
elif seed_or_generator is None:
|
| 15 |
+
self.random_generator = torch.default_generator
|
| 16 |
+
else:
|
| 17 |
+
raise RuntimeError(
|
| 18 |
+
f"cannot use argument of type {type(seed_or_generator)} to set random generator"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def seed(self):
|
| 23 |
+
raise AttributeError(f"Cannot read attribute 'seed'.")
|
| 24 |
+
|
| 25 |
+
@seed.setter
|
| 26 |
+
def seed(self, value: int):
|
| 27 |
+
self._random_generator = torch.Generator().manual_seed(value)
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def random_generator(self):
|
| 31 |
+
return self._random_generator
|
| 32 |
+
|
| 33 |
+
@random_generator.setter
|
| 34 |
+
def random_generator(self, generator: torch.Generator):
|
| 35 |
+
self._random_generator = generator
|
| 36 |
+
|
| 37 |
+
def forward(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals):
|
| 38 |
+
return self.generate(
|
| 39 |
+
x_shape,
|
| 40 |
+
x_device,
|
| 41 |
+
*args_conditionals,
|
| 42 |
+
**kwargs_conditionals,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def generate(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals):
|
| 46 |
+
for _, xt, _ in self.generate_iter(
|
| 47 |
+
x_shape,
|
| 48 |
+
x_device,
|
| 49 |
+
*args_conditionals,
|
| 50 |
+
**kwargs_conditionals,
|
| 51 |
+
):
|
| 52 |
+
pass
|
| 53 |
+
return xt
|
| 54 |
+
|
| 55 |
+
def generate_iter(
|
| 56 |
+
self,
|
| 57 |
+
x_shape,
|
| 58 |
+
x_device,
|
| 59 |
+
*args_conditionals,
|
| 60 |
+
**kwargs_conditionals,
|
| 61 |
+
):
|
| 62 |
+
raise NotImplementedError
|
| 63 |
+
|
| 64 |
+
def loss(self, x, *args_conditionals, **kwargs_conditionals):
|
| 65 |
+
raise NotImplementedError
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from numbers import Number
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
from torch.utils import _pytree
|
| 7 |
+
from torch.utils._pytree import tree_map_only
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
def _zeros_like(struct):
|
| 11 |
+
def make_zeros(x):
|
| 12 |
+
if isinstance(x, torch.Tensor):
|
| 13 |
+
return torch.zeros_like(x)
|
| 14 |
+
return x
|
| 15 |
+
|
| 16 |
+
return _pytree.tree_map(make_zeros, struct)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def zero_out(args, kwargs):
|
| 20 |
+
args = _zeros_like(args)
|
| 21 |
+
kwargs = _zeros_like(kwargs)
|
| 22 |
+
return args, kwargs
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def discard(args, kwargs):
|
| 26 |
+
return (), {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _drop_tensors(struct):
|
| 30 |
+
"""
|
| 31 |
+
Drop any conditioning that are tensors
|
| 32 |
+
Not using _pytree since we actually want to throw them instead of keeping them.
|
| 33 |
+
"""
|
| 34 |
+
if isinstance(struct, dict):
|
| 35 |
+
return {
|
| 36 |
+
k: _drop_tensors(v)
|
| 37 |
+
for k, v in struct.items()
|
| 38 |
+
if not isinstance(v, torch.Tensor)
|
| 39 |
+
}
|
| 40 |
+
elif isinstance(struct, (list, tuple)):
|
| 41 |
+
filtered = [_drop_tensors(x) for x in struct if not isinstance(x, torch.Tensor)]
|
| 42 |
+
return tuple(filtered) if isinstance(struct, tuple) else filtered
|
| 43 |
+
else:
|
| 44 |
+
return struct
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def drop_tensors(args, kwargs):
|
| 48 |
+
args = _drop_tensors(args)
|
| 49 |
+
kwargs = _drop_tensors(kwargs)
|
| 50 |
+
return args, kwargs
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def add_flag(args, kwargs):
|
| 54 |
+
kwargs["cfg"] = True
|
| 55 |
+
return args, kwargs
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ClassifierFreeGuidance(torch.nn.Module):
|
| 59 |
+
UNCONDITIONAL_HANDLING_TYPES = {
|
| 60 |
+
"zeros": zero_out,
|
| 61 |
+
"discard": discard,
|
| 62 |
+
"drop_tensors": drop_tensors,
|
| 63 |
+
"add_flag": add_flag,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
backbone, # backbone should be a backbone/generator (e.g. DDPM/DDIM/FlowMatching)
|
| 69 |
+
p_unconditional=0.1,
|
| 70 |
+
strength=3.0,
|
| 71 |
+
# "zeros" = set cond tensors to 0,
|
| 72 |
+
# "discard" = remove cond arguments and let underlying model handle it
|
| 73 |
+
# "drop_tensors" = drop all tensors but leave non-tensors
|
| 74 |
+
# "add_flag" = add an argument in kwargs as "cfg" and defer the handling to generator backbone
|
| 75 |
+
unconditional_handling="zeros",
|
| 76 |
+
interval=None, # only perform cfg if t within interval
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
if not (
|
| 81 |
+
unconditional_handling
|
| 82 |
+
in ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES
|
| 83 |
+
):
|
| 84 |
+
raise RuntimeError(
|
| 85 |
+
f"'{unconditional_handling}' is not valid for `unconditional_handling`, should be in {ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.backbone = backbone
|
| 89 |
+
self.p_unconditional = p_unconditional
|
| 90 |
+
self.strength = strength
|
| 91 |
+
self.unconditional_handling = unconditional_handling
|
| 92 |
+
self.interval = interval
|
| 93 |
+
self._make_unconditional_args = (
|
| 94 |
+
ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES[
|
| 95 |
+
self.unconditional_handling
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def _cfg_step_tensor(self, y_cond, y_uncond, strength):
|
| 100 |
+
return (1 + strength) * y_cond - strength * y_uncond
|
| 101 |
+
|
| 102 |
+
def _cfg_step(self, y_cond, y_uncond, strength):
|
| 103 |
+
if isinstance(strength, dict):
|
| 104 |
+
return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, strength)
|
| 105 |
+
else:
|
| 106 |
+
return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength), y_cond, y_uncond)
|
| 107 |
+
|
| 108 |
+
def inner_forward(self, x, t, is_cond, strength, *args_cond, **kwargs_cond):
|
| 109 |
+
y_cond = self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 110 |
+
if is_cond:
|
| 111 |
+
return y_cond
|
| 112 |
+
else:
|
| 113 |
+
args_cond, kwargs_cond = self._make_unconditional_args(
|
| 114 |
+
args_cond,
|
| 115 |
+
kwargs_cond,
|
| 116 |
+
)
|
| 117 |
+
y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 118 |
+
return self._cfg_step(y_cond, y_uncond, strength)
|
| 119 |
+
|
| 120 |
+
def forward(self, x, t, *args_cond, **kwargs_cond):
|
| 121 |
+
# handle case when no conditional arguments are provided
|
| 122 |
+
if len(args_cond) + len(kwargs_cond) == 0: # unconditional
|
| 123 |
+
if self.unconditional_handling != "discard":
|
| 124 |
+
raise RuntimeError(
|
| 125 |
+
f"cannot call `ClassifierFreeGuidance` module without condition"
|
| 126 |
+
)
|
| 127 |
+
return self.backbone(x, t)
|
| 128 |
+
else: # conditional arguments are provided
|
| 129 |
+
# training mode
|
| 130 |
+
if self.training:
|
| 131 |
+
coin_flip = random.random() < self.p_unconditional
|
| 132 |
+
if coin_flip: # unconditional
|
| 133 |
+
args_cond, kwargs_cond = self._make_unconditional_args(
|
| 134 |
+
args_cond,
|
| 135 |
+
kwargs_cond,
|
| 136 |
+
)
|
| 137 |
+
return self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 138 |
+
else: # inference mode
|
| 139 |
+
strength = get_strength(self.strength, self.interval, t)
|
| 140 |
+
is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
|
| 141 |
+
return self.inner_forward(
|
| 142 |
+
x, t, is_cond, strength, *args_cond, **kwargs_cond
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def get_strength(strength, interval, t):
|
| 146 |
+
if interval is None:
|
| 147 |
+
return _pytree.tree_map(lambda x: 0.0, strength)
|
| 148 |
+
|
| 149 |
+
# If interval is not a dict (single tuple), broadcast it
|
| 150 |
+
if not isinstance(interval, dict):
|
| 151 |
+
return _pytree.tree_map(
|
| 152 |
+
lambda x: x if interval[0] <= t <= interval[1] else 0.0,
|
| 153 |
+
strength
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return _pytree.tree_map(
|
| 157 |
+
lambda x, iv: x if iv[0] <= t <= iv[1] else 0.0,
|
| 158 |
+
strength,
|
| 159 |
+
interval
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
class PointmapCFG(ClassifierFreeGuidance):
|
| 163 |
+
|
| 164 |
+
def __init__(self, *args, strength_pm=0.0, **kwargs):
|
| 165 |
+
super().__init__(*args, **kwargs)
|
| 166 |
+
self.strength_pm = strength_pm
|
| 167 |
+
|
| 168 |
+
def _cfg_step_tensor(self, y_cond, y_uncond, y_unpm, strength, strength_pm):
|
| 169 |
+
# https://arxiv.org/abs/2411.18613
|
| 170 |
+
return y_cond \
|
| 171 |
+
+ strength_pm * (y_cond - y_unpm) \
|
| 172 |
+
+ strength * (y_unpm - y_uncond)
|
| 173 |
+
|
| 174 |
+
def _cfg_step(self, y_cond, y_uncond, y_pm, strength, strength_pm):
|
| 175 |
+
if isinstance(strength, dict):
|
| 176 |
+
return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, y_pm, strength, strength_pm)
|
| 177 |
+
else:
|
| 178 |
+
return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength, strength_pm=strength_pm), y_cond, y_uncond, y_pm)
|
| 179 |
+
|
| 180 |
+
def inner_forward(self, x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond):
|
| 181 |
+
y_cond = self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 182 |
+
|
| 183 |
+
if is_cond:
|
| 184 |
+
return y_cond
|
| 185 |
+
else:
|
| 186 |
+
force_drop_modalities = self.backbone.condition_embedder.force_drop_modalities
|
| 187 |
+
self.backbone.condition_embedder.force_drop_modalities = ['pointmap', 'rgb_pointmap']
|
| 188 |
+
y_pm = self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 189 |
+
self.backbone.condition_embedder.force_drop_modalities = force_drop_modalities
|
| 190 |
+
|
| 191 |
+
args_cond, kwargs_cond = self._make_unconditional_args(
|
| 192 |
+
args_cond,
|
| 193 |
+
kwargs_cond,
|
| 194 |
+
)
|
| 195 |
+
y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 196 |
+
return self._cfg_step(y_cond, y_uncond, y_pm, strength, strength_pm)
|
| 197 |
+
|
| 198 |
+
def forward(self, x, t, *args_cond, **kwargs_cond):
|
| 199 |
+
# handle case when no conditional arguments are provided
|
| 200 |
+
if len(args_cond) + len(kwargs_cond) == 0: # unconditional
|
| 201 |
+
if self.unconditional_handling != "discard":
|
| 202 |
+
raise RuntimeError(
|
| 203 |
+
f"cannot call `ClassifierFreeGuidance` module without condition"
|
| 204 |
+
)
|
| 205 |
+
return self.backbone(x, t)
|
| 206 |
+
else: # conditional arguments are provided
|
| 207 |
+
# training mode
|
| 208 |
+
if self.training:
|
| 209 |
+
coin_flip = random.random() < self.p_unconditional
|
| 210 |
+
if coin_flip: # unconditional
|
| 211 |
+
args_cond, kwargs_cond = self._make_unconditional_args(
|
| 212 |
+
args_cond,
|
| 213 |
+
kwargs_cond,
|
| 214 |
+
)
|
| 215 |
+
return self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 216 |
+
else: # inference mode
|
| 217 |
+
strength = get_strength(self.strength, self.interval, t)
|
| 218 |
+
is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
|
| 219 |
+
strength_pm = get_strength(self.strength_pm, self.interval, t)
|
| 220 |
+
return self.inner_forward(
|
| 221 |
+
x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
class ClassifierFreeGuidanceWithExternalUnconditionalProbability(ClassifierFreeGuidance):
|
| 225 |
+
|
| 226 |
+
def __init__(self, *args, use_unconditional_from_flow_matching=False, **kwargs):
|
| 227 |
+
super().__init__(*args, **kwargs)
|
| 228 |
+
self.use_unconditional_from_flow_matching = use_unconditional_from_flow_matching
|
| 229 |
+
|
| 230 |
+
def forward(self, x, t, *args_cond, p_unconditional=None, **kwargs_cond):
|
| 231 |
+
# p_unconditional should be a value in [0, 1], indicating the probability of unconditional
|
| 232 |
+
|
| 233 |
+
if p_unconditional is None:
|
| 234 |
+
coin_flip = random.random() < self.p_unconditional
|
| 235 |
+
else:
|
| 236 |
+
coin_flip = random.random() < p_unconditional
|
| 237 |
+
|
| 238 |
+
# handle case when no conditional arguments are provided
|
| 239 |
+
if len(args_cond) + len(kwargs_cond) == 0: # unconditional
|
| 240 |
+
if self.unconditional_handling != "discard":
|
| 241 |
+
raise RuntimeError(
|
| 242 |
+
f"cannot call `ClassifierFreeGuidance` module without condition"
|
| 243 |
+
)
|
| 244 |
+
return self.backbone(x, t)
|
| 245 |
+
else: # conditional arguments are provided
|
| 246 |
+
# training mode
|
| 247 |
+
if self.training:
|
| 248 |
+
if coin_flip: # unconditional
|
| 249 |
+
args_cond, kwargs_cond = self._make_unconditional_args(
|
| 250 |
+
args_cond,
|
| 251 |
+
kwargs_cond,
|
| 252 |
+
)
|
| 253 |
+
return self.backbone(x, t, *args_cond, **kwargs_cond)
|
| 254 |
+
else: # inference mode
|
| 255 |
+
strength = get_strength(self.strength, self.interval, t)
|
| 256 |
+
is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0])
|
| 257 |
+
return self.inner_forward(
|
| 258 |
+
x, t, is_cond, strength, *args_cond, **kwargs_cond
|
| 259 |
+
)
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
from typing import Callable, Sequence, Union
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from functools import partial
|
| 6 |
+
import optree
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
from sam3d_objects.model.backbone.generator.base import Base
|
| 10 |
+
from sam3d_objects.data.utils import right_broadcasting
|
| 11 |
+
from sam3d_objects.data.utils import tree_tensor_map, tree_reduce_unique
|
| 12 |
+
from sam3d_objects.model.backbone.generator.flow_matching.solver import (
|
| 13 |
+
ODESolver,
|
| 14 |
+
Euler,
|
| 15 |
+
Midpoint,
|
| 16 |
+
RungeKutta4,
|
| 17 |
+
gradient,
|
| 18 |
+
SDE,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# default sampler in flow matching
|
| 22 |
+
uniform_sampler = torch.rand
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# https://arxiv.org/pdf/2403.03206
|
| 26 |
+
def lognorm_sampler(mean=0.0, std=1.0, **kwargs):
|
| 27 |
+
logit = torch.randn(**kwargs) * std + mean
|
| 28 |
+
return torch.nn.functional.sigmoid(logit)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# for backwards compatibility; please do not use this
|
| 32 |
+
def rev_lognorm_sampler(mean=0.0, std=1.0, **kwargs):
|
| 33 |
+
logit = torch.randn(**kwargs) * std + mean
|
| 34 |
+
return 1 - torch.nn.functional.sigmoid(logit)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# https://arxiv.org/pdf/2210.02747
|
| 38 |
+
class FlowMatching(Base):
|
| 39 |
+
SOLVER_METHODS = {
|
| 40 |
+
"euler": Euler,
|
| 41 |
+
"midpoint": Midpoint,
|
| 42 |
+
"rk4": RungeKutta4,
|
| 43 |
+
"sde": SDE,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
reverse_fn: Callable,
|
| 49 |
+
sigma_min: float = 0.0, # 0. = rectifier flow
|
| 50 |
+
inference_steps: int = 100,
|
| 51 |
+
time_scale: float = 1000.0, # scale [0,1]-time before passing to `reverse_fn`
|
| 52 |
+
training_time_sampler_fn: Callable = partial(
|
| 53 |
+
lognorm_sampler,
|
| 54 |
+
mean=0,
|
| 55 |
+
std=1,
|
| 56 |
+
),
|
| 57 |
+
reversed_timestamp=False,
|
| 58 |
+
rescale_t=1.0,
|
| 59 |
+
loss_fn=partial(torch.nn.functional.mse_loss, reduction="mean"),
|
| 60 |
+
loss_weights=1.0,
|
| 61 |
+
solver_method: Union[str, ODESolver] = "euler",
|
| 62 |
+
solver_kwargs: dict = {},
|
| 63 |
+
**kwargs,
|
| 64 |
+
):
|
| 65 |
+
super().__init__(**kwargs)
|
| 66 |
+
|
| 67 |
+
self.reverse_fn = reverse_fn
|
| 68 |
+
self.sigma_min = sigma_min
|
| 69 |
+
self.inference_steps = inference_steps
|
| 70 |
+
self.time_scale = time_scale
|
| 71 |
+
self.training_time_sampler_fn = training_time_sampler_fn
|
| 72 |
+
self.reversed_timestamp = reversed_timestamp
|
| 73 |
+
self.rescale_t = rescale_t
|
| 74 |
+
self.loss_fn = loss_fn
|
| 75 |
+
self.loss_weights = loss_weights
|
| 76 |
+
self._solver_method, self._solver = self._get_solver(
|
| 77 |
+
solver_method, solver_kwargs
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def _get_solver(self, solver_method, solver_kwargs):
|
| 81 |
+
if solver_method in FlowMatching.SOLVER_METHODS:
|
| 82 |
+
solver = FlowMatching.SOLVER_METHODS[solver_method](**solver_kwargs)
|
| 83 |
+
elif isinstance(solver_method, ODESolver):
|
| 84 |
+
solver_method = f"custom[{solver_method.__class__.__name__}]"
|
| 85 |
+
solver = solver_method
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"Invalid solver `{solver_method}`, should be in {set(self.SOLVER_METHODS.keys())} or an ODESolver instance"
|
| 89 |
+
)
|
| 90 |
+
return solver_method, solver
|
| 91 |
+
|
| 92 |
+
def _generate_noise_tensor(self, x_shape, x_device):
|
| 93 |
+
return torch.randn(
|
| 94 |
+
x_shape,
|
| 95 |
+
# generator=self.random_generator,
|
| 96 |
+
device=x_device,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def _generate_noise(self, x_shape, x_device):
|
| 100 |
+
def is_shape(maybe_shape):
|
| 101 |
+
return isinstance(maybe_shape, Sequence) and all(
|
| 102 |
+
(isinstance(s, int) and s >= 0) for s in maybe_shape
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return optree.tree_map(
|
| 106 |
+
partial(self._generate_noise_tensor, x_device=x_device),
|
| 107 |
+
x_shape,
|
| 108 |
+
is_leaf=is_shape,
|
| 109 |
+
none_is_leaf=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def _generate_x0_tensor(self, x1: torch.Tensor):
|
| 113 |
+
x0 = self._generate_noise_tensor(x1.shape, x1.device)
|
| 114 |
+
return x0
|
| 115 |
+
|
| 116 |
+
def _generate_xt_tensor(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor):
|
| 117 |
+
# equation (22)
|
| 118 |
+
tb = right_broadcasting(t.to(x1.device), x1)
|
| 119 |
+
x_t = (1 - (1 - self.sigma_min) * tb) * x0 + tb * x1
|
| 120 |
+
|
| 121 |
+
return x_t
|
| 122 |
+
|
| 123 |
+
def _generate_target_tensor(self, x0: torch.Tensor, x1: torch.Tensor):
|
| 124 |
+
# equation (23)
|
| 125 |
+
target = x1 - (1 - self.sigma_min) * x0
|
| 126 |
+
|
| 127 |
+
return target
|
| 128 |
+
|
| 129 |
+
def _generate_x0(self, x1):
|
| 130 |
+
return tree_tensor_map(self._generate_x0_tensor, x1)
|
| 131 |
+
|
| 132 |
+
def _generate_xt(self, x0, x1, t):
|
| 133 |
+
return tree_tensor_map(
|
| 134 |
+
partial(self._generate_xt_tensor, t=t),
|
| 135 |
+
x0,
|
| 136 |
+
x1,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def _generate_target(self, x0, x1):
|
| 140 |
+
return tree_tensor_map(
|
| 141 |
+
self._generate_target_tensor,
|
| 142 |
+
x0,
|
| 143 |
+
x1,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def _generate_t(self, x1):
|
| 147 |
+
first_tensor = optree.tree_flatten(x1)[0][0]
|
| 148 |
+
batch_size = first_tensor.shape[0]
|
| 149 |
+
device = first_tensor.device
|
| 150 |
+
|
| 151 |
+
t = self.training_time_sampler_fn(
|
| 152 |
+
size=(batch_size,),
|
| 153 |
+
generator=self.random_generator,
|
| 154 |
+
).to(device)
|
| 155 |
+
|
| 156 |
+
return t
|
| 157 |
+
|
| 158 |
+
def loss(self, x1: torch.Tensor, *args_conditionals, **kwargs_conditionals):
|
| 159 |
+
t = self._generate_t(x1)
|
| 160 |
+
x0 = self._generate_x0(x1)
|
| 161 |
+
x_t = self._generate_xt(x0, x1, t)
|
| 162 |
+
target = self._generate_target(x0, x1)
|
| 163 |
+
|
| 164 |
+
prediction = self.reverse_fn(
|
| 165 |
+
x_t,
|
| 166 |
+
t * self.time_scale,
|
| 167 |
+
*args_conditionals,
|
| 168 |
+
**kwargs_conditionals,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# broadcast & and compute loss
|
| 172 |
+
loss = optree.tree_broadcast_map(
|
| 173 |
+
lambda fn, weight, pred, targ: weight * fn(pred, targ),
|
| 174 |
+
self.loss_fn,
|
| 175 |
+
self.loss_weights,
|
| 176 |
+
prediction,
|
| 177 |
+
target,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
total_loss = sum(optree.tree_flatten(loss)[0])
|
| 181 |
+
|
| 182 |
+
# Create detailed loss breakdown
|
| 183 |
+
detail_losses = {
|
| 184 |
+
"flow_matching_loss": total_loss,
|
| 185 |
+
}
|
| 186 |
+
if isinstance(loss, dict):
|
| 187 |
+
detail_losses.update(loss)
|
| 188 |
+
return total_loss, detail_losses
|
| 189 |
+
|
| 190 |
+
def _prepare_t(self, steps=None):
|
| 191 |
+
steps = self.inference_steps if steps is None else steps
|
| 192 |
+
t_seq = torch.linspace(0, 1, steps + 1)
|
| 193 |
+
|
| 194 |
+
if self.rescale_t:
|
| 195 |
+
t_seq = t_seq / (1 + (self.rescale_t - 1) * (1 - t_seq))
|
| 196 |
+
|
| 197 |
+
if self.reversed_timestamp:
|
| 198 |
+
t_seq = 1 - t_seq
|
| 199 |
+
|
| 200 |
+
return t_seq
|
| 201 |
+
|
| 202 |
+
def generate_iter(
|
| 203 |
+
self,
|
| 204 |
+
x_shape,
|
| 205 |
+
x_device,
|
| 206 |
+
*args_conditionals,
|
| 207 |
+
**kwargs_conditionals,
|
| 208 |
+
):
|
| 209 |
+
x_0 = self._generate_noise(x_shape, x_device)
|
| 210 |
+
t_seq = self._prepare_t().to(x_device)
|
| 211 |
+
|
| 212 |
+
for x_t, t in self._solver.solve_iter(
|
| 213 |
+
self._generate_dynamics,
|
| 214 |
+
x_0,
|
| 215 |
+
t_seq,
|
| 216 |
+
*args_conditionals,
|
| 217 |
+
**kwargs_conditionals,
|
| 218 |
+
):
|
| 219 |
+
yield t, x_t, ()
|
| 220 |
+
|
| 221 |
+
def _generate_dynamics(
|
| 222 |
+
self,
|
| 223 |
+
x_t,
|
| 224 |
+
t,
|
| 225 |
+
*args_conditionals,
|
| 226 |
+
**kwargs_conditionals,
|
| 227 |
+
):
|
| 228 |
+
return self.reverse_fn(x_t, t * self.time_scale, *args_conditionals, **kwargs_conditionals)
|
| 229 |
+
|
| 230 |
+
def _log_p0(self, x0):
|
| 231 |
+
x0 = self._tree_flatten(x0)
|
| 232 |
+
inside_exp = -(x0**2).sum(dim=1) / 2
|
| 233 |
+
return inside_exp - math.log(2 * math.pi) / 2 * x0.shape[1]
|
| 234 |
+
|
| 235 |
+
def log_likelihood(
|
| 236 |
+
self,
|
| 237 |
+
x1,
|
| 238 |
+
solver=None,
|
| 239 |
+
steps=None,
|
| 240 |
+
z_samples=1,
|
| 241 |
+
*args_conditionals,
|
| 242 |
+
**kwargs_conditionals,
|
| 243 |
+
):
|
| 244 |
+
device = tree_reduce_unique(lambda tensor: tensor.device, x1)
|
| 245 |
+
# device = "cuda"
|
| 246 |
+
t_seq = self._prepare_t(steps).to(device)
|
| 247 |
+
t_seq = 1 - t_seq # from x1 to x0
|
| 248 |
+
solver = self._solver if solver is None else self._get_solver(solver)[1]
|
| 249 |
+
|
| 250 |
+
x_0 = solver.solve(
|
| 251 |
+
partial(self._log_likelihood_dynamics, device=device, z_samples=z_samples),
|
| 252 |
+
{"x": x1, "log_p": 0.0},
|
| 253 |
+
t_seq,
|
| 254 |
+
*args_conditionals,
|
| 255 |
+
**kwargs_conditionals,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
log_p1 = x_0["log_p"] + self._log_p0(x_0["x"])
|
| 259 |
+
|
| 260 |
+
return log_p1
|
| 261 |
+
|
| 262 |
+
def _log_likelihood_dynamics(
|
| 263 |
+
self,
|
| 264 |
+
state,
|
| 265 |
+
t,
|
| 266 |
+
device,
|
| 267 |
+
z_samples,
|
| 268 |
+
*args_conditionals,
|
| 269 |
+
**kwargs_conditionals,
|
| 270 |
+
):
|
| 271 |
+
t = torch.tensor([t * self.time_scale], device=device, dtype=torch.float32)
|
| 272 |
+
x_t = state["x"]
|
| 273 |
+
|
| 274 |
+
with torch.set_grad_enabled(True):
|
| 275 |
+
tree_tensor_map(lambda x,: x.requires_grad_(True), x_t)
|
| 276 |
+
velocity = self.reverse_fn(
|
| 277 |
+
x_t,
|
| 278 |
+
t,
|
| 279 |
+
*args_conditionals,
|
| 280 |
+
**kwargs_conditionals,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# compute the divergence estimate
|
| 284 |
+
div = self._compute_hutchinson_divergence(velocity, x_t, z_samples)
|
| 285 |
+
|
| 286 |
+
tree_tensor_map(lambda x,: x.requires_grad_(False), x_t)
|
| 287 |
+
velocity = tree_tensor_map(lambda x: x.detach(), velocity)
|
| 288 |
+
|
| 289 |
+
return {"x": velocity, "log_p": div.detach()}
|
| 290 |
+
|
| 291 |
+
def _tree_flatten(self, tree):
|
| 292 |
+
flat_x = tree_tensor_map(lambda x: x.flatten(start_dim=1), tree)
|
| 293 |
+
flat_x, _ = optree.tree_flatten(
|
| 294 |
+
flat_x,
|
| 295 |
+
is_leaf=lambda x: isinstance(x, torch.Tensor),
|
| 296 |
+
)
|
| 297 |
+
flat_x = torch.cat(flat_x, dim=1)
|
| 298 |
+
return flat_x
|
| 299 |
+
|
| 300 |
+
def _compute_hutchinson_divergence(self, velocity, x_t, z_samples):
|
| 301 |
+
flat_velocity = self._tree_flatten(velocity)
|
| 302 |
+
flat_velocity = flat_velocity.unsqueeze(-1)
|
| 303 |
+
|
| 304 |
+
z = torch.randn(
|
| 305 |
+
flat_velocity.shape[:-1] + (z_samples,),
|
| 306 |
+
dtype=flat_velocity.dtype,
|
| 307 |
+
device=flat_velocity.device,
|
| 308 |
+
)
|
| 309 |
+
z = z < 0
|
| 310 |
+
z = z * 2.0 - 1.0
|
| 311 |
+
z = z / math.sqrt(z_samples)
|
| 312 |
+
|
| 313 |
+
# compute Hutchinson divergence estimator E[z^T D_x(vt) z] = E[D_x(z^T vt) z)]
|
| 314 |
+
vt_dot_z = torch.einsum("ijk,ijk->ik", flat_velocity, z)
|
| 315 |
+
grad_vt_dot_z = [
|
| 316 |
+
gradient(vt_dot_z[..., i], x_t, create_graph=(z_samples > 1))
|
| 317 |
+
for i in range(z_samples)
|
| 318 |
+
]
|
| 319 |
+
grad_vt_dot_z = [self._tree_flatten(g) for g in grad_vt_dot_z]
|
| 320 |
+
grad_vt_dot_z = torch.stack(grad_vt_dot_z, dim=-1)
|
| 321 |
+
div = torch.einsum("ijk,ijk->i", grad_vt_dot_z, z)
|
| 322 |
+
return div
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _get_device(x):
|
| 326 |
+
device = tree_reduce_unique(lambda tensor: tensor.device, x)
|
| 327 |
+
return device
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class ConditionalFlowMatching(FlowMatching):
|
| 331 |
+
def generate_iter(
|
| 332 |
+
self,
|
| 333 |
+
x_shape,
|
| 334 |
+
x_device,
|
| 335 |
+
*args_conditionals,
|
| 336 |
+
**kwargs_conditionals,
|
| 337 |
+
):
|
| 338 |
+
x_0 = self._generate_noise(x_shape, x_device)
|
| 339 |
+
t_seq = self._prepare_t().to(x_device)
|
| 340 |
+
|
| 341 |
+
noise_override = None
|
| 342 |
+
if "noise_override" in kwargs_conditionals:
|
| 343 |
+
noise_override = kwargs_conditionals["noise_override"]
|
| 344 |
+
del kwargs_conditionals["noise_override"]
|
| 345 |
+
if noise_override is not None:
|
| 346 |
+
if type(x_0) == dict:
|
| 347 |
+
x_0.update(noise_override)
|
| 348 |
+
else:
|
| 349 |
+
x_0 = noise_override
|
| 350 |
+
|
| 351 |
+
for x_t, t in self._solver.solve_iter(
|
| 352 |
+
self._generate_dynamics,
|
| 353 |
+
x_0,
|
| 354 |
+
t_seq,
|
| 355 |
+
*args_conditionals,
|
| 356 |
+
**kwargs_conditionals,
|
| 357 |
+
):
|
| 358 |
+
if noise_override is not None:
|
| 359 |
+
if type(noise_override) == dict:
|
| 360 |
+
x_t.update(noise_override)
|
| 361 |
+
else:
|
| 362 |
+
x_t = noise_override
|
| 363 |
+
yield t, x_t, ()
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import optree
|
| 3 |
+
import torch
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from sam3d_objects.data.utils import tree_tensor_map
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def linear_approximation_step(x_t, dt, velocity):
|
| 10 |
+
# x_tp1 = x_t + velocity * dt
|
| 11 |
+
x_tp1 = tree_tensor_map(lambda x, v: x + v * dt, x_t, velocity)
|
| 12 |
+
return x_tp1
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def gradient(output, x, create_graph: bool = False):
|
| 16 |
+
tensors, pyspec = optree.tree_flatten(
|
| 17 |
+
x, is_leaf=lambda x: isinstance(x, torch.Tensor)
|
| 18 |
+
)
|
| 19 |
+
grad_outputs = [torch.ones_like(output).detach() for _ in tensors]
|
| 20 |
+
grads = torch.autograd.grad(
|
| 21 |
+
output,
|
| 22 |
+
tensors,
|
| 23 |
+
grad_outputs=grad_outputs,
|
| 24 |
+
create_graph=create_graph,
|
| 25 |
+
)
|
| 26 |
+
return optree.tree_unflatten(pyspec, grads)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ODESolver:
|
| 30 |
+
def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
def solve_iter(self, dynamics_fn, x_init, times, *args, **kwargs):
|
| 34 |
+
x_t = x_init
|
| 35 |
+
for t0, t1 in zip(times[:-1], times[1:]):
|
| 36 |
+
dt = t1 - t0
|
| 37 |
+
x_t = self.step(dynamics_fn, x_t, t0, dt, *args, **kwargs)
|
| 38 |
+
yield x_t, t0
|
| 39 |
+
|
| 40 |
+
def solve(self, dynamics_fn, x_init, times, *args, **kwargs):
|
| 41 |
+
for x_t, _ in self.solve_iter(dynamics_fn, x_init, times, *args, **kwargs):
|
| 42 |
+
pass
|
| 43 |
+
return x_t
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# https://en.wikipedia.org/wiki/Euler_method
|
| 47 |
+
class Euler(ODESolver):
|
| 48 |
+
def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 49 |
+
velocity = dynamics_fn(x_t, t, *args, **kwargs)
|
| 50 |
+
x_tp1 = linear_approximation_step(x_t, dt, velocity)
|
| 51 |
+
return x_tp1
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# https://arxiv.org/abs/2505.05470
|
| 55 |
+
class SDE(ODESolver):
|
| 56 |
+
def __init__(self, **kwargs):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.sde_strength = kwargs.get("sde_strength", 0.1)
|
| 59 |
+
|
| 60 |
+
def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 61 |
+
velocity = dynamics_fn(x_t, t, *args, **kwargs)
|
| 62 |
+
sigma = 1 - t
|
| 63 |
+
var_t = sigma / (1 - torch.tensor(sigma).clamp(min=dt))
|
| 64 |
+
std_dev_t = (
|
| 65 |
+
torch.sqrt(variance) * self.sde_strength
|
| 66 |
+
) # self.sde_strength = alpha
|
| 67 |
+
|
| 68 |
+
def compute_mean(x, v):
|
| 69 |
+
drift_term = x * (std_dev_t**2 / (2 * sigma) * dt)
|
| 70 |
+
velocity_term = v * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
|
| 71 |
+
return x + drift_term + velocity_term
|
| 72 |
+
|
| 73 |
+
prev_sample_mean = tree_tensor_map(compute_mean, x_t, velocity)
|
| 74 |
+
|
| 75 |
+
# Generate noise and compute final sample using tree_tensor_map
|
| 76 |
+
def add_noise(mean_val):
|
| 77 |
+
variance_noise = torch.randn_like(mean_val)
|
| 78 |
+
return mean_val + std_dev_t * torch.sqrt(torch.tensor(dt)) * variance_noise
|
| 79 |
+
|
| 80 |
+
prev_sample = tree_tensor_map(add_noise, prev_sample_mean)
|
| 81 |
+
|
| 82 |
+
return prev_sample
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# https://en.wikipedia.org/wiki/Midpoint_method
|
| 86 |
+
class Midpoint(ODESolver):
|
| 87 |
+
def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 88 |
+
half_dt = 0.5 * dt
|
| 89 |
+
|
| 90 |
+
x_mid = Euler.step(self, dynamics_fn, x_t, t, half_dt, *args, **kwargs)
|
| 91 |
+
|
| 92 |
+
velocity_mid = dynamics_fn(x_mid, t + half_dt, *args, **kwargs)
|
| 93 |
+
x_tp1 = linear_approximation_step(x_t, dt, velocity_mid)
|
| 94 |
+
return x_tp1
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods
|
| 98 |
+
class RungeKutta4(ODESolver):
|
| 99 |
+
|
| 100 |
+
def k1(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 101 |
+
return dynamics_fn(x_t, t, *args, **kwargs)
|
| 102 |
+
|
| 103 |
+
def k2(self, dynamics_fn, x_t, t, dt, k1, *args, **kwargs):
|
| 104 |
+
x_k1 = linear_approximation_step(x_t, dt * 0.5, k1)
|
| 105 |
+
return dynamics_fn(x_k1, t + dt * 0.5, *args, **kwargs)
|
| 106 |
+
|
| 107 |
+
def k3(self, dynamics_fn, x_t, t, dt, k2, *args, **kwargs):
|
| 108 |
+
x_k2 = linear_approximation_step(x_t, dt * 0.5, k2)
|
| 109 |
+
return dynamics_fn(x_k2, t + dt * 0.5, *args, **kwargs)
|
| 110 |
+
|
| 111 |
+
def k4(self, dynamics_fn, x_t, t, dt, k3, *args, **kwargs):
|
| 112 |
+
x_k3 = linear_approximation_step(x_t, dt, k3)
|
| 113 |
+
return dynamics_fn(x_k3, t + dt, *args, **kwargs)
|
| 114 |
+
|
| 115 |
+
def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs):
|
| 116 |
+
k1 = self.k1(dynamics_fn, x_t, t, dt, *args, **kwargs)
|
| 117 |
+
k2 = self.k2(dynamics_fn, x_t, t, dt, k1, *args, **kwargs)
|
| 118 |
+
k3 = self.k3(dynamics_fn, x_t, t, dt, k2, *args, **kwargs)
|
| 119 |
+
k4 = self.k4(dynamics_fn, x_t, t, dt, k3, *args, **kwargs)
|
| 120 |
+
|
| 121 |
+
def compute_velocity(k1, k2, k3, k4):
|
| 122 |
+
return (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
| 123 |
+
|
| 124 |
+
velocity_k = tree_tensor_map(compute_velocity, k1, k2, k3, k4)
|
| 125 |
+
x_tp1 = linear_approximation_step(x_t, dt, velocity_k)
|
| 126 |
+
return x_tp1
|
thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|