diff --git a/app.py b/app.py index d591e19f89989b51b433b23e9af68a7ca1e8e518..9ebe3b75ea0fc27ee9765c46f970812128c1250d 100644 --- a/app.py +++ b/app.py @@ -115,11 +115,11 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: ) rmbg_tag = gr.Radio( choices=["rembg", "rmbg14"], - value="rembg", + value="rmbg14", label="Background Removal Model", ) ip_adapt_scale = gr.Slider( - 0, 1, label="IP-adapter Scale", value=0.3, step=0.05 + 0, 1, label="IP-adapter Scale", value=0.7, step=0.05 ) img_guidance_scale = gr.Slider( 1, 30, label="Text Guidance Scale", value=12, step=0.2 @@ -287,7 +287,7 @@ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: visible=False, ) gr.Markdown( - "Generated image may be poor quality due to auto seg." + "Generated image may be poor quality due to auto seg. " "Retry by adjusting text prompt, seed or switch seg model in `Image Gen Settings`." ) with gr.Row(): diff --git a/embodied_gen/models/sam3d.py b/embodied_gen/models/sam3d.py index 3296e411c45e67622d77a710b0f31382ca6c4b46..8601edf9fd48398b9240414e4d5c45af5e165cab 100644 --- a/embodied_gen/models/sam3d.py +++ b/embodied_gen/models/sam3d.py @@ -94,9 +94,10 @@ class Sam3dInference: ) -> dict: if isinstance(image, Image.Image): image = np.array(image) + image = self.merge_mask_to_rgba(image, mask) return self.pipeline.run( image, - mask, + None, seed, stage1_only=False, with_mesh_postprocess=False, @@ -132,7 +133,7 @@ if __name__ == "__main__": start = time() - output = pipeline(image, mask, seed=42) + output = pipeline.run(image, mask, seed=42) print(f"Running cost: {round(time()-start, 1)}") if torch.cuda.is_available(): diff --git a/embodied_gen/utils/monkey_patches.py b/embodied_gen/utils/monkey_patches.py index ecf635d24f9c36b4c65388a6d0e366de0c8b76fc..3dfea6a6d9d54d87e8c8f96b785899743b8ec2e5 100644 --- a/embodied_gen/utils/monkey_patches.py +++ b/embodied_gen/utils/monkey_patches.py @@ -397,17 +397,13 @@ def monkey_patch_sam3d(): exc_info=True, ) - # glb.export("sample.glb") - logger.info("Finished!") - - return { + result = { **ss_return_dict, **outputs, - "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 - "pointmap_colors": pts_colors.cpu().permute( - (1, 2, 0) - ), # HxWx3 + "pointmap": pts.cpu().permute((1, 2, 0)), + "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), } + return result InferencePipelinePointMap.run = patch_run diff --git a/thirdparty/sam3d/sam3d/.gitignore b/thirdparty/sam3d/sam3d/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ed8ebf583f771da9150c35db3955987b7d757904 --- /dev/null +++ b/thirdparty/sam3d/sam3d/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md b/thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..cf9dc244896688c330dba81d06f2cfc5568d7aea --- /dev/null +++ b/thirdparty/sam3d/sam3d/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/CONTRIBUTING.md b/thirdparty/sam3d/sam3d/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..f853448d128c8e64ab1446451f92c014e5a4ba3d --- /dev/null +++ b/thirdparty/sam3d/sam3d/CONTRIBUTING.md @@ -0,0 +1,39 @@ +# Contributing to sam-3d-objects +We want to make contributing to this project as easy and transparent as +possible. + +## Our Development Process +... (in particular how this is synced with internal changes to the project) + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://bugbounty.meta.com/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 2 spaces for indentation rather than tabs +* 80 character line length +* ... + +## License +By contributing to sam-3d-objects, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/LICENSE b/thirdparty/sam3d/sam3d/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e58d758cc3d92e6d1a57b52f7e4d203d02a67e98 --- /dev/null +++ b/thirdparty/sam3d/sam3d/LICENSE @@ -0,0 +1,52 @@ +SAM License +Last Updated: November 19, 2025 + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein. + +“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement. + +“Documentation” means the specifications, manuals and documentation accompanying +SAM Materials distributed by Meta. + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom. + +“Trade Controls” means any of the following: Sanctions and applicable export and import controls. + +By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials. + +i. Grant of Patent License. Subject to the terms and conditions of this License, you are granted a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by Meta that are necessarily infringed alone or by combination of their contribution(s) with the SAM 3 Materials. If you institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM 3 Materials incorporated within the work constitutes direct or contributory patent infringement, then any patent licenses granted to you under this License for that work shall terminate as of the date such litigation is filed. + +b. Redistribution and Use. + +i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials. + +ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication. + +iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws. +iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials. +v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons. +2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + +a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + +8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. diff --git a/thirdparty/sam3d/sam3d/README.md b/thirdparty/sam3d/sam3d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3bcc4e6d8255efc7601406ce1471ef9a531ab411 --- /dev/null +++ b/thirdparty/sam3d/sam3d/README.md @@ -0,0 +1,152 @@ +# SAM 3D + +SAM 3D Objects is one part of SAM 3D, a pair of models for object and human mesh reconstruction. If you’re looking for SAM 3D Body, [click here](https://github.com/facebookresearch/sam-3d-body). + +# SAM 3D Objects + +**SAM 3D Team**, [Xingyu Chen](https://scholar.google.com/citations?user=gjSHr6YAAAAJ&hl=en&oi=sra)\*, [Fu-Jen Chu](https://fujenchu.github.io/)\*, [Pierre Gleize](https://scholar.google.com/citations?user=4imOcw4AAAAJ&hl=en&oi=ao)\*, [Kevin J Liang](https://kevinjliang.github.io/)\*, [Alexander Sax](https://alexsax.github.io/)\*, [Hao Tang](https://scholar.google.com/citations?user=XY6Nh9YAAAAJ&hl=en&oi=sra)\*, [Weiyao Wang](https://sites.google.com/view/weiyaowang/home)\*, [Michelle Guo](https://scholar.google.com/citations?user=lyjjpNMAAAAJ&hl=en&oi=ao), [Thibaut Hardin](https://github.com/Thibaut-H), [Xiang Li](https://ryanxli.github.io/)⚬, [Aohan Lin](https://github.com/linaohan), [Jia-Wei Liu](https://jia-wei-liu.github.io/), [Ziqi Ma](https://ziqi-ma.github.io/)⚬, [Anushka Sagar](https://www.linkedin.com/in/anushkasagar/), [Bowen Song](https://scholar.google.com/citations?user=QQKVkfcAAAAJ&hl=en&oi=sra)⚬, [Xiaodong Wang](https://scholar.google.com/citations?authuser=2&user=rMpcFYgAAAAJ), [Jianing Yang](https://jedyang.com/)⚬, [Bowen Zhang](http://home.ustc.edu.cn/~zhangbowen/)⚬, [Piotr Dollár](https://pdollar.github.io/)†, [Georgia Gkioxari](https://georgiagkioxari.com/)†, [Matt Feiszli](https://scholar.google.com/citations?user=A-wA73gAAAAJ&hl=en&oi=ao)†§, [Jitendra Malik](https://people.eecs.berkeley.edu/~malik/)†§ + +***Meta Superintelligence Labs*** + +*Core contributor (Alphabetical, Equal Contribution), ⚬Intern, †Project leads, §Equal Contribution + +[[`Paper`](https://ai.meta.com/research/publications/sam-3d-3dfy-anything-in-images/)] [[`Code`](https://github.com/facebookresearch/sam-3d-objects)] [[`Website`](https://ai.meta.com/sam3d/)] [[`Demo`](https://www.aidemos.meta.com/segment-anything/editor/convert-image-to-3d)] [[`Blog`](https://ai.meta.com/blog/sam-3d/)] [[`BibTeX`](#citing-sam-3d-objects)] [[`Roboflow`](https://blog.roboflow.com/sam-3d/)] + +**SAM 3D Objects** is a foundation model that reconstructs full 3D shape geometry, texture, and layout from a single image, excelling in real-world scenarios with occlusion and clutter by using progressive training and a data engine with human feedback. It outperforms prior 3D generation models in human preference tests on real-world objects and scenes. We released code, weights, online demo, and a new challenging benchmark. + + +

+ +----- + +

+ +## Latest updates + +**11/19/2025** - Checkpoints Launched, Web Demo and Paper are out. + +## Installation + +Follow the [setup](doc/setup.md) steps before running the following. + +## Single or Multi-Object 3D Generation + +SAM 3D Objects can convert masked objects in an image, into 3D models with pose, shape, texture, and layout. SAM 3D is designed to be robust in challenging natural images, handling small objects and occlusions, unusual poses, and difficult situations encountered in uncurated natural scenes like this kidsroom: + +

+ + +

+ +For a quick start, run `python demo.py` or use the the following lines of code: + +```python +import sys + +# import inference code +sys.path.append("notebook") +from inference import Inference, load_image, load_single_mask + +# load model +tag = "hf" +config_path = f"checkpoints/{tag}/pipeline.yaml" +inference = Inference(config_path, compile=False) + +# load image and mask +image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png") +mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14) + +# run model +output = inference(image, mask, seed=42) + +# export gaussian splat +output["gs"].save_ply(f"splat.ply") +``` + +For more details and multi-object reconstruction, please take a look at out two jupyter notebooks: +* [single object](notebook/demo_single_object.ipynb) +* [multi object](notebook/demo_multi_object.ipynb) + + +## SAM 3D Body + +[SAM 3D Body (3DB)](https://github.com/facebookresearch/sam-3d-body) is a robust promptable foundation model for single-image 3D human mesh recovery (HMR). + +As a way to combine the strengths of both **SAM 3D Objects** and **SAM 3D Body**, we provide an example notebook that demonstrates how to combine the results of both models such that they are aligned in the same frame of reference. Check it out [here](notebook/demo_3db_mesh_alignment.ipynb). + +## License + +The SAM 3D Objects model checkpoints and code are licensed under [SAM License](./LICENSE). + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Contributors + +The SAM 3D Objects project was made possible with the help of many contributors. + +Robbie Adkins, +Paris Baptiste, +Karen Bergan, +Kai Brown, +Michelle Chan, +Ida Cheng, +Khadijat Durojaiye, +Patrick Edwards, +Daniella Factor, +Facundo Figueroa, +Rene de la Fuente, +Eva Galper, +Cem Gokmen, +Alex He, +Enmanuel Hernandez, +Dex Honsa, +Leonna Jones, +Arpit Kalla, +Kris Kitani, +Helen Klein, +Kei Koyama, +Robert Kuo, +Vivian Lee, +Alex Lende, +Jonny Li, +Kehan Lyu, +Faye Ma, +Mallika Malhotra, +Sasha Mitts, +William Ngan, +George Orlin, +Peter Park, +Don Pinkus, +Roman Radle, +Nikhila Ravi, +Azita Shokrpour, +Jasmine Shone, +Zayida Suber, +Phillip Thomas, +Tatum Turner, +Joseph Walker, +Meng Wang, +Claudette Ward, +Andrew Westbury, +Lea Wilken, +Nan Yang, +Yael Yungster + + +## Citing SAM 3D Objects + +If you use SAM 3D Objects in your research, please use the following BibTeX entry. + +``` +@article{sam3dteam2025sam3d3dfyimages, + title={SAM 3D: 3Dfy Anything in Images}, + author={SAM 3D Team and Xingyu Chen and Fu-Jen Chu and Pierre Gleize and Kevin J Liang and Alexander Sax and Hao Tang and Weiyao Wang and Michelle Guo and Thibaut Hardin and Xiang Li and Aohan Lin and Jiawei Liu and Ziqi Ma and Anushka Sagar and Bowen Song and Xiaodong Wang and Jianing Yang and Bowen Zhang and Piotr Dollár and Georgia Gkioxari and Matt Feiszli and Jitendra Malik}, + year={2025}, + eprint={2511.16624}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2511.16624}, +} +``` diff --git a/thirdparty/sam3d/sam3d/checkpoints/.gitignore b/thirdparty/sam3d/sam3d/checkpoints/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3 --- /dev/null +++ b/thirdparty/sam3d/sam3d/checkpoints/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/demo.py b/thirdparty/sam3d/sam3d/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..befdd845c3fedc9039e906918de7b8e8b12d96d7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/demo.py @@ -0,0 +1,21 @@ +import sys + +# import inference code +sys.path.append("notebook") +from inference import Inference, load_image, load_single_mask + +# load model +tag = "hf" +config_path = f"checkpoints/{tag}/pipeline.yaml" +inference = Inference(config_path, compile=False) + +# load image (RGBA only, mask is embedded in the alpha channel) +image = load_image("notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png") +mask = load_single_mask("notebook/images/shutterstock_stylish_kidsroom_1640806567", index=14) + +# run model +output = inference(image, mask, seed=42) + +# export gaussian splat +output["gs"].save_ply(f"splat.ply") +print("Your reconstruction has been saved to splat.ply") diff --git a/thirdparty/sam3d/sam3d/doc/setup.md b/thirdparty/sam3d/sam3d/doc/setup.md new file mode 100644 index 0000000000000000000000000000000000000000..1da6cd77a185dee7cde52ee3f8bed679028a963c --- /dev/null +++ b/thirdparty/sam3d/sam3d/doc/setup.md @@ -0,0 +1,58 @@ +# Setup + +## Prerequisites + +* A linux 64-bits architecture (i.e. `linux-64` platform in `mamba info`). +* A NVIDIA GPU with at least 32 Gb of VRAM. + +## 1. Setup Python Environment + +The following will install the default environment. If you use `conda` instead of `mamba`, replace its name in the first two lines. Note that you may have to build the environment on a compute node with GPU (e.g., you may get a `RuntimeError: Not compiled with GPU support` error when running certain parts of the code that use Pytorch3D). + +```bash +# create sam3d-objects environment +mamba env create -f environments/default.yml +mamba activate sam3d-objects + +# for pytorch/cuda dependencies +export PIP_EXTRA_INDEX_URL="https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121" + +# install sam3d-objects and core dependencies +pip install -e '.[dev]' +pip install -e '.[p3d]' # pytorch3d dependency on pytorch is broken, this 2-step approach solves it + +# for inference +export PIP_FIND_LINKS="https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html" +pip install -e '.[inference]' + +# patch things that aren't yet in official pip packages +./patching/hydra # https://github.com/facebookresearch/hydra/pull/2863 +``` + +## 2. Getting Checkpoints + +### From HuggingFace + +⚠️ Before using SAM 3D Objects, please request access to the checkpoints on the SAM 3D Objects +Hugging Face [repo](https://huggingface.co/facebook/sam-3d-objects). Once accepted, you +need to be authenticated to download the checkpoints. You can do this by running +the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication) +(e.g. `hf auth login` after generating an access token). + +⚠️ SAM 3D Objects is available via HuggingFace globally, **except** in comprehensively sanctioned jurisdictions. +Sanctioned jurisdiction will result in requests being **rejected**. + +```bash +pip install 'huggingface-hub[cli]<1.0' + +TAG=hf +hf download \ + --repo-type model \ + --local-dir checkpoints/${TAG}-download \ + --max-workers 1 \ + facebook/sam-3d-objects +mv checkpoints/${TAG}-download/checkpoints checkpoints/${TAG} +rm -rf checkpoints/${TAG}-download +``` + + diff --git a/thirdparty/sam3d/sam3d/environments/default.yml b/thirdparty/sam3d/sam3d/environments/default.yml new file mode 100644 index 0000000000000000000000000000000000000000..091170ea69a1a81555c0b1b7add5666711030aff --- /dev/null +++ b/thirdparty/sam3d/sam3d/environments/default.yml @@ -0,0 +1,216 @@ +name: sam3d-objects +channels: + - conda-forge +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - alsa-lib=1.2.13=hb9d3cd8_0 + - attr=2.5.1=h166bdaf_1 + - binutils=2.43=h4852527_4 + - binutils_impl_linux-64=2.43=h4bf12b8_4 + - binutils_linux-64=2.43=h4852527_4 + - bzip2=1.0.8=h4bc722e_7 + - c-compiler=1.7.0=hd590300_1 + - ca-certificates=2025.1.31=hbcca054_0 + - cairo=1.18.0=h3faef2a_0 + - cuda-cccl=12.1.109=ha770c72_0 + - cuda-cccl-impl=2.0.1=ha770c72_1 + - cuda-cccl_linux-64=12.1.109=ha770c72_0 + - cuda-command-line-tools=12.1.1=ha770c72_0 + - cuda-compiler=12.1.1=hbad6d8a_0 + - cuda-cudart=12.1.105=hd3aeb46_0 + - cuda-cudart-dev=12.1.105=hd3aeb46_0 + - cuda-cudart-dev_linux-64=12.1.105=h59595ed_0 + - cuda-cudart-static=12.1.105=hd3aeb46_0 + - cuda-cudart-static_linux-64=12.1.105=h59595ed_0 + - cuda-cudart_linux-64=12.1.105=h59595ed_0 + - cuda-cuobjdump=12.1.111=h59595ed_0 + - cuda-cupti=12.1.105=h59595ed_0 + - cuda-cupti-dev=12.1.105=h59595ed_0 + - cuda-cuxxfilt=12.1.105=h59595ed_0 + - cuda-driver-dev=12.1.105=hd3aeb46_0 + - cuda-driver-dev_linux-64=12.1.105=h59595ed_0 + - cuda-gdb=12.1.105=hd47b8d6_0 + - cuda-libraries=12.1.1=ha770c72_0 + - cuda-libraries-dev=12.1.1=ha770c72_0 + - cuda-nsight=12.1.105=ha770c72_0 + - cuda-nvcc=12.1.105=hcdd1206_1 + - cuda-nvcc-dev_linux-64=12.1.105=ha770c72_0 + - cuda-nvcc-impl=12.1.105=hd3aeb46_0 + - cuda-nvcc-tools=12.1.105=hd3aeb46_0 + - cuda-nvcc_linux-64=12.1.105=h8a487aa_1 + - cuda-nvdisasm=12.1.105=h59595ed_0 + - cuda-nvml-dev=12.1.105=h59595ed_0 + - cuda-nvprof=12.1.105=h59595ed_0 + - cuda-nvprune=12.1.105=h59595ed_0 + - cuda-nvrtc=12.1.105=hd3aeb46_0 + - cuda-nvrtc-dev=12.1.105=hd3aeb46_0 + - cuda-nvtx=12.1.105=h59595ed_0 + - cuda-nvvp=12.1.105=h59595ed_0 + - cuda-opencl=12.1.105=h59595ed_0 + - cuda-opencl-dev=12.1.105=h59595ed_0 + - cuda-profiler-api=12.1.105=ha770c72_0 + - cuda-sanitizer-api=12.1.105=h59595ed_0 + - cuda-toolkit=12.1.1=ha804496_0 + - cuda-tools=12.1.1=ha770c72_0 + - cuda-version=12.1=h1d6eff3_3 + - cuda-visual-tools=12.1.1=ha770c72_0 + - cxx-compiler=1.7.0=h00ab1b0_1 + - dbus=1.13.6=h5008d03_3 + - expat=2.6.4=h5888daf_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_3 + - fontconfig=2.15.0=h7e30c49_1 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - freetype=2.13.3=h48d6fc4_0 + - gcc=12.4.0=h236703b_2 + - gcc_impl_linux-64=12.4.0=h26ba24d_2 + - gcc_linux-64=12.4.0=h6b7512a_8 + - gds-tools=1.6.1.9=hd3aeb46_0 + - gettext=0.23.1=h5888daf_0 + - gettext-tools=0.23.1=h5888daf_0 + - glib=2.82.2=h07242d1_1 + - glib-tools=2.82.2=h4833e2c_1 + - gmp=6.3.0=hac33072_2 + - graphite2=1.3.13=h59595ed_1003 + - gst-plugins-base=1.24.4=h9ad1361_0 + - gstreamer=1.24.4=haf2f30d_0 + - gxx=12.4.0=h236703b_2 + - gxx_impl_linux-64=12.4.0=h3ff227c_2 + - gxx_linux-64=12.4.0=h8489865_8 + - harfbuzz=8.5.0=hfac3d4d_0 + - icu=73.2=h59595ed_0 + - kernel-headers_linux-64=3.10.0=he073ed8_18 + - keyutils=1.6.1=h166bdaf_0 + - krb5=1.21.3=h659f571_0 + - lame=3.100=h166bdaf_1003 + - ld_impl_linux-64=2.43=h712a8e2_4 + - libasprintf=0.23.1=h8e693c7_0 + - libasprintf-devel=0.23.1=h8e693c7_0 + - libcap=2.75=h39aace5_0 + - libclang-cpp15=15.0.7=default_h127d8a8_5 + - libclang13=19.1.2=default_h9c6a7e4_1 + - libcublas=12.1.3.1=hd3aeb46_0 + - libcublas-dev=12.1.3.1=hd3aeb46_0 + - libcufft=11.0.2.54=hd3aeb46_0 + - libcufft-dev=11.0.2.54=hd3aeb46_0 + - libcufile=1.6.1.9=hd3aeb46_0 + - libcufile-dev=1.6.1.9=hd3aeb46_0 + - libcups=2.3.3=h4637d8d_4 + - libcurand=10.3.2.106=hd3aeb46_0 + - libcurand-dev=10.3.2.106=hd3aeb46_0 + - libcusolver=11.4.5.107=hd3aeb46_0 + - libcusolver-dev=11.4.5.107=hd3aeb46_0 + - libcusparse=12.1.0.106=hd3aeb46_0 + - libcusparse-dev=12.1.0.106=hd3aeb46_0 + - libedit=3.1.20250104=pl5321h7949ede_0 + - libevent=2.1.12=hf998b51_1 + - libexpat=2.6.4=h5888daf_0 + - libffi=3.4.6=h2dba641_0 + - libflac=1.4.3=h59595ed_0 + - libgcc=14.2.0=h767d61c_2 + - libgcc-devel_linux-64=12.4.0=h1762d19_102 + - libgcc-ng=14.2.0=h69a702a_2 + - libgcrypt-lib=1.11.0=hb9d3cd8_2 + - libgettextpo=0.23.1=h5888daf_0 + - libgettextpo-devel=0.23.1=h5888daf_0 + - libglib=2.82.2=h2ff4ddf_1 + - libgomp=14.2.0=h767d61c_2 + - libgpg-error=1.51=hbd13f7d_1 + - libiconv=1.18=h4ce23a2_1 + - libjpeg-turbo=3.0.0=hd590300_1 + - libllvm15=15.0.7=hb3ce162_4 + - libllvm19=19.1.2=ha7bfdaf_0 + - liblzma=5.6.4=hb9d3cd8_0 + - liblzma-devel=5.6.4=hb9d3cd8_0 + - libnpp=12.1.0.40=hd3aeb46_0 + - libnpp-dev=12.1.0.40=hd3aeb46_0 + - libnsl=2.0.1=hd590300_0 + - libnuma=2.0.18=h4ab18f5_2 + - libnvjitlink=12.1.105=hd3aeb46_0 + - libnvjitlink-dev=12.1.105=hd3aeb46_0 + - libnvjpeg=12.2.0.2=h59595ed_0 + - libnvjpeg-dev=12.2.0.2=ha770c72_0 + - libogg=1.3.5=h4ab18f5_0 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.47=h943b412_0 + - libpq=16.8=h87c4ccc_0 + - libsanitizer=12.4.0=ha732cd4_2 + - libsndfile=1.2.2=hc60ed4a_1 + - libsqlite=3.49.1=hee588c1_2 + - libstdcxx=14.2.0=h8f9b012_2 + - libstdcxx-devel_linux-64=12.4.0=h1762d19_102 + - libstdcxx-ng=14.2.0=h4852527_2 + - libsystemd0=257.4=h4e0b6ca_1 + - libuuid=2.38.1=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libxcb=1.15=h0b41bf4_0 + - libxkbcommon=1.7.0=h662e7e4_0 + - libxkbfile=1.1.0=h166bdaf_1 + - libxml2=2.12.7=h4c95cb1_3 + - libzlib=1.3.1=hb9d3cd8_2 + - lz4-c=1.10.0=h5888daf_1 + - mpg123=1.32.9=hc50e24c_0 + - mysql-common=8.3.0=h70512c7_5 + - mysql-libs=8.3.0=ha479ceb_5 + - ncurses=6.5=h2d0b736_3 + - nsight-compute=2023.1.1.4=h3718151_0 + - nspr=4.36=h5888daf_0 + - nss=3.108=h159eef7_0 + - ocl-icd=2.3.2=hb9d3cd8_2 + - opencl-headers=2024.10.24=h5888daf_0 + - openssl=3.4.1=h7b32b05_0 + - packaging=24.2=pyhd8ed1ab_2 + - pcre2=10.44=hba22ea6_2 + - pip=25.0.1=pyh8b19718_0 + - pixman=0.44.2=h29eaf8c_0 + - pthread-stubs=0.4=hb9d3cd8_1002 + - pulseaudio-client=17.0=hb77b528_0 + - python=3.11.0=he550d4f_1_cpython + - qt-main=5.15.8=hc9dc06e_21 + - readline=8.2=h8c095d6_2 + - setuptools=75.8.2=pyhff2d567_0 + - sysroot_linux-64=2.17=h0157908_18 + - tk=8.6.13=noxft_h4845f30_101 + - tzdata=2025b=h78e105d_0 + - wayland=1.23.1=h3e06ad9_0 + - wheel=0.45.1=pyhd8ed1ab_1 + - xcb-util=0.4.0=hd590300_1 + - xcb-util-image=0.4.0=h8ee46fc_1 + - xcb-util-keysyms=0.4.0=h8ee46fc_1 + - xcb-util-renderutil=0.3.9=hd590300_1 + - xcb-util-wm=0.4.1=h8ee46fc_1 + - xkeyboard-config=2.42=h4ab18f5_0 + - xorg-compositeproto=0.4.2=hb9d3cd8_1002 + - xorg-damageproto=1.2.1=hb9d3cd8_1003 + - xorg-fixesproto=5.0=hb9d3cd8_1003 + - xorg-inputproto=2.3.2=hb9d3cd8_1003 + - xorg-kbproto=1.0.7=hb9d3cd8_1003 + - xorg-libice=1.1.2=hb9d3cd8_0 + - xorg-libsm=1.2.6=he73a12e_0 + - xorg-libx11=1.8.9=h8ee46fc_0 + - xorg-libxau=1.0.12=hb9d3cd8_0 + - xorg-libxcomposite=0.4.6=h0b41bf4_1 + - xorg-libxdamage=1.1.5=h7f98852_1 + - xorg-libxdmcp=1.1.5=hb9d3cd8_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxfixes=5.0.3=h7f98852_1004 + - xorg-libxi=1.7.10=h4bc722e_1 + - xorg-libxrandr=1.5.2=h7f98852_1 + - xorg-libxrender=0.9.11=hd590300_0 + - xorg-libxtst=1.2.5=h4bc722e_0 + - xorg-randrproto=1.5.0=hb9d3cd8_1002 + - xorg-recordproto=1.14.2=hb9d3cd8_1003 + - xorg-renderproto=0.11.1=hb9d3cd8_1003 + - xorg-util-macros=1.20.2=hb9d3cd8_0 + - xorg-xextproto=7.3.0=hb9d3cd8_1004 + - xorg-xf86vidmodeproto=2.3.1=hb9d3cd8_1005 + - xorg-xproto=7.0.31=hb9d3cd8_1008 + - xz=5.6.4=hbcc6ac9_0 + - xz-gpl-tools=5.6.4=hbcc6ac9_0 + - xz-tools=5.6.4=hb9d3cd8_0 + - zlib=1.3.1=hb9d3cd8_2 + - zstd=1.5.7=hb8e6e7a_2 diff --git a/thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb b/thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e99c69f3364f84b61bd960657371c993cdabf50a --- /dev/null +++ b/thirdparty/sam3d/sam3d/notebook/demo_3db_mesh_alignment.ipynb @@ -0,0 +1,149 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SAM 3D Body (3DB) Mesh Alignment to SAM 3D Object Scale\n", + "\n", + "This notebook processes a single 3DB mesh and aligns it to the SAM 3D Objects scale.\n", + "\n", + "**Input Data:**\n", + "- `images/human_object/image.jpg` - Input image for MoGe\n", + "- `meshes/human_object/3DB_results/mask_human.png` - Human mask\n", + "- `meshes/human_object/3DB_results/human.ply` - Single 3DB mesh in OpenGL coordinates\n", + "- `meshes/human_object/3DB_results/focal_length.json` - 3DB focal length\n", + "\n", + "**Output:**\n", + "- `meshes/human_object/aligned_meshes/human_aligned.ply` - Aligned 3DB mesh in OpenGL coordinates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "from mesh_alignment import process_and_save_alignment\n", + "\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "print(f\"Using device: {device}\")\n", + "PATH = os.getcwd()\n", + "print(f\"Current working directory: {PATH}\")\n", + "\n", + "# Please inference the SAM 3D Body (3DB) Repo (https://github.com/facebookresearch/sam-3d-body) to get the 3DB Results\n", + "image_path = f\"{PATH}/images/human_object/image.png\"\n", + "mask_path = f\"{PATH}/meshes/human_object/3DB_results/mask_human.png\"\n", + "mesh_path = f\"{PATH}/meshes/human_object/3DB_results/human.ply\"\n", + "focal_length_json_path = f\"{PATH}/meshes/human_object/3DB_results/focal_length.json\"\n", + "output_dir = f\"{PATH}/meshes/human_object/aligned_meshes\"\n", + "os.makedirs(output_dir, exist_ok=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load and Display Input Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_image = Image.open(image_path)\n", + "mask = Image.open(mask_path).convert('L')\n", + "fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n", + "axes[0].imshow(input_image)\n", + "axes[0].set_title('Input Image')\n", + "axes[0].axis('off')\n", + "axes[1].imshow(mask, cmap='gray')\n", + "axes[1].set_title('Mask')\n", + "axes[1].axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Process and Save Aligned Mesh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "success, output_mesh_path, result = process_and_save_alignment(\n", + " mesh_path=mesh_path,\n", + " mask_path=mask_path,\n", + " image_path=image_path,\n", + " output_dir=output_dir,\n", + " device=device,\n", + " focal_length_json_path=focal_length_json_path\n", + ")\n", + "\n", + "if success:\n", + " print(f\"Alignment completed successfully! Output: {output_mesh_path}\")\n", + "else:\n", + " print(\"Alignment failed!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Interactive 3D Visualization\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from mesh_alignment import visualize_meshes_interactive\n", + "\n", + "aligned_mesh_path = f\"{PATH}/meshes/human_object/aligned_meshes/human_aligned.ply\"\n", + "dfy_mesh_path = f\"{PATH}/meshes/human_object/3Dfy_results/0.glb\"\n", + "\n", + "demo, combined_glb_path = visualize_meshes_interactive(\n", + " aligned_mesh_path=aligned_mesh_path,\n", + " dfy_mesh_path=dfy_mesh_path,\n", + " share=True\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam3d_objects-3dfy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb b/thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ecd817f44ffec1a7e9a73e78eee778ffa7691cb7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/notebook/demo_multi_object.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports and Model Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import uuid\n", + "import imageio\n", + "import numpy as np\n", + "from IPython.display import Image as ImageDisplay\n", + "\n", + "from inference import Inference, ready_gaussian_for_video_rendering, load_image, load_masks, display_image, make_scene, render_video, interactive_visualizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PATH = os.getcwd()\n", + "TAG = \"hf\"\n", + "config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n", + "inference = Inference(config_path, compile=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load input image to lift to 3D (multiple objects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n", + "IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n", + "\n", + "image = load_image(IMAGE_PATH)\n", + "masks = load_masks(os.path.dirname(IMAGE_PATH), extension=\".png\")\n", + "display_image(image, masks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Generate Gaussian Splats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "outputs = [inference(image, mask, seed=42) for mask in masks]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Gaussian Splat of the Scene\n", + "### a. Animated Gif" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scene_gs = make_scene(*outputs)\n", + "scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n", + "\n", + "# export gaussian splatting (as point cloud)\n", + "scene_gs.save_ply(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")\n", + "\n", + "video = render_video(\n", + " scene_gs,\n", + " r=1,\n", + " fov=60,\n", + " resolution=512,\n", + ")[\"color\"]\n", + "\n", + "# save video as gif\n", + "imageio.mimsave(\n", + " os.path.join(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.gif\"),\n", + " video,\n", + " format=\"GIF\",\n", + " duration=1000 / 30, # default assuming 30fps from the input MP4\n", + " loop=0, # 0 means loop indefinitely\n", + ")\n", + "\n", + "# notebook display\n", + "ImageDisplay(url=f\"gaussians/multi/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\",)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### b. Interactive Visualizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# might take a while to load (black screen)\n", + "interactive_visualizer(f\"{PATH}/gaussians/multi/{IMAGE_NAME}.ply\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam3d-objects", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb b/thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4ce07ca05646cd8ce6f1367a575b205c27679d61 --- /dev/null +++ b/thirdparty/sam3d/sam3d/notebook/demo_single_object.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports and Model Loading" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import imageio\n", + "import uuid\n", + "from IPython.display import Image as ImageDisplay\n", + "from inference import Inference, ready_gaussian_for_video_rendering, render_video, load_image, load_single_mask, display_image, make_scene, interactive_visualizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PATH = os.getcwd()\n", + "TAG = \"hf\"\n", + "config_path = f\"{PATH}/../checkpoints/{TAG}/pipeline.yaml\"\n", + "inference = Inference(config_path, compile=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load input image to lift to 3D (single object)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "IMAGE_PATH = f\"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n", + "IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))\n", + "\n", + "image = load_image(IMAGE_PATH)\n", + "mask = load_single_mask(os.path.dirname(IMAGE_PATH), index=14)\n", + "display_image(image, masks=[mask])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Generate Gaussian Splat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# run model\n", + "output = inference(image, mask, seed=42)\n", + "\n", + "# export gaussian splat (as point cloud)\n", + "output[\"gs\"].save_ply(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Gaussian Splat\n", + "### a. Animated Gif" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# render gaussian splat\n", + "scene_gs = make_scene(output)\n", + "scene_gs = ready_gaussian_for_video_rendering(scene_gs)\n", + "\n", + "video = render_video(\n", + " scene_gs,\n", + " r=1,\n", + " fov=60,\n", + " pitch_deg=15,\n", + " yaw_start_deg=-45,\n", + " resolution=512,\n", + ")[\"color\"]\n", + "\n", + "# save video as gif\n", + "imageio.mimsave(\n", + " os.path.join(f\"{PATH}/gaussians/single/{IMAGE_NAME}.gif\"),\n", + " video,\n", + " format=\"GIF\",\n", + " duration=1000 / 30, # default assuming 30fps from the input MP4\n", + " loop=0, # 0 means loop indefinitely\n", + ")\n", + "\n", + "# notebook display\n", + "ImageDisplay(url=f\"gaussians/single/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### b. Interactive Visualizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# might take a while to load (black screen)\n", + "interactive_visualizer(f\"{PATH}/gaussians/single/{IMAGE_NAME}.ply\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sam3d_objects-3dfy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/thirdparty/sam3d/sam3d/notebook/inference.py b/thirdparty/sam3d/sam3d/notebook/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb55534273bb41a8df72fcadd52e30c5a967c89 --- /dev/null +++ b/thirdparty/sam3d/sam3d/notebook/inference.py @@ -0,0 +1,414 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os + +# not ideal to put that here +os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"] +os.environ["LIDRA_SKIP_INIT"] = "true" + +import sys +from typing import Union, Optional, List, Callable +import numpy as np +from PIL import Image +from omegaconf import OmegaConf, DictConfig, ListConfig +from hydra.utils import instantiate, get_method +import torch +import math +import utils3d +import shutil +import subprocess +import seaborn as sns +from PIL import Image +import numpy as np +import gradio as gr +import matplotlib.pyplot as plt +from copy import deepcopy +from kaolin.visualize import IpyTurntableVisualizer +from kaolin.render.camera import Camera, CameraExtrinsics, PinholeIntrinsics +import builtins +from pytorch3d.transforms import quaternion_multiply, quaternion_invert + +import sam3d_objects # REMARK(Pierre) : do not remove this import +from sam3d_objects.pipeline.inference_pipeline_pointmap import InferencePipelinePointMap +from sam3d_objects.model.backbone.tdfy_dit.utils import render_utils + +from sam3d_objects.utils.visualization import SceneVisualizer + +__all__ = ["Inference"] + +WHITELIST_FILTERS = [ + lambda target: target.split(".", 1)[0] in {"sam3d_objects", "torch", "torchvision", "moge"}, +] + +BLACKLIST_FILTERS = [ + lambda target: get_method(target) + in { + builtins.exec, + builtins.eval, + builtins.__import__, + os.kill, + os.system, + os.putenv, + os.remove, + os.removedirs, + os.rmdir, + os.fchdir, + os.setuid, + os.fork, + os.forkpty, + os.killpg, + os.rename, + os.renames, + os.truncate, + os.replace, + os.unlink, + os.fchmod, + os.fchown, + os.chmod, + os.chown, + os.chroot, + os.fchdir, + os.lchown, + os.getcwd, + os.chdir, + shutil.rmtree, + shutil.move, + shutil.chown, + subprocess.Popen, + builtins.help, + }, +] + + +class Inference: + # public facing inference API + # only put publicly exposed arguments here + def __init__(self, config_file: str, compile: bool = False): + # load inference pipeline + config = OmegaConf.load(config_file) + config.rendering_engine = "pytorch3d" # overwrite to disable nvdiffrast + config.compile_model = compile + config.workspace_dir = os.path.dirname(config_file) + check_hydra_safety(config, WHITELIST_FILTERS, BLACKLIST_FILTERS) + self._pipeline: InferencePipelinePointMap = instantiate(config) + + def merge_mask_to_rgba(self, image, mask): + mask = mask.astype(np.uint8) * 255 + mask = mask[..., None] + # embed mask in alpha channel + rgba_image = np.concatenate([image[..., :3], mask], axis=-1) + return rgba_image + + def __call__( + self, + image: Union[Image.Image, np.ndarray], + mask: Optional[Union[None, Image.Image, np.ndarray]], + seed: Optional[int] = None, + pointmap=None, + ) -> dict: + image = self.merge_mask_to_rgba(image, mask) + return self._pipeline.run( + image, + None, + seed, + stage1_only=False, + with_mesh_postprocess=False, + with_texture_baking=False, + with_layout_postprocess=True, + use_vertex_color=True, + stage1_inference_steps=None, + pointmap=pointmap, + ) + + +def _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.sin(pitch), + torch.cos(yaw) * torch.cos(pitch), + ] + ).cuda() + * r + ) + extr = utils3d.torch.extrinsics_look_at( + orig, + torch.tensor([0, 0, 0]).float().cuda(), + torch.tensor([0, 1, 0]).float().cuda(), + ) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_video( + sample, + resolution=512, + bg_color=(0, 0, 0), + num_frames=300, + r=2.0, + fov=40, + pitch_deg=0, + yaw_start_deg=-90, + **kwargs, +): + + yaws = ( + torch.linspace(0, 2 * torch.pi, num_frames) + math.radians(yaw_start_deg) + ).tolist() + pitch = [math.radians(pitch_deg)] * num_frames + + extr, intr = _yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + + return render_utils.render_frames( + sample, + extr, + intr, + {"resolution": resolution, "bg_color": bg_color, "backend": "gsplat"}, + **kwargs, + ) + + +def ready_gaussian_for_video_rendering(scene_gs, in_place=False, fix_alignment=False): + if fix_alignment: + scene_gs = _fix_gaussian_alignment(scene_gs, in_place=in_place) + scene_gs = normalized_gaussian(scene_gs, in_place=fix_alignment) + return scene_gs + + +def _fix_gaussian_alignment(scene_gs, in_place=False): + if not in_place: + scene_gs = deepcopy(scene_gs) + + device = scene_gs._xyz.device + dtype = scene_gs._xyz.dtype + scene_gs._xyz = ( + scene_gs._xyz + @ torch.tensor( + [ + [-1, 0, 0], + [0, 0, 1], + [0, 1, 0], + ], + device=device, + dtype=dtype, + ).T + ) + return scene_gs + + +def normalized_gaussian(scene_gs, in_place=False, outlier_percentile=None): + if not in_place: + scene_gs = deepcopy(scene_gs) + + orig_xyz = scene_gs.get_xyz + orig_scale = scene_gs.get_scaling + + active_mask = (scene_gs.get_opacity > 0.9).squeeze() + inv_scale = ( + orig_xyz[active_mask].max(dim=0)[0] - orig_xyz[active_mask].min(dim=0)[0] + ).max() + norm_scale = orig_scale / inv_scale + norm_xyz = orig_xyz / inv_scale + + if outlier_percentile is None: + lower_bound_xyz = torch.min(norm_xyz[active_mask], dim=0)[0] + upper_bound_xyz = torch.max(norm_xyz[active_mask], dim=0)[0] + else: + lower_bound_xyz = torch.quantile( + norm_xyz[active_mask], + outlier_percentile, + dim=0, + ) + upper_bound_xyz = torch.quantile( + norm_xyz[active_mask], + 1.0 - outlier_percentile, + dim=0, + ) + + center = (lower_bound_xyz + upper_bound_xyz) / 2 + norm_xyz = norm_xyz - center + scene_gs.from_xyz(norm_xyz) + scene_gs.mininum_kernel_size /= inv_scale.item() + scene_gs.from_scaling(norm_scale) + return scene_gs + + +def make_scene(*outputs, in_place=False): + if not in_place: + outputs = [deepcopy(output) for output in outputs] + + all_outs = [] + minimum_kernel_size = float("inf") + for output in outputs: + # move gaussians to scene frame of reference + PC = SceneVisualizer.object_pointcloud( + points_local=output["gaussian"][0].get_xyz.unsqueeze(0), + quat_l2c=output["rotation"], + trans_l2c=output["translation"], + scale_l2c=output["scale"], + ) + output["gaussian"][0].from_xyz(PC.points_list()[0]) + # must ... ROTATE + output["gaussian"][0].from_rotation( + quaternion_multiply( + quaternion_invert(output["rotation"]), + output["gaussian"][0].get_rotation, + ) + ) + scale = output["gaussian"][0].get_scaling + adjusted_scale = scale * output["scale"] + assert ( + output["scale"][0, 0].item() + == output["scale"][0, 1].item() + == output["scale"][0, 2].item() + ) + output["gaussian"][0].mininum_kernel_size *= output["scale"][0, 0].item() + adjusted_scale = torch.maximum( + adjusted_scale, + torch.tensor( + output["gaussian"][0].mininum_kernel_size * 1.1, + device=adjusted_scale.device, + ), + ) + output["gaussian"][0].from_scaling(adjusted_scale) + minimum_kernel_size = min( + minimum_kernel_size, + output["gaussian"][0].mininum_kernel_size, + ) + all_outs.append(output) + + # merge gaussians + scene_gs = all_outs[0]["gaussian"][0] + scene_gs.mininum_kernel_size = minimum_kernel_size + for out in all_outs[1:]: + out_gs = out["gaussian"][0] + scene_gs._xyz = torch.cat([scene_gs._xyz, out_gs._xyz], dim=0) + scene_gs._features_dc = torch.cat( + [scene_gs._features_dc, out_gs._features_dc], dim=0 + ) + scene_gs._scaling = torch.cat([scene_gs._scaling, out_gs._scaling], dim=0) + scene_gs._rotation = torch.cat([scene_gs._rotation, out_gs._rotation], dim=0) + scene_gs._opacity = torch.cat([scene_gs._opacity, out_gs._opacity], dim=0) + + return scene_gs + + +def check_target( + target: str, + whitelist_filters: List[Callable], + blacklist_filters: List[Callable], +): + if any(filt(target) for filt in whitelist_filters): + if not any(filt(target) for filt in blacklist_filters): + return + raise RuntimeError( + f"target '{target}' is not allowed to be hydra instantiated, if this is a mistake, please do modify the whitelist_filters / blacklist_filters" + ) + + +def check_hydra_safety( + config: DictConfig, + whitelist_filters: List[Callable], + blacklist_filters: List[Callable], +): + to_check = [config] + while len(to_check) > 0: + node = to_check.pop() + if isinstance(node, DictConfig): + to_check.extend(list(node.values())) + if "_target_" in node: + check_target(node["_target_"], whitelist_filters, blacklist_filters) + elif isinstance(node, ListConfig): + to_check.extend(list(node)) + + +def load_image(path): + image = Image.open(path) + image = np.array(image) + image = image.astype(np.uint8) + return image + + +def load_mask(path): + mask = load_image(path) + mask = mask > 0 + if mask.ndim == 3: + mask = mask[..., -1] + return mask + + +def load_single_mask(folder_path, index=0, extension=".png"): + masks = load_masks(folder_path, [index], extension) + return masks[0] + + +def load_masks(folder_path, indices_list=None, extension=".png"): + masks = [] + indices_list = [] if indices_list is None else list(indices_list) + if not len(indices_list) > 0: # get all all masks if not provided + idx = 0 + while os.path.exists(os.path.join(folder_path, f"{idx}{extension}")): + indices_list.append(idx) + idx += 1 + + for idx in indices_list: + mask_path = os.path.join(folder_path, f"{idx}{extension}") + assert os.path.exists(mask_path), f"Mask path {mask_path} does not exist" + mask = load_mask(mask_path) + masks.append(mask) + return masks + + +def display_image(image, masks=None): + def imshow(image, ax): + ax.axis("off") + ax.imshow(image) + + grid = (1, 1) if masks is None else (2, 2) + fig, axes = plt.subplots(*grid) + if masks is not None: + mask_colors = sns.color_palette("husl", len(masks)) + black_image = np.zeros_like(image[..., :3], dtype=float) # background + mask_display = np.copy(black_image) + mask_union = np.zeros_like(image[..., :3]) + for i, mask in enumerate(masks): + mask_display[mask] = mask_colors[i] + mask_union |= mask[..., None] if mask.ndim == 2 else mask + imshow(black_image, axes[0, 1]) + imshow(mask_display, axes[1, 0]) + imshow(image * mask_union, axes[1, 1]) + + image_axe = axes if masks is None else axes[0, 0] + imshow(image, image_axe) + + fig.tight_layout(pad=0) + fig.show() + + +def interactive_visualizer(ply_path): + with gr.Blocks() as demo: + gr.Markdown("# 3D Gaussian Splatting (black-screen loading might take a while)") + gr.Model3D( + value=ply_path, # splat file + label="3D Scene", + ) + demo.launch(share=True) diff --git a/thirdparty/sam3d/sam3d/notebook/mesh_alignment.py b/thirdparty/sam3d/sam3d/notebook/mesh_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..d211e39ab17670d8a2cf0e0760a401f723c8f506 --- /dev/null +++ b/thirdparty/sam3d/sam3d/notebook/mesh_alignment.py @@ -0,0 +1,469 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +""" +SAM 3D Body (3DB) Mesh Alignment Utilities +Handles alignment of 3DB meshes to SAM 3D Object, same as MoGe point cloud scale. +""" + +import os +import math +import json +import numpy as np +import torch +import trimesh +from PIL import Image +import torch.nn.functional as F +from pytorch3d.structures import Meshes +from pytorch3d.renderer import PerspectiveCameras, RasterizationSettings, MeshRasterizer, TexturesVertex +from moge.model.v1 import MoGeModel + + +def load_3db_mesh(mesh_path, device='cuda'): + """Load 3DB mesh and convert from OpenGL to PyTorch3D coordinates.""" + mesh = trimesh.load(mesh_path) + vertices = np.array(mesh.vertices) + faces = np.array(mesh.faces) + + # Convert from OpenGL to PyTorch3D coordinates + vertices[:, 0] *= -1 # Flip X + vertices[:, 2] *= -1 # Flip Z + + vertices = torch.from_numpy(vertices).float().to(device) + faces = torch.from_numpy(faces).long().to(device) + return vertices, faces + + +def get_moge_pointcloud(image_tensor, device='cuda'): + """Generate MoGe point cloud from image tensor.""" + moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) + moge_model.eval() + with torch.no_grad(): + moge_output = moge_model.infer(image_tensor) + return moge_output + + +def denormalize_intrinsics(norm_K, height, width): + """Convert normalized intrinsics to absolute pixel coordinates.""" + cx_norm, cy_norm = norm_K[0, 2], norm_K[1, 2] + fx_norm, fy_norm = norm_K[0, 0], norm_K[1, 1] + + fx_abs = fx_norm * width + fy_abs = fy_norm * height + cx_abs = cx_norm * width + cy_abs = cy_norm * height + fx_abs = fy_abs + + return np.array([ + [fx_abs, 0.0, cx_abs], + [0.0, fy_abs, cy_abs], + [0.0, 0.0, 1.0] + ]) + + +def crop_mesh_with_mask(vertices, faces, focal_length, mask, device='cuda'): + """Crop mesh vertices to only those visible in the mask.""" + textures = TexturesVertex(verts_features=torch.ones_like(vertices)[None]) + mesh = Meshes(verts=[vertices], faces=[faces], textures=textures) + + H, W = mask.shape[-2:] + fx = fy = focal_length + cx, cy = W / 2.0, H / 2.0 + + camera = PerspectiveCameras( + focal_length=((fx, fy),), + principal_point=((cx, cy),), + image_size=((H, W),), + in_ndc=False, device=device + ) + + raster_settings = RasterizationSettings( + image_size=(H, W), blur_radius=0.0, faces_per_pixel=1, + cull_backfaces=False, bin_size=0, + ) + + rasterizer = MeshRasterizer(cameras=camera, raster_settings=raster_settings) + fragments = rasterizer(mesh) + + face_indices = fragments.pix_to_face[0, ..., 0] # (H, W) + visible_mask = (mask > 0) & (face_indices >= 0) + visible_face_ids = face_indices[visible_mask] + + visible_faces = faces[visible_face_ids] + visible_vert_ids = torch.unique(visible_faces) + verts_cropped = vertices[visible_vert_ids] + + return verts_cropped, visible_mask + + +def extract_target_points(pointmap, visible_mask): + """Extract target points from MoGe pointmap using visible mask.""" + target_points = pointmap[visible_mask.bool()] + + # Convert from MoGe coordinates to PyTorch3D coordinates + target_points[:, 0] *= -1 + target_points[:, 1] *= -1 + + # Remove flying points using adaptive quantile filtering + z_range = torch.max(target_points[:, 2]) - torch.min(target_points[:, 2]) + if z_range > 6.0: + thresh = 0.90 + elif z_range > 2.0: + thresh = 0.93 + else: + thresh = 0.95 + + depth_quantile = torch.quantile(target_points[:, 2], thresh) + target_points = target_points[target_points[:, 2] <= depth_quantile] + + # Remove infinite values + finite_mask = torch.isfinite(target_points).all(dim=1) + target_points = target_points[finite_mask] + + return target_points + + +def align_mesh_to_pointcloud(vertices, target_points): + """Align mesh vertices to target point cloud using scale and translation.""" + if target_points.shape[0] == 0: + print("[WARNING] No target points for alignment!") + return vertices, torch.tensor(1.0), torch.zeros(3) + + # Scale alignment based on height + height_src = torch.max(vertices[:, 1]) - torch.min(vertices[:, 1]) + height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1]) + scale_factor = height_tgt / height_src + + vertices_scaled = vertices * scale_factor + + # Translation alignment based on centers + center_src = torch.mean(vertices_scaled, dim=0) + center_tgt = torch.mean(target_points, dim=0) + translation = center_tgt - center_src + + vertices_aligned = vertices_scaled + translation + return vertices_aligned, scale_factor, translation + + +def load_mask_for_alignment(mask_path): + """Load mask image as numpy array.""" + mask = Image.open(mask_path).convert('L') + mask_array = np.array(mask) / 255.0 + return mask_array + + +def load_focal_length_from_json(json_path): + """Load focal length from JSON file.""" + try: + with open(json_path, 'r') as f: + data = json.load(f) + focal_length = data.get('focal_length') + if focal_length is None: + raise ValueError("'focal_length' key not found in JSON file") + print(f"[INFO] Loaded focal length from {json_path}: {focal_length}") + return focal_length + except Exception as e: + print(f"[ERROR] Failed to load focal length from {json_path}: {e}") + raise + + +def process_3db_alignment(mesh_path, mask_path, image_path, device='cuda', focal_length_json_path=None): + """Complete pipeline for aligning 3DB mesh to MoGe scale.""" + print(f"[INFO] Processing alignment...") + + # Load input data + vertices, faces = load_3db_mesh(mesh_path, device) + + # Load and preprocess image + image = Image.open(image_path).convert('RGB') + image_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 + image_tensor = image_tensor.to(device) + + # Load mask and resize to match image + H, W = image_tensor.shape[1:] + mask = load_mask_for_alignment(mask_path) + if mask.shape != (H, W): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + mask = mask.resize((W, H), Image.NEAREST) + mask = np.array(mask) / 255.0 + mask = torch.from_numpy(mask).float().to(device) + + # Generate MoGe point cloud + print("[INFO] Generating MoGe point cloud...") + moge_output = get_moge_pointcloud(image_tensor, device) + + # Load focal length from JSON if provided, otherwise compute from MoGe intrinsics + if focal_length_json_path is not None: + focal_length = load_focal_length_from_json(focal_length_json_path) + else: + # Compute camera parameters from MoGe intrinsics (fallback) + intrinsics = denormalize_intrinsics(moge_output['intrinsics'].cpu().numpy(), H, W) + focal_length = intrinsics[1, 1] # Use fy + print(f"[INFO] Using computed focal length from MoGe: {focal_length}") + + # Crop mesh using mask + print("[INFO] Cropping mesh with mask...") + verts_cropped, visible_mask = crop_mesh_with_mask(vertices, faces, focal_length, mask, device) + + # Extract target points from MoGe + print("[INFO] Extracting target points...") + target_points = extract_target_points(moge_output['points'], visible_mask) + + if target_points.shape[0] == 0: + print("[ERROR] No valid target points found!") + return None + + # Perform alignment + print("[INFO] Aligning mesh to point cloud...") + aligned_vertices, scale_factor, translation = align_mesh_to_pointcloud(verts_cropped, target_points) + + # Apply alignment to full mesh + full_aligned_vertices = (vertices * scale_factor) + translation + + # Convert back to OpenGL coordinates for final output + final_vertices_opengl = full_aligned_vertices.cpu().numpy() + final_vertices_opengl[:, 0] *= -1 + final_vertices_opengl[:, 2] *= -1 + + results = { + 'aligned_vertices_opengl': final_vertices_opengl, + 'faces': faces.cpu().numpy(), + 'scale_factor': scale_factor.item(), + 'translation': translation.cpu().numpy(), + 'focal_length': focal_length, + 'target_points_count': target_points.shape[0], + 'cropped_vertices_count': verts_cropped.shape[0] + } + + print(f"[INFO] Alignment completed - Scale: {scale_factor.item():.4f}, Target points: {target_points.shape[0]}") + return results + + +def process_and_save_alignment(mesh_path, mask_path, image_path, output_dir, device='cuda', focal_length_json_path=None): + """ + Complete pipeline for processing 3DB alignment and saving the result. + + Args: + mesh_path: Path to input 3DB mesh (.ply) + mask_path: Path to mask image (.png) + image_path: Path to input image (.jpg) + output_dir: Directory to save aligned mesh + device: Device to use ('cuda' or 'cpu') + focal_length_json_path: Optional path to focal length JSON file + + Returns: + tuple: (success: bool, output_mesh_path: str or None, result_info: dict or None) + """ + try: + print("[INFO] Starting 3DB mesh alignment pipeline...") + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Process alignment + result = process_3db_alignment( + mesh_path=mesh_path, + mask_path=mask_path, + image_path=image_path, + device=device, + focal_length_json_path=focal_length_json_path + ) + + if result is not None: + # Save aligned mesh + output_mesh_path = os.path.join(output_dir, 'human_aligned.ply') + aligned_mesh = trimesh.Trimesh( + vertices=result['aligned_vertices_opengl'], + faces=result['faces'] + ) + aligned_mesh.export(output_mesh_path) + + print(f" SUCCESS! Saved aligned mesh to: {output_mesh_path}") + return True, output_mesh_path, result + else: + print(" ERROR: Failed to process mesh alignment") + return False, None, None + + except Exception as e: + print(f" ERROR: Exception during processing: {e}") + import traceback + traceback.print_exc() + return False, None, None + + finally: + print(" Processing complete!") + + +def visualize_meshes_interactive(aligned_mesh_path, dfy_mesh_path, output_dir=None, share=True, height=600): + """ + Interactive Gradio-based 3D visualization of aligned human and object meshes. + + Args: + aligned_mesh_path: Path to aligned mesh PLY file + dfy_mesh_path: Path to 3Dfy GLB file + output_dir: Directory to save combined GLB file (defaults to same dir as aligned_mesh_path) + share: Whether to create a public shareable link (default: True) + height: Height of the 3D viewer in pixels (default: 600) + + Returns: + tuple: (demo, combined_glb_path) - Gradio demo object and path to combined GLB file + """ + import gradio as gr + + print("Loading meshes for interactive visualization...") + + try: + # Load aligned mesh (PLY) + aligned_mesh = trimesh.load(aligned_mesh_path) + print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") + + # Load 3Dfy mesh (GLB - handle scene structure) + dfy_scene = trimesh.load(dfy_mesh_path) + + if hasattr(dfy_scene, 'dump'): # It's a scene + dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] + if len(dfy_meshes) == 1: + dfy_mesh = dfy_meshes[0] + elif len(dfy_meshes) > 1: + dfy_mesh = trimesh.util.concatenate(dfy_meshes) + else: + raise ValueError("No valid meshes in GLB file") + else: + dfy_mesh = dfy_scene + + print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") + + # Create combined scene + scene = trimesh.Scene() + + # Add both meshes with different colors + aligned_copy = aligned_mesh.copy() + aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red for aligned human + scene.add_geometry(aligned_copy, node_name="sam3d_aligned_human") + + dfy_copy = dfy_mesh.copy() + dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue for 3Dfy object + scene.add_geometry(dfy_copy, node_name="dfy_object") + + # Determine output path + if output_dir is None: + output_dir = os.path.dirname(aligned_mesh_path) + os.makedirs(output_dir, exist_ok=True) + + combined_glb_path = os.path.join(output_dir, 'combined_scene.glb') + scene.export(combined_glb_path) + print(f"Exported combined scene to: {combined_glb_path}") + + # Create interactive Gradio viewer + with gr.Blocks() as demo: + gr.Markdown("# 3D Mesh Alignment Visualization") + gr.Markdown("**Red**: SAM 3D Body Aligned Human | **Blue**: 3Dfy Object") + gr.Model3D( + value=combined_glb_path, + label="Combined 3D Scene (Interactive)", + height=height + ) + + # Launch the viewer + print("Launching interactive 3D viewer...") + demo.launch(share=share) + + return demo, combined_glb_path + + except Exception as e: + print(f"ERROR in visualization: {e}") + import traceback + traceback.print_exc() + return None, None + + +def visualize_meshes_comparison(aligned_mesh_path, dfy_mesh_path, use_interactive=False): + """ + Simple visualization of both meshes in a single 3D plot. + + DEPRECATED: Use visualize_meshes_interactive() for better interactive visualization. + + Args: + aligned_mesh_path: Path to aligned mesh PLY file + dfy_mesh_path: Path to 3Dfy GLB file + use_interactive: Whether to attempt trimesh scene viewer (default: False) + + Returns: + tuple: (aligned_mesh, dfy_mesh) trimesh objects or (None, None) if failed + """ + import matplotlib.pyplot as plt + + print("Loading meshes for visualization...") + + try: + # Load aligned mesh (PLY) + aligned_mesh = trimesh.load(aligned_mesh_path) + print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") + + # Load 3Dfy mesh (GLB - handle scene structure) + dfy_scene = trimesh.load(dfy_mesh_path) + + if hasattr(dfy_scene, 'dump'): # It's a scene + dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] + if len(dfy_meshes) == 1: + dfy_mesh = dfy_meshes[0] + elif len(dfy_meshes) > 1: + dfy_mesh = trimesh.util.concatenate(dfy_meshes) + else: + raise ValueError("No valid meshes in GLB file") + else: + dfy_mesh = dfy_scene + + print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") + + # Create single 3D plot with both meshes + fig = plt.figure(figsize=(12, 10)) + ax = fig.add_subplot(111, projection='3d') + + # Plot both meshes in the same space + ax.scatter(dfy_mesh.vertices[:, 0], + dfy_mesh.vertices[:, 1], + dfy_mesh.vertices[:, 2], + c='blue', s=0.1, alpha=0.6, label='3Dfy Original') + + ax.scatter(aligned_mesh.vertices[:, 0], + aligned_mesh.vertices[:, 1], + aligned_mesh.vertices[:, 2], + c='red', s=0.1, alpha=0.6, label='SAM 3D Body Aligned') + + ax.set_title('Mesh Comparison: 3Dfy vs SAM 3D Body Aligned', fontsize=16, fontweight='bold') + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.legend() + + plt.tight_layout() + plt.show() + + # Optional trimesh scene viewer + if use_interactive: + try: + print("Creating trimesh scene...") + scene = trimesh.Scene() + + # Add both meshes with different colors + aligned_copy = aligned_mesh.copy() + aligned_copy.visual.vertex_colors = [255, 0, 0, 200] # Red + scene.add_geometry(aligned_copy, node_name="sam3d_aligned") + + dfy_copy = dfy_mesh.copy() + dfy_copy.visual.vertex_colors = [0, 0, 255, 200] # Blue + scene.add_geometry(dfy_copy, node_name="dfy_original") + + print("Opening interactive trimesh viewer...") + scene.show() + + except Exception as e: + print(f"Trimesh viewer not available: {e}") + + print("Visualization complete") + return aligned_mesh, dfy_mesh + + except Exception as e: + print(f"ERROR in visualization: {e}") + import traceback + traceback.print_exc() + return None, None \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/patching/hydra b/thirdparty/sam3d/sam3d/patching/hydra new file mode 100755 index 0000000000000000000000000000000000000000..bb9c4c2669132141d23cae6632ced41813b55e45 --- /dev/null +++ b/thirdparty/sam3d/sam3d/patching/hydra @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os +import hydra +import urllib.request + +if hydra.__version__ != "1.3.2": + raise RuntimeError("different hydra version has been found, cannot patch") + +hydra_root = os.path.dirname(hydra.__file__) +utils_path = os.path.join(hydra_root, "core", "utils.py") + +urllib.request.urlretrieve( + "https://raw.githubusercontent.com/gleize/hydra/78f00766b5f37672aa7232ebbf01bdd74246bd60/hydra/core/utils.py", + utils_path, +) diff --git a/thirdparty/sam3d/sam3d/pyproject.toml b/thirdparty/sam3d/sam3d/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f9421022fdf45615053f421a496dde0892a518cb --- /dev/null +++ b/thirdparty/sam3d/sam3d/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["hatchling", "hatch-requirements-txt"] +build-backend = "hatchling.build" + +[tool.hatch.envs.default.env-vars] +PIP_EXTRA_INDEX_URL = "https://pypi.ngc.nvidia.com https://download.pytorch.org/whl/cu121" + +[tool.hatch.metadata] +# for git-referenced dependencies +allow-direct-references = true + +[project] +name = "sam3d_objects" +version = "0.0.1" +# required for "hatch-requirements-txt" to work +dynamic = ["dependencies", "optional-dependencies"] + +[tool.hatch.build] +ignore-vcs = true +include = ["**/*.py"] +exclude = ["conftest.py", "*_test.py"] +packages = ["sam3d_objects"] + +[tool.hatch.metadata.hooks.requirements_txt] +files = ["requirements.txt"] + +[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] +p3d = ["requirements.p3d.txt"] +inference = ["requirements.inference.txt"] +dev = ["requirements.dev.txt"] diff --git a/thirdparty/sam3d/sam3d/requirements.dev.txt b/thirdparty/sam3d/sam3d/requirements.dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..b00f80e73db54fe890a92f0adc12d8b9382a2398 --- /dev/null +++ b/thirdparty/sam3d/sam3d/requirements.dev.txt @@ -0,0 +1,4 @@ +pytest +findpydeps +pipdeptree +lovely_tensors diff --git a/thirdparty/sam3d/sam3d/requirements.inference.txt b/thirdparty/sam3d/sam3d/requirements.inference.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca5784151bc3b242fc380c05c4e8758444bd17a5 --- /dev/null +++ b/thirdparty/sam3d/sam3d/requirements.inference.txt @@ -0,0 +1,4 @@ +kaolin==0.17.0 +gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@2323de5905d5e90e035f792fe65bad0fedd413e7 +seaborn==0.13.2 +gradio==5.49.0 \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/requirements.p3d.txt b/thirdparty/sam3d/sam3d/requirements.p3d.txt new file mode 100644 index 0000000000000000000000000000000000000000..e91c068ade19578ba31d688ca6786aff59c8490c --- /dev/null +++ b/thirdparty/sam3d/sam3d/requirements.p3d.txt @@ -0,0 +1,2 @@ +pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47 +flash_attn==2.8.3 \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/requirements.txt b/thirdparty/sam3d/sam3d/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b4d16d98b5ca3932e662f47cf1904db254dcc9d --- /dev/null +++ b/thirdparty/sam3d/sam3d/requirements.txt @@ -0,0 +1,88 @@ +astor==0.8.1 +async-timeout==4.0.3 +auto_gptq==0.7.1 +autoflake==2.3.1 +av==12.0.0 +bitsandbytes==0.43.0 +black==24.3.0 +bpy==4.3.0 +colorama==0.4.6 +conda-pack==0.7.1 +crcmod==1.7 +cuda-python==12.1.0 +dataclasses==0.6 +decord==0.6.0 +deprecation==2.1.0 +easydict==1.13 +einops-exts==0.0.4 +exceptiongroup==1.2.0 +fastavro==1.9.4 +fasteners==0.19 +flake8==7.0.0 +Flask==3.0.3 +fqdn==1.5.1 +ftfy==6.2.0 +fvcore==0.1.5.post20221221 +gdown==5.2.0 +h5py==3.12.1 +hdfs==2.7.3 +httplib2==0.22.0 +hydra-core==1.3.2 +hydra-submitit-launcher==1.2.0 +igraph==0.11.8 +imath==0.0.2 +isoduration==20.11.0 +jsonlines==4.0.0 +jsonpickle==3.0.4 +jsonpointer==2.4 +jupyter==1.1.1 +librosa==0.10.1 +lightning==2.3.3 +loguru==0.7.2 +mosaicml-streaming==0.7.5 +nvidia-cuda-nvcc-cu12==12.1.105 +nvidia-pyindex==1.0.9 +objsize==0.7.0 +open3d==0.18.0 +opencv-python==4.9.0.80 +OpenEXR==3.3.3 +optimum==1.18.1 +optree==0.14.1 +orjson==3.10.0 +panda3d-gltf==1.2.1 +pdoc3==0.10.0 +peft==0.10.0 +pip-system-certs==4.0 +point-cloud-utils==0.29.5 +polyscope==2.3.0 +pycocotools==2.0.7 +pydot==1.4.2 +pymeshfix==0.17.0 +pymongo==4.6.3 +pyrender==0.1.45 +PySocks==1.7.1 +pytest==8.1.1 +python-pycg==0.9.2 +randomname==0.2.1 +roma==1.5.1 +rootutils==1.0.7 +Rtree==1.3.0 +sagemaker==2.242.0 +scikit-image==0.23.1 +sentence-transformers==2.6.1 +simplejson==3.19.2 +smplx==0.1.28 +spconv-cu121==2.3.8 +tensorboard==2.16.2 +timm==0.9.16 +tomli==2.0.1 +torchaudio==2.5.1+cu121 +uri-template==1.3.0 +usort==1.0.8.post1 +wandb==0.20.0 +webcolors==1.13 +webdataset==0.2.86 +Werkzeug==3.0.6 +xatlas==0.0.9 +xformers==0.0.28.post3 +MoGe @ git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a700544f5c8388939eff112e419a851dbabf7671 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os + +# Allow skipping initialization for lightweight tools +if not os.environ.get('LIDRA_SKIP_INIT'): + import sam3d_objects.init diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6635d978e6d1c3a76ee67932e5b905c99d5575a0 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/config/utils.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import functools +from typing import Any, Callable, Union + +from omegaconf import DictConfig, ListConfig, OmegaConf +from hydra.utils import instantiate + +TargetType = Union[str, type, Callable[..., Any]] +ClassOrCallableType = Union[type, Callable[..., Any]] + + +def dump_config(config: DictConfig, path: str = "./config.yaml"): + txt = OmegaConf.to_yaml(config, sort_keys=True) + with open(path, "w") as f: + f.write(txt) + + +def locate(path: str) -> Any: + if path == "": + raise ImportError("Empty path") + + import builtins + from importlib import import_module + + parts = [part for part in path.split(".") if part] + + # load module part + module = None + for n in reversed(range(len(parts))): + try: + mod = ".".join(parts[:n]) + module = import_module(mod) + except Exception as e: + if n == 0: + raise ImportError(f"Error loading module '{path}'") from e + continue + if module: + break + + if module: + obj = module + else: + obj = builtins + + # load object path in module + for part in parts[n:]: + mod = mod + "." + part + if not hasattr(obj, part): + try: + import_module(mod) + except Exception as e: + raise ImportError( + f"Encountered error: `{e}` when loading module '{path}'" + ) from e + obj = getattr(obj, part) + + return obj + + +def full_instance_name(instance: Any) -> str: + return full_class_name(instance.__class__) + + +def full_class_name(klass: Any) -> str: + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ + + +def ensure_is_subclass(child_class: type, parent_class: type) -> None: + if not issubclass(child_class, parent_class): + raise RuntimeError( + f"class {full_class_name(child_class)} should be a subclass of {full_class_name(parent_class)}" + ) + + +def find_class_or_callable_from_target( + target: TargetType, +) -> ClassOrCallableType: + if isinstance(target, str): + obj = locate(target) + else: + obj = target + + if (not isinstance(obj, type)) and (not callable(obj)): + raise ValueError(f"Invalid type ({type(obj)}) found for {target}") + + return obj + + +def find_and_ensure_is_subclass(target: TargetType, type_: type) -> ClassOrCallableType: + klass = find_class_or_callable_from_target(target) + ensure_is_subclass(klass, type_) + return klass + + +class StrictPartial: + # remark : the `/` will handle the `path` argument name conflict (e.g. calling StrictPartial("a.b.c", ..., path="/a/b/c")) + def __init__(self, path, /, *args, **kwargs): + class_or_callable = find_class_or_callable_from_target(path) + self._partial = functools.partial(class_or_callable, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._partial(*args, **kwargs) + + +class RecursivePartial: + @staticmethod + def replace_keys(config, key_mapping): + def recurse(data): + if isinstance(data, DictConfig): + new_data = { + key_mapping[k] if k in key_mapping else k: recurse(v) + for k, v in data.items() + } + new_data = DictConfig(new_data) + elif isinstance(data, ListConfig): + new_data = ListConfig([recurse(item) for item in data]) + elif type(data) in {bool, str, int, float, type(None)}: + new_data = data + else: + raise RuntimeError(f"unknow type found : {type(data)}") + + return new_data + + return recurse(config) + + def __init__(self, config): + self.config = RecursivePartial.replace_keys( + config, {"_rpartial_target_": "_target_"} + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return instantiate(self.config) + + +class Partial(StrictPartial): + # remark : allow `path` argument to be exposed for easier use + def __init__(self, path, *args, **kwargs): + super().__init__(path, *args, **kwargs) + + +def subkey(mapping, key): + return mapping[key] + + +def make_set(*args): + return set(args) + + +def make_tuple(*args): + return tuple(args) + + +def make_list_from_kwargs(**kwargs): + # Filter out None/null values to avoid issues with callbacks + return [v for v in kwargs.values() if v is not None] + + +def make_string(value): + return str(value) + + +def make_dict(**kwargs): + return dict(kwargs) + + +def get_item(data, key: str): + return data[key] + + +def get_attr(data, key: str): + return getattr(data, key) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..91b8b1290de0121a0d3e268b0948296a23143800 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_and_mask_transforms.py @@ -0,0 +1,986 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from collections import namedtuple +import random +from typing import Optional, Dict + +import numpy as np +import matplotlib.pyplot as plt +import torchvision.transforms.functional +from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered +from sam3d_objects.model.backbone.dit.embedder.point_remapper import PointRemapper +from typing import Optional, Dict +from loguru import logger +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as tv_transforms +import torchvision.transforms.functional +import torchvision.transforms.functional as TF + +from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered + + +def UNNORMALIZE(mean, std): + mean = torch.tensor(mean).reshape((3, 1, 1)) + std = torch.tensor(std).reshape((3, 1, 1)) + + def unnormalize_img(img): + assert img.ndim == 3 and img.shape[0] == 3 + + return img * std.to(img.device) + mean.to(img.device) + + return unnormalize_img + + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +IMAGENET_NORMALIZATION = tv_transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) +IMAGENET_UNNORMALIZATION = UNNORMALIZE(IMAGENET_MEAN, IMAGENET_STD) + + +class BoundingBoxError(Exception): + pass + + +def check_bounding_box(bbox_w, bbox_h): + if bbox_w < 2 or bbox_h < 2: + raise BoundingBoxError("Bounding box dimensions must be at least 2x2.") + + +class RGBAImageProcessor: + def __init__( + self, + resize_and_make_square_kwargs: Optional[Dict] = None, + object_crop_kwargs: Optional[Dict] = None, + remove_background: bool = False, + imagenet_normalization: bool = False, + ): + self.remove_background = remove_background + self.resize_and_pad_kwargs = resize_and_make_square_kwargs + self.object_crop_kwargs = object_crop_kwargs + self.imagenet_normalization = imagenet_normalization + if resize_and_make_square_kwargs is not None: + self.transforms = resize_and_make_square(**resize_and_make_square_kwargs) + + def __call__( + self, image: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if mask is None: + assert ( + image.shape[0] == 4 + ), f"Requires 4 channels (RGB + alpha), got {image.shape[0]=}" + image, mask = split_rgba(image) + else: + assert ( + image.shape[0] == 3 + ), f"Requires 3 channels (RGB), got {image.shape[0]=}" + assert mask.dim() == 2, f"Requires 2D mask, got {mask.dim()=}" + + if not self.object_crop_kwargs in [None, False]: + image, mask = crop_around_mask_with_padding( + image, mask, **self.object_crop_kwargs + ) + + if self.remove_background: + image, mask = rembg(image, mask) + + image = self.transforms["img_transform"](image) + mask = self.transforms["mask_transform"](mask.unsqueeze(0)) + + if self.imagenet_normalization: + image = IMAGENET_NORMALIZATION(image) + return image, mask + + +def load_rgb(fpath: str) -> torch.Tensor: + """ + Load a RGB(A) image from a file path. + """ + image = plt.imread(fpath) # Why use matplotlib? + if image.dtype == "uint8": + image = image / 255 + image = image.astype(np.float32) + image = torch.from_numpy(image) + image = image.permute(2, 0, 1).contiguous() + return image + + +def concat_rgba( + rgb_image: torch.Tensor, + mask: torch.Tensor, +) -> torch.Tensor: + """ + Create a 4-channel RGBA image from a 3-channel RGB image and a mask. + """ + assert rgb_image.dim() == 3, f"{rgb_image.shape=}" + assert mask.dim() == 2, f"{mask.shape=}" + assert rgb_image.shape[0] == 3, f"{rgb_image.shape[0]=}" + assert rgb_image.shape[1:] == mask.shape, f"{rgb_image.shape[1:]=} != {mask.shape=}" + return torch.cat((rgb_image, mask[None, ...]), dim=0) + + +def split_rgba(rgba_image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Split a 4-channel RGBA image into a 3-channel RGB image and a 1-channel mask. + + Args: + rgba_image: A 4-channel RGBA image. + + Returns: + A tuple of (rgb_image, mask). + """ + assert rgba_image.dim() == 3, f"{rgba_image.shape=}" + assert rgba_image.shape[0] == 4, f"{rgba_image.shape[0]=}" + return rgba_image[:3], rgba_image[3] + + +def get_mask( + rgb_image: torch.Tensor, + depth_image: torch.Tensor, + mask_source: str, +) -> torch.Tensor: + """ + Extract a mask from either the alpha channel of an RGB image or a depth image. + + Args: + rgb_image: Tensor of shape (B, C, H, W) or (C, H, W) where C >= 4 if using alpha channel + depth_image: Tensor of shape (B, 1, H, W) or (1, H, W) containing depth information + mask_source: Source of the mask, either "ALPHA_CHANNEL" or "DEPTH" + + Returns: + mask: Tensor of shape (B, 1, H, W) or (1, H, W) containing the extracted mask + """ + # Handle unbatched inputs (add batch dimension if needed) + is_batched = len(rgb_image.shape) == 4 + + if not is_batched: + rgb_image = rgb_image.unsqueeze(0) + if depth_image is not None: + depth_image = depth_image.unsqueeze(0) + + if mask_source == "ALPHA_CHANNEL": + if rgb_image.shape[1] != 4: + logger.warning(f"No ALPHA CHANNEL for the image, cannot read mask.") + mask = None + else: + mask = rgb_image[:, 3:4, :, :] + elif mask_source == "DEPTH": + mask = depth_image + else: + raise ValueError(f"Invalid mask source: {mask_source}") + + # Remove batch dimension if input was unbatched + if not is_batched: + mask = mask.squeeze(0) + + return mask + + +def rembg(image, mask, pointmap=None): + """ + Remove the background from an image using a mask. + For pointmaps, sets background regions to NaN. + + This function follows the standard transform pattern: + - If called with (image, mask), returns (image, mask) + - If called with (image, mask, pointmap), returns (image, mask, pointmap) + """ + masked_image = image * mask + + if pointmap is not None: + masked_pointmap = torch.where(mask > 0, pointmap, torch.nan) + return masked_image, mask, masked_pointmap + + return masked_image, mask + + +def resize_and_make_square( + img_size: int, + make_square: bool | str = False, +): + """ + Create image and mask transforms based on configuration. + + Returns: + dict: {"img_transform": img_transform, "mask_transform": mask_transform} + """ + if isinstance(make_square, str): + make_square = make_square.lower() + assert make_square in ["pad", "crop", False] + pre_resize_transform = tv_transforms.Lambda(lambda x: x) + post_resize_transform = tv_transforms.Lambda(lambda x: x) + if make_square == "pad": + pre_resize_transform = pad_to_square_centered + elif make_square == "crop": + post_resize_transform = tv_transforms.CenterCrop(img_size) + + img_resize = tv_transforms.Resize(img_size) + mask_resize = tv_transforms.Resize( + img_size, + interpolation=tv_transforms.InterpolationMode.BILINEAR, + ) + + img_transform = tv_transforms.Compose( + [ + pre_resize_transform, + img_resize, + post_resize_transform, + ] + ) + + mask_transform = tv_transforms.Compose( + [ + pre_resize_transform, + mask_resize, + post_resize_transform, + ] + ) + + return { + "img_transform": img_transform, + "mask_transform": mask_transform, + } + + +def crop_around_mask_with_random_box_size_factor( + loaded_image: torch.Tensor, + mask: torch.Tensor, + random_box_size_factor: float = 1.0, + pointmap: Optional[torch.Tensor] = None, +) -> np.ndarray: + return crop_around_mask_with_padding( + loaded_image, + mask, + box_size_factor=1.0 + random.uniform(0, 1) * random_box_size_factor, + padding_factor=0.0, + pointmap=pointmap, + ) + + +def crop_around_mask_with_padding( + loaded_image: torch.Tensor, + mask: torch.Tensor, + box_size_factor: float = 1.6, + padding_factor: float = 0.1, + pointmap: Optional[torch.Tensor] = None, +) -> np.ndarray: + # cast to ensure the function can be called normally + cast_mask = False + if mask.dim() == 3: + assert mask.shape[0] == 1, "cannot take mask with channel dimension not 1" + mask = mask[0] + cast_mask = True + loaded_image = concat_rgba(loaded_image, mask) + + bbox = compute_mask_bbox(mask, box_size_factor) + loaded_image = torchvision.transforms.functional.crop( + loaded_image, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0] + ) + + # Crop pointmap if provided + if pointmap is not None: + pointmap = torchvision.transforms.functional.crop( + pointmap, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0] + ) + + C, H, W = loaded_image.shape + max_dim = max(H, W) # Get the larger dimension + + # Step 1: Pad to square shape + pad_h = (max_dim - H) // 2 + pad_w = (max_dim - W) // 2 + pad_h_extra = (max_dim - H) - pad_h # To ensure even padding + pad_w_extra = (max_dim - W) - pad_w + + loaded_image = torch.nn.functional.pad( + loaded_image, (pad_w, pad_w_extra, pad_h, pad_h_extra), mode="constant", value=0 + ) + if pointmap is not None: + pointmap = torch.nn.functional.pad( + pointmap, + (pad_w, pad_w_extra, pad_h, pad_h_extra), + mode="constant", + value=float("nan"), + ) + + # Step 2: Extend by 10% on each side; idk but this seems to have better results overall + if padding_factor > 0: + extend_size = int(max_dim * padding_factor) # 10% extension on each side + loaded_image = torch.nn.functional.pad( + loaded_image, + (extend_size, extend_size, extend_size, extend_size), + mode="constant", + value=0, + ) + + if pointmap is not None: + pointmap = torch.nn.functional.pad( + pointmap, + (extend_size, extend_size, extend_size, extend_size), + mode="constant", + value=float("nan"), + ) + + rgb_image, mask = split_rgba(loaded_image) + if cast_mask: + mask = mask[None] + + if pointmap is not None: + return rgb_image, mask, pointmap + return rgb_image, mask + + +def compute_mask_bbox( + mask: torch.Tensor, box_size_factor: float = 1.0 +) -> tuple[float, float, float, float]: + """ + Compute a bounding box around a binary mask with optional size adjustment. + + Args: + mask: A 2D binary tensor where non-zero values represent the object of interest. + box_size_factor: Factor to scale the bounding box size. Values > 1.0 create a larger box. + Default is 1.0 (tight bounding box). + + Returns: + A tuple of (x1, y1, x2, y2) coordinates representing the bounding box, + where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner. + + Raises: + ValueError: If mask is not a torch.Tensor or not a 2D tensor. + """ + if not isinstance(mask, torch.Tensor): + raise ValueError("Mask must be a torch.Tensor") + if not mask.dim() == 2: + raise ValueError("Mask must be a 2D tensor") + bbox_indices = torch.nonzero(mask) + if bbox_indices.numel() == 0: + # Handle empty mask case + return (0, 0, 0, 0) + + y_indices = bbox_indices[:, 0] + x_indices = bbox_indices[:, 1] + + min_x = torch.min(x_indices).item() + min_y = torch.min(y_indices).item() + max_x = torch.max(x_indices).item() + max_y = torch.max(y_indices).item() + + bbox = (min_x, min_y, max_x, max_y) + + center_x = (bbox[0] + bbox[2]) / 2 + center_y = (bbox[1] + bbox[3]) / 2 + + bbox_w, bbox_h = bbox[2] - bbox[0], bbox[3] - bbox[1] + + check_bounding_box(bbox_w, bbox_h) + + size = max(bbox_w, bbox_h, 2) + size = int(size * box_size_factor) + + bbox = ( + int(center_x - size // 2), + int(center_y - size // 2), + int(center_x + size // 2), + int(center_y + size // 2), + ) + # bbox = tuple(map(int, bbox)) + return bbox + + +def crop_and_pad(image, bbox): + """ + Crop an image using a bounding box and pad with zeros if out of bounds. + + Args: + image (torch.Tensor): CxHxW image. + bbox (tuple): (x1, y1, x2, y2) bounding box. + + Returns: + torch.Tensor: Cropped and zero-padded image. + """ + C, H, W = image.shape + x1, y1, x2, y2 = bbox + + # Ensure coordinates are integers + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + + # Compute cropping coordinates + x1_pad, y1_pad = max(0, -x1), max(0, -y1) + x2_pad, y2_pad = max(0, x2 - W), max(0, y2 - H) + + # Compute valid region in the original image + x1_crop, y1_crop = max(0, x1), max(0, y1) + x2_crop, y2_crop = min(W, x2), min(H, y2) + + # Extract the valid part + cropped = image[:, y1_crop:y2_crop, x1_crop:x2_crop] + + # Create a zero-padded output + padded = torch.zeros((C, y2 - y1, x2 - x1), dtype=image.dtype) + + # Place the cropped image into the zero-padded array + padded[ + :, y1_pad : y1_pad + cropped.shape[1], x1_pad : x1_pad + cropped.shape[2] + ] = cropped + + return padded + + +def resize_all_to_same_size( + rgb_image: torch.Tensor, + mask: torch.Tensor, + pointmap: Optional[torch.Tensor] = None, + target_size: Optional[tuple[int, int]] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Resize RGB image, mask, and pointmap to the same size. + + This is crucial when pointmaps have different resolution than RGB images, + which must be done BEFORE any cropping operations. + + Args: + rgb_image: RGB image tensor of shape (C, H, W) + mask: Mask tensor of shape (H, W) or (1, H, W) + pointmap: Optional pointmap tensor of shape (C_p, H_p, W_p) + target_size: Target size as (H, W). If None, uses RGB image size. + + Returns: + Tuple of (resized_rgb, resized_mask, resized_pointmap) + """ + squeeze_mask = (mask.dim() == 2) + if squeeze_mask: + mask = mask.unsqueeze(0) + + if target_size is None: + target_size = (rgb_image.shape[1], rgb_image.shape[2]) # (H, W) + + rgb_needs_resize = (rgb_image.shape[1], rgb_image.shape[2]) != target_size + if rgb_needs_resize: + rgb_image = torchvision.transforms.functional.resize( + rgb_image, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + mask = torchvision.transforms.functional.resize( + mask, target_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + if pointmap is not None: + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + if pointmap_size != target_size: + # Handle NaN values in pointmap during resizing + # Direct resize would propagate NaN values, so we need special handling + nan_mask = torch.isnan(pointmap).any(dim=0) + pointmap_clean = torch.where(torch.isnan(pointmap), torch.zeros_like(pointmap), pointmap) + pointmap_resized = torchvision.transforms.functional.resize( + pointmap_clean, target_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + + # Resize the nan mask to identify which regions should remain invalid + nan_mask_resized = torchvision.transforms.functional.resize( + nan_mask.unsqueeze(0).float(), target_size, + interpolation=torchvision.transforms.InterpolationMode.NEAREST + ).squeeze(0) > 0.5 + + # Restore NaN values in regions that were originally invalid + pointmap = torch.where( + nan_mask_resized.unsqueeze(0).expand_as(pointmap_resized), + torch.full_like(pointmap_resized, float('nan')), + pointmap_resized + ) + + if squeeze_mask: + mask = mask.squeeze(0) + + if pointmap is not None: + return rgb_image, mask, pointmap + return rgb_image, mask + + +SSINormalizedPointmap = namedtuple("SSINormalizedPointmap", ["pointmap", "scale", "shift"]) +class SSIPointmapNormalizer: + + def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor, + scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None, + ) -> SSINormalizedPointmap: + if scale is None or shift is None: + normalized_pointmap, scale, shift = normalize_pointmap_ssi(pointmap) + else: + assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format" + normalized_pointmap = _apply_metric_to_ssi(pointmap, scale, shift) + return SSINormalizedPointmap(normalized_pointmap, scale, shift) + + def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True) + return pointmap + + + +class ObjectCentricSSI(SSIPointmapNormalizer): + def __init__(self, + use_scene_scale: bool = True, + quantile_drop_threshold: float = 0.1, + clip_beyond_scale: Optional[float] = None, + # scale_factor: float = 3.8076, # e^(1.337); empirical mean of R3+Artist train + scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train + allow_scale_and_shift_override: bool = False, + raise_on_no_valid_points: bool = False, + ): + self.use_scene_scale = use_scene_scale + self.quantile_drop_threshold = quantile_drop_threshold + self.clip_beyond_scale = clip_beyond_scale + self.scale_factor = scale_factor + self.allow_scale_and_shift_override = allow_scale_and_shift_override + self.raise_on_no_valid_points = raise_on_no_valid_points + + def _compute_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + + + mask_resized = torchvision.transforms.functional.resize( + mask, pointmap_size, + interpolation=torchvision.transforms.InterpolationMode.NEAREST + ).squeeze(0) + + pointmap_flat = pointmap.reshape(3, -1) + # Get valid points from the mask + mask_bool = mask_resized.reshape(-1) > 0.5 + mask_points = pointmap_flat[:, mask_bool] + + if mask_points.isfinite().max() == 0: + if self.raise_on_no_valid_points: + raise ValueError(f"No valid points found in mask") + logger.warning(f"No valid points found in mask; setting scale to {self.scale_factor} and shift to 0") + return torch.ones_like(pointmap_flat[:,0]) * self.scale_factor, torch.zeros_like(pointmap_flat[:,0]) + + # Compute median for shift + shift = mask_points.nanmedian(dim=-1).values + # logger.info(f"{pointmap.shape=} {mask_resized.shape=} {shift.shape=}") + + + if self.use_scene_scale == True: + # Normalize by the scene scale + points_centered = pointmap_flat - shift.unsqueeze(-1) + max_dims = points_centered.abs().max(dim=0).values + scale = max_dims.nanmedian(dim=-1).values + elif self.use_scene_scale == False: + # Normalize by the object scale + shifted_mask_points = mask_points - shift.unsqueeze(-1) + norm = shifted_mask_points.norm(dim=0) + quantiles = torch.nanquantile(norm, + torch.tensor([self.quantile_drop_threshold, 1. - self.quantile_drop_threshold], + device=shifted_mask_points.device), + dim=-1) + scale = (quantiles[1] - quantiles[0]).max(dim=-1).values * 2.0 + elif self.use_scene_scale.upper() == "OBJECT_NORM_MEDIAN": + # Normalize by the object scale + shifted_mask_points = mask_points - shift.unsqueeze(-1) + norm = shifted_mask_points.norm(dim=0) + scale = norm.nanmedian(dim=-1).values + else: + raise ValueError(f"Invalid use_scene_scale: {self.use_scene_scale}") + scale = scale.expand_as(shift) # per-dim scaling + scale = scale * self.scale_factor + return scale, shift + + def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor, + scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # 1. resize mask to size of pointmap using nearest interpolation + # 2. get mask points: pointmap[mask > 0.5] + # 3. shift = mask_points.median() # xyz + # 4. scale = # filter. If no points, then + # logger.info(f"{pointmap.shape=} {mask.shape=}") + assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format" + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + + _scale, _shift = self._compute_scale_and_shift(pointmap, mask) + if scale is not None and self.allow_scale_and_shift_override: + _scale = scale + if shift is not None and self.allow_scale_and_shift_override: + _shift = shift + return_scale, return_shift = _scale, _shift + + # Apply normalization + pointmap_normalized = _apply_metric_to_ssi(pointmap, return_scale, return_shift) + + if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0: + new_norm = pointmap_normalized.norm(dim=0) + pointmap_normalized = torch.where( + new_norm > self.clip_beyond_scale, + torch.full_like(pointmap_normalized, float('nan')), + pointmap_normalized + ) + + return SSINormalizedPointmap(pointmap_normalized, return_scale, return_shift) + + +class ObjectApparentSizeSSI(SSIPointmapNormalizer): + def __init__(self, + clip_beyond_scale: Optional[float] = None, + use_scene_scale: bool = True, + scale_factor: float = 1.0, # e^(1.337); empirical mean of R3+Artist train + ): + self.clip_beyond_scale = clip_beyond_scale + self.use_scene_scale = use_scene_scale + self.scale_factor = scale_factor + + def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + pointmap_flat = pointmap.reshape(3, -1) + + if not self.use_scene_scale: + # Get valid points from the mask + mask_resized = torchvision.transforms.functional.resize( + mask, pointmap_size, + interpolation=torchvision.transforms.InterpolationMode.NEAREST + ).squeeze(0) + mask_bool = mask_resized.reshape(-1) > 0.5 + pointmap_flat = pointmap_flat[:, mask_bool] + + # Median z-distance + median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0) + scale = median_z.expand(3) * self.scale_factor + shift = torch.zeros_like(scale) + # logger.info(f'median z = {median_z}') + return scale, shift + + def normalize(self, + pointmap: torch.Tensor, + mask: torch.Tensor, + scale: Optional[torch.Tensor] = None, + shift: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format" + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + + if scale is None or shift is None: + scale, shift = self._get_scale_and_shift(pointmap, mask) + else: + assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format" + + # Apply normalization and clip + pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift) + # logger.info(f"{pointmap_normalized.shape=}") + + if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0: + pointmap_normalized = torch.where( + pointmap_normalized[-1, ...] > self.clip_beyond_scale, + torch.full_like(pointmap_normalized, float('nan')), + pointmap_normalized + ) + + # return pointmap_normalized, scale, shift + return SSINormalizedPointmap(pointmap_normalized, scale, shift) + + +class NormalizedDisparitySpaceSSI(SSIPointmapNormalizer): + def __init__(self, + clip_beyond_scale: Optional[float] = None, + use_scene_scale: bool = True, + log_disparity_shift: float = 0.0, + ): + self.clip_beyond_scale = clip_beyond_scale + self.use_scene_scale = use_scene_scale + self.point_remapper = PointRemapper(remap_type="exp_disparity") + self.log_disparity_shift = log_disparity_shift + + def normalize(self, pointmap: torch.Tensor, mask: torch.Tensor, + scale: Optional[torch.Tensor] = None, shift: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert pointmap.shape[0] == 3, "pointmap must be in (3, H, W) format" + + + disparity_space_pointmap = self.point_remapper.forward(pointmap.permute(1, 2, 0)).permute(2, 0, 1) + if scale is None or shift is None: + scale, shift = self._get_scale_and_shift(disparity_space_pointmap, mask) + else: + assert scale.shape == (3,) and shift.shape == (3,), "scale and shift must be in (3,) format" + + # pointmap_normalized = pointmap.clone().detach() + pointmap_normalized = _apply_metric_to_ssi(disparity_space_pointmap, scale, shift) + # logger.info(f"{pointmap_normalized.shape=}") + + if self.clip_beyond_scale is not None and self.clip_beyond_scale > 0: + pointmap_normalized = torch.where( + pointmap_normalized[2, ...].abs() > self.clip_beyond_scale, + torch.full_like(pointmap_normalized, float('nan')), + pointmap_normalized + ) + + # return pointmap_normalized, scale, shift + return SSINormalizedPointmap(pointmap_normalized, scale, shift) + + def denormalize(self, pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + pointmap = _apply_metric_to_ssi(pointmap, scale, shift, apply_inverse=True) + pointmap = self.point_remapper.inverse(pointmap.permute(1, 2, 0)).permute(2, 0, 1) + return pointmap + + def _get_scale_and_shift(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + mask_resized = torchvision.transforms.functional.resize( + mask, pointmap_size, + interpolation=torchvision.transforms.InterpolationMode.NEAREST + ).squeeze(0) + + pointmap_flat = pointmap.reshape(3, -1) + if self.use_scene_scale: + median_z = pointmap_flat[-1, ...].nanmedian().unsqueeze(0) + shift = torch.zeros_like(median_z.expand(3)) + shift[-1, ...] = median_z[0] + self.log_disparity_shift + else: + # Get valid points from the mask (shift, x/z, y/z, log(z)) + mask_bool = mask_resized.reshape(-1) > 0.5 + pointmap_flat = pointmap_flat[:, mask_bool] + shift = pointmap_flat.nanmedian(dim=-1).values + + scale = torch.ones_like(shift) + # logger.info(f'median z = {median_z}') + return scale, shift + +def normalize_pointmap_ssi(pointmap: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Normalize pointmap using Scale-Shift Invariant (SSI) normalization. + + Args: + pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W) + + Returns: + Tuple of (normalized_pointmap, scale, shift) + """ + from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant + + # Convert to (H, W, 3) if needed for get_scale_and_shift + if pointmap.shape[0] == 3: + pointmap_hw3 = pointmap.permute(1, 2, 0) + original_format = 'chw' + else: + pointmap_hw3 = pointmap + original_format = 'hwc' + + # Get scale and shift using existing method + scale, shift = ScaleShiftInvariant.get_scale_and_shift(pointmap_hw3) + + pointmap_normalized = _apply_metric_to_ssi(pointmap, scale, shift) + return pointmap_normalized, scale, shift + +def _apply_metric_to_ssi(pointmap: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, apply_inverse: bool = False) -> torch.Tensor: + """ + Normalize pointmap using Scale-Shift Invariant (SSI) normalization. + + Args: + pointmap: Pointmap tensor of shape (H, W, 3) or (3, H, W) + + Returns: + Tuple of (normalized_pointmap, scale, shift) + """ + from sam3d_objects.data.dataset.tdfy.pose_target import ScaleShiftInvariant + + # Convert to (H, W, 3) if needed for get_scale_and_shift + if pointmap.shape[0] == 3: + pointmap_hw3 = pointmap.permute(1, 2, 0) + original_format = 'chw' + else: + pointmap_hw3 = pointmap + original_format = 'hwc' + + # Apply normalization + ssi_to_metric = ScaleShiftInvariant.ssi_to_metric(scale, shift) + metric_to_ssi = ssi_to_metric.inverse() + transform_to_apply = metric_to_ssi + + if apply_inverse: + transform_to_apply = ssi_to_metric + + pointmap_flat = pointmap_hw3.reshape(-1, 3) + pointmap_normalized = transform_to_apply.transform_points(pointmap_flat) + + # Reshape back to original format + if original_format == 'chw': + pointmap_normalized = pointmap_normalized.reshape(pointmap.shape[1], pointmap.shape[2], 3).permute(2, 0, 1) + else: + pointmap_normalized = pointmap_normalized.reshape(pointmap_hw3.shape) + + return pointmap_normalized + + +def perturb_mask_translation( + image: torch.Tensor, + mask: torch.Tensor, + max_px_delta: int = 5, +): + """ + Applies data augmentation to the mask by randomly translating the mask. + + Args: + image: (C, H, W) float32 [0, 1] tensor. + mask: (1, H, W) float32 [0, 1] tensor. + max_px_delta: The maximum number of pixels we will randomly shift by in each 2D direction. + """ + dx = random.randint(-max_px_delta, max_px_delta) + dy = random.randint(-max_px_delta, max_px_delta) + + mask = mask.squeeze(0) + mask = torch.roll(mask, shifts=(dy, dx), dims=(0, 1)) + + # Zero out wrapped regions + if dy > 0: + mask[:dy, :] = 0 + elif dy < 0: + mask[dy:, :] = 0 + if dx > 0: + mask[:, :dx] = 0 + elif dx < 0: + mask[:, dx:] = 0 + + mask = mask.unsqueeze(0) + return image, mask + + +def perturb_mask_boundary( + image: torch.Tensor, + mask: torch.Tensor, + kernel_range: tuple[int, int] = (2, 5), + p_erode: float = 0.1, + p_dilate: float = 0.8, + **kwargs, +): + """ + Applies data augmentation to the mask by randomly eroding or dilating the mask. + + Args: + image: (C, H, W) float32 [0, 1] tensor. + mask: (1, H, W) float32 [0, 1] tensor. + kernel_range: Range of kernel sizes to sample from. + p_erode: Probability of erosion. + p_dilate: Probability of dilation. + kwargs: Kwargs for the cv2 erode/dilate function. + """ + import cv2 + + C, H, W = image.shape + assert mask.shape == (1, H, W) + assert mask.dtype == torch.float32 + assert torch.all((mask == 0) | (mask == 1)), "Mask must be binary (0 or 1)" + + p_none = 1.0 - p_erode - p_dilate + assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid." + + # Sample operation. + op = random.choices(["erode", "dilate", "none"], weights=[p_erode, p_dilate, p_none], k=1)[0] + + if op == "none": + pass + else: + # Sample kernel size + ksize = random.randint(*kernel_range) + kernel = np.ones((ksize, ksize), np.uint8) + + mask = mask.squeeze().cpu().numpy().astype(np.uint8) # (H, W) + + if op == "erode": + mask = cv2.erode(mask, kernel, **kwargs) + elif op == "dilate": + mask = cv2.dilate(mask, kernel, **kwargs) + else: + raise NotImplementedError + + mask = torch.from_numpy(mask).float()[None] # (1, H, W) + + return image, mask + + +def resolution_blur( + image: torch.Tensor, + mask: torch.Tensor, + scale_range=(0.05, 0.95), + interpolation_down=tv_transforms.InterpolationMode.BICUBIC, + interpolation_up=tv_transforms.InterpolationMode.BICUBIC, +): + """ + Blur the input image by applying upsample(downsample(x)). + + Args: + image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1]. + mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged. + scale_range: Tuple of (min_scale, max_scale) for downsampling. + interpolation_down: Interpolation mode for downsampling. + interpolation_up: Interpolation mode for upsampling. + """ + C, H, W = image.shape + scale = random.uniform(*scale_range) + new_H, new_W = max(1, int(H * scale)), max(1, int(W * scale)) + + # Downsample + image = TF.resize(image, size=[new_H, new_W], interpolation=interpolation_down) + + # Upsample back to original size + image = TF.resize(image, size=[H, W], interpolation=interpolation_up) + + return image, mask + + +def gaussian_blur( + image: torch.Tensor, + mask: torch.Tensor, + kernel_range: tuple[int, int] = (3, 15), + sigma_range: tuple[int, int] = (0.1, 4.0), +): + """ + Apply gaussian blur to the input image. + + Args: + image (torch.Tensor): Image tensor of shape (C, H, W), float32, with values in [0, 1]. + mask (torch.Tensor): Mask tensor of shape (1, H, W), float32, with values in [0, 1]. The mask is returned unchanged. + kernel_range (tuple): Range of odd kernel sizes to sample from for the Gaussian blur (min, max). + sigma_range (tuple): Range of sigma values (standard deviation) to sample from for the Gaussian kernel (min, max). + """ + kernel_size = random.choice([k for k in range(kernel_range[0], kernel_range[1]+1) if k % 2 == 1]) + sigma = random.uniform(*sigma_range) + pad = kernel_size // 2 + + # Step 1: Pad the image + image = F.pad(image.unsqueeze(0), (pad, pad, pad, pad), mode='replicate') + + # Step 2: Apply gaussian blur + image = TF.gaussian_blur(image, kernel_size=[kernel_size, kernel_size], sigma=sigma) + + # Step 3: Unpad to get back to original size + image = image[:, :, pad:-pad, pad:-pad] + + return image.squeeze(0), mask + + +def apply_blur_augmentation( + image: torch.Tensor, + mask: torch.Tensor, + p_resolution: float = 0.33, + p_gaussian: float = 0.33, + gaussian_kwargs: dict = None, + resolution_kwargs: dict = None, +): + """Apply blur augmentation with configurable parameters""" + + # Handle None defaults BEFORE unpacking + if gaussian_kwargs is None: + gaussian_kwargs = {} + if resolution_kwargs is None: + resolution_kwargs = {} + + p_none = 1.0 - p_gaussian - p_resolution + assert 0 <= p_none <= 1, "Probabilities must sum to 1 and be valid." + + operation = random.choices( + ["gaussian", "resolution", "none"], + weights=[p_gaussian, p_resolution, p_none], + k=1 + )[0] + + if operation == "gaussian": + return gaussian_blur(image, mask, **gaussian_kwargs) + elif operation == "resolution": + return resolution_blur(image, mask, **resolution_kwargs) + elif operation == "none": + return image, mask + else: + raise NotImplementedError diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e7fa4b9c2c5f833c1ea68d5f21953e6e243485 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/img_processing.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import math + +import random + +import torch +import torch.nn.functional as F + +from torchvision import transforms +from torchvision.transforms import functional as tv_F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + + @staticmethod + def get_params(img, scale, ratio): + width, height = tv_F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w + + +# following PT3D CO3D data to pad image +def pad_to_square(image, value=0): + _, _, h, w = image.shape # Assuming image is in (B, C, H, W) format + if h == w: + return image # The image is already square + + # Calculate the padding + diff = abs(h - w) + pad2 = diff + + # Pad the image to make it square + if h > w: + padding = (0, pad2, 0, 0) # Pad width (left, right, top, bottom) + else: + padding = (0, 0, 0, pad2) # Pad height + # Apply padding + padded_image = torch.nn.functional.pad(image, padding, mode="constant", value=value) + return padded_image + + +def preprocess_img( + x, + mask=None, + img_target_shape=224, + mask_target_shape=256, + normalize=False, +): + if x.shape[1] != x.shape[2]: + x = pad_to_square(x) + if mask is not None and mask.shape[1] != mask.shape[2]: + mask = pad_to_square(mask) + if x.shape[2] != img_target_shape: + x = F.interpolate( + x, + size=(img_target_shape, img_target_shape), + # scale_factor=float(img_target_shape)/x.shape[2], + mode="bilinear", + ) + if mask is not None and mask.shape[2] != mask_target_shape: + if mask is not None: + mask = F.interpolate( + mask, + size=(mask_target_shape, mask_target_shape), + # scale_factor=float(mask_target_shape)/mask.shape[2], + mode="nearest", + ) + if normalize: + imgs_normed = resnet_img_normalization(x) + else: + imgs_normed = x + return imgs_normed, mask + + +def resnet_img_normalization(x): + resnet_mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).reshape( + (3, 1, 1) + ) + resnet_std = torch.tensor([0.229, 0.224, 0.225], device=x.device).reshape((3, 1, 1)) + if x.ndim == 4: + resnet_mean = resnet_mean[None] + resnet_std = resnet_std[None] + x = (x - resnet_mean) / resnet_std + return x + + +# pad image to be centered for unprojecting depth +def pad_to_square_centered(image, value=0, pointmap=None): + h, w = image.shape[-2], image.shape[-1] # Assuming image is in (B, C, H, W) format + if h == w: + if pointmap is not None: + return image, pointmap + return image # The image is already square + + # Calculate the padding + diff = abs(h - w) + pad1 = diff // 2 + pad2 = diff - pad1 + + # Pad the image to make it square + if h > w: + padding = (pad1, pad2, 0, 0) # Pad width (left, right, top, bottom) + else: + padding = (0, 0, pad1, pad2) # Pad height + # Apply padding to image + padded_image = F.pad(image, padding, mode="constant", value=value) + + # Apply padding to pointmap if provided + if pointmap is not None: + # Pad pointmap using torch functional with NaN fill value + padded_pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan")) + + return padded_image, padded_pointmap + return padded_image + + +def crop_img_to_obj(mask, context_size): + nonzeros = torch.nonzero(mask) + if len(nonzeros) > 0: + r_max, c_max = nonzeros.max(dim=0)[0] + r_min, c_min = nonzeros.min(dim=0)[0] + box_h = max(1, r_max - r_min) + box_w = max(1, c_max - c_min) + left = max(0, c_min - int(box_w * context_size)) + right = min(mask.shape[-1], c_max + int(box_w * context_size)) + top = max(0, r_min - int(box_h * context_size)) + bot = min(mask.shape[-2], r_max + int(box_h * context_size)) + return left, right, top, bot + return None, None, None, None + + +def random_pad(img, mask=None, max_ratio=0.0, pointmap=None): + max_size = int(max(img.shape) * max_ratio) + padding = tuple([random.randint(0, max_size) for _ in range(4)]) + img = F.pad(img, padding) + if mask is not None: + mask = F.pad(mask, padding) + + if pointmap is not None: + pointmap = F.pad(pointmap, padding, mode="constant", value=float("nan")) + return img, mask, pointmap + return img, mask + + +def get_img_color_augmentation( + color_jit_prob=0.5, + gaussian_blur_prob=0.1, +): + transform = transforms.Compose( + [ + # (a) Random Color Jitter + transforms.RandomApply( + [ + transforms.ColorJitter( + brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1 + ) + ], + p=color_jit_prob, + ), + # (b) Randomly apply GaussianBlur + transforms.RandomApply( + [transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], + p=gaussian_blur_prob, + ), + ] + ) + return transform diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e221f25c4b437f415fb49725a9a5322df14c98 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/pose_target.py @@ -0,0 +1,784 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from typing import Dict, Optional, Tuple, Any +from dataclasses import dataclass, asdict, field +from loguru import logger + +from sam3d_objects.data.utils import expand_as_right, tree_tensor_map +from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform +from pytorch3d.transforms import Transform3d, quaternion_to_matrix, matrix_to_quaternion + + +@dataclass +class InstancePose: + """ + Stores the pose of an object. + Also, stores some information about the scene that was used to normalize the pose. + """ + + instance_scale_l2c: torch.Tensor + instance_position_l2c: torch.Tensor + instance_quaternion_l2c: torch.Tensor + scene_scale: torch.Tensor + scene_shift: torch.Tensor + + @classmethod + def _broadcast_postcompose( + cls, + scale: torch.Tensor, + rotation: torch.Tensor, + translation: torch.Tensor, + transform_to_postcompose: Transform3d, + ) -> Transform3d: + """ + Assumes scale, rotation, translation are of shape: + B, K, C + --- + B: batch size + K: number of objects + C: number of channels + + Takes a transform where + get_matrix() has shape (B, 3, 3) + + Returns pose.compose(transform_to_postcompose) + """ + scale_c = scale.shape[-1] + ndim_orig = scale.ndim + if ndim_orig == 3: + b, k, _ = scale.shape + elif ndim_orig == 2: + b = scale.shape[0] + k = 1 + elif ndim_orig == 1: + b = 1 + k = 1 + else: + raise ValueError(f"Invalid scale shape: {scale.shape}") + + # Create transform of shape (B * K) + wide = {"scale": scale, "rotation": rotation, "translation": translation} + shapes_orig = {k: v.shape for k, v in wide.items()} + long = tree_tensor_map(lambda x: x.reshape(b * k, x.shape[-1]), wide) + long["rotation"] = quaternion_to_matrix(long["rotation"]) + if scale_c == 1: + long["scale"] = long["scale"].expand(b * k, 3) + + composed = compose_transform(**long) + + # Apply transform to shape (B * K) + pc_transform = transform_to_postcompose.get_matrix() + pc_transform = pc_transform.repeat(k, 1, 1) + stacked_pc_transform = Transform3d(matrix=pc_transform) + assert stacked_pc_transform.get_matrix().shape == composed.get_matrix().shape + postcomposed = composed.compose(stacked_pc_transform) + + # Decompose transform to shape (B, K, C) + scale, rotation, translation = decompose_transform(postcomposed) + rotation = matrix_to_quaternion(rotation) + pc_long = {"scale": scale, "rotation": rotation, "translation": translation} + pc_wide = tree_tensor_map(lambda x: x.reshape(b, k, x.shape[-1]), pc_long) + if scale_c == 1: + pc_wide["scale"] = pc_wide["scale"][..., 0].unsqueeze(-1) + for k, shape in shapes_orig.items(): + pc_wide[k] = pc_wide[k].reshape(*shape) + return pc_wide["scale"], pc_wide["rotation"], pc_wide["translation"] + + +@dataclass +class PoseTarget: + x_instance_scale: torch.Tensor + x_instance_rotation: torch.Tensor + x_instance_translation: torch.Tensor + x_scene_scale: torch.Tensor + x_scene_center: torch.Tensor + x_translation_scale: torch.Tensor + pose_target_convention: str = field(default="unknown") + + +@dataclass +class InvariantPoseTarget: + """ + This is the canonical representation of pose targets, used for computing metrics. + instance_pose <-> invariant_pose_targets <-> all other pose_target_conventions + + Background: + --- + We want to estimate a transformation T: R³ → R³ despite scene scale ambiguity. + + The transformation taking object points to scene points is defined as + T(x) = s · R(q) · x + t + where: + - x is a point in the object coordinate frame, + - q is a unit quaternion representing rotation, + - s is the object-to-scene scale, and + - t is the translation. + + However, there is an inherent scale ambiguity in the scene, denoted as s_scene; + This ambiguity introduces irreducible error that complicates both evaluation and training. + + To decouple the scene scale from the invariant quantities, we define: + T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ] + where we define + t_rel = t / s_scene + s_rel = s / s_scene + s_tilde = s_rel / |t_rel| + t_unit = t_rel / |t_rel| + + During training, you would predict (q, s_tilde, t_unit), leaving s_scene separate. + + + Hand-wavy error analysis: + --- + 1. Naive (coupled) estimate: + T(x) = s_scene [ s_rel · R(q) · x + t_rel ] + + We can define: + U = ln(s_rel) + V = ln(|t_rel|) + so that the error is governed by Var(U + V). + + 2. In the decoupled case, we have: + T(x) = s_scene · |t_rel| [ s_tilde · R(q) · x + t_unit ] + = s_scene · |t_rel| [ (s_rel / |t_rel|) R(q) · x + t_unit ] + Then ln(s_tilde) = ln(s_rel) - ln(|t_rel|) = U - V, and the error is + Var(U - V) = Var(U) + Var(V) - 2Cov(U, V). + + """ + + # These are invariant + q: torch.Tensor + t_unit: torch.Tensor + s_scene: torch.Tensor + t_scene_center: Optional[torch.Tensor] = None + t_rel_norm: Optional[torch.Tensor] = None + s_tilde: Optional[torch.Tensor] = None + s_rel: Optional[torch.Tensor] = None + + def __post_init__(self): + # Check that fields that are required always have values. + if self.q is None: + raise ValueError("Field 'q' (quaternion) must be provided.") + if self.s_scene is None: + raise ValueError("Field 's_scene' must be provided.") + if self.s_rel is None: + if self.s_tilde is not None: + self.s_rel = self.s_tilde * self.t_rel_norm + else: + raise ValueError("Field 's_rel' or 's_tilde' must be provided.") + if self.t_unit is None: + raise ValueError("Field 't_unit' must be provided.") + + if self.t_scene_center is None: + self.t_scene_center = torch.zeros_like(self.t_unit) + + # There is a simple relationship between s_tilde and t_rel_norm: + # s_tilde = s_rel / t_rel_norm + # + # If one of these is missing and the other is provided, we can compute the missing field. + if self.s_tilde is None and self.t_rel_norm is not None: + self.s_tilde = self.s_rel / self.t_rel_norm + elif self.t_rel_norm is None and self.s_tilde is not None: + self.t_rel_norm = self.s_rel / self.s_tilde + + # If both are provided, we check for consistency. + if self.s_tilde is not None and self.t_rel_norm is not None: + computed_s_tilde = self.s_rel / self.t_rel_norm + # If the provided s_tilde deviates from what is computed, update it. + if not torch.allclose(self.s_tilde, computed_s_tilde, atol=1e-6): + logger.warning( + f"s_tilde and t_rel_norm are provided, but they are not consistent. " + f"Updating s_tilde to {computed_s_tilde}." + ) + self.s_tilde = computed_s_tilde + + self._validate_fields() + + def _validate_fields(self): + for field in self.__dict__: + if self.__dict__[field] is None: + raise ValueError(f"Field '{field}' must be provided.") + + + @staticmethod + def from_instance_pose(instance_pose: InstancePose) -> "InvariantPoseTarget": + q = instance_pose.instance_quaternion_l2c + s_obj_to_scene = instance_pose.instance_scale_l2c # (..., 1) or (..., 3) + t_obj_to_scene = instance_pose.instance_position_l2c # (..., 3) + s_scene = instance_pose.scene_scale # (..., 1) or scalar-broadcastable + t_scene_center = instance_pose.scene_shift # (..., 3) + + # Normalize to scene scale (per the derivation) + if not ( s_obj_to_scene.ndim == (s_scene.ndim + 1)): + raise ValueError(f"s_scene should be ND [...,3] and s_obj_to_scene should be (N+1)D [...,K,3], but got {s_scene.shape=} {s_obj_to_scene.shape=}") + if not (t_obj_to_scene.ndim == (s_scene.ndim + 1)): + raise ValueError(f"t_scene_center should be ND [B,3] and t_obj_to_scene should be (N+1)D [B,K,3], but got {t_scene_center.shape=} {t_obj_to_scene.shape=}") + s_scene_exp = s_scene.unsqueeze(-2) + + s_rel = s_obj_to_scene / s_scene_exp + t_rel = t_obj_to_scene / s_scene_exp + + # Robust norms + eps = 1e-8 + t_rel_norm = t_rel.norm(dim=-1, keepdim=True).clamp_min(eps) + + s_tilde = s_rel / t_rel_norm + t_unit = t_rel / t_rel_norm + + return InvariantPoseTarget( + q=q, + s_scene=s_scene, + t_scene_center=t_scene_center, + s_rel=s_rel, + s_tilde=s_tilde, + t_unit=t_unit, + t_rel_norm=t_rel_norm, + ) + + + @staticmethod + def to_instance_pose(invariant_targets: "InvariantPoseTarget") -> InstancePose: + # scale factor per the derivation: s_scene * |t_rel| + # Normalize to scene scale (per the derivation) + t_rel_norm_ndim = invariant_targets.t_rel_norm.ndim + if not (invariant_targets.s_scene.ndim == (t_rel_norm_ndim - 1)) : + raise ValueError(f"s_scene should be ND [...,3] and t_rel_norm should be (N+1)D [...,K,3], but got {invariant_targets.s_scene.shape=} {invariant_targets.t_rel_norm.shape=}") + + scale = invariant_targets.s_scene.unsqueeze(-2) * invariant_targets.t_rel_norm + return InstancePose( + instance_scale_l2c=invariant_targets.s_tilde * scale, + instance_position_l2c=invariant_targets.t_unit * scale, + instance_quaternion_l2c=invariant_targets.q, + scene_scale=invariant_targets.s_scene, + scene_shift=invariant_targets.t_scene_center, + ) + + +class PoseTargetConvention: + """ + Converts pose_targets <-> instance_pose <-> invariant_pose_targets + """ + + pose_target_convention: str + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget: + raise NotImplementedError("Implement this in a subclass") + + @classmethod + def to_invariant(cls, instance_pose: InstancePose) -> InvariantPoseTarget: + raise NotImplementedError("Implement this in a subclass") + + @classmethod + def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget: + invariant_targets = InvariantPoseTarget.from_instance_pose(instance_pose) + return cls.from_invariant(invariant_targets) + + @classmethod + def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose: + invariant_targets = cls.to_invariant(pose_target) + return InvariantPoseTarget.to_instance_pose(invariant_targets) + + +class ScaleShiftInvariant(PoseTargetConvention): + """ + + Midas eq. (6): https://arxiv.org/pdf/1907.01341v3 + But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115 + """ + + pose_target_convention: str = "ScaleShiftInvariant" + scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32) + scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32) + translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32) + translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32) + + @classmethod + def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget: + metric_to_ssi = cls.ssi_to_metric( + instance_pose.scene_scale, instance_pose.scene_shift + ).inverse() + + ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose( + scale=instance_pose.instance_scale_l2c, + rotation=instance_pose.instance_quaternion_l2c, + translation=instance_pose.instance_position_l2c, + transform_to_postcompose=metric_to_ssi, + ) + # logger.info(f"{normalize=} {ssi_scale.shape=} {ssi_rotation.shape=} {ssi_translation.shape=}") + if normalize: + device = ssi_scale.device + ssi_scale = (ssi_scale - cls.scale_mean.to(device)) / cls.scale_std.to(device) + ssi_translation = (ssi_translation - cls.translation_mean.to(device)) / cls.translation_std.to(device) + + return PoseTarget( + x_instance_scale=ssi_scale, + x_instance_rotation=ssi_rotation, + x_instance_translation=ssi_translation, + x_scene_scale=instance_pose.scene_scale, + x_scene_center=instance_pose.scene_shift, + x_translation_scale=torch.ones_like(ssi_scale)[..., 0].unsqueeze(-1), + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose: + scene_scale = pose_target.x_scene_scale + scene_shift = pose_target.x_scene_center + ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift) + + if normalize: + device = pose_target.x_instance_scale.device + pose_target.x_instance_scale = pose_target.x_instance_scale * cls.scale_std.to(device) + cls.scale_mean.to(device) + pose_target.x_instance_translation = pose_target.x_instance_translation * cls.translation_std.to(device) + cls.translation_mean.to(device) + + ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose( + scale=pose_target.x_instance_scale, + rotation=pose_target.x_instance_rotation, + translation=pose_target.x_instance_translation, + transform_to_postcompose=ssi_to_metric, + ) + + return InstancePose( + instance_scale_l2c=ins_scale, + instance_position_l2c=ins_translation, + instance_quaternion_l2c=ins_rotation, + scene_scale=scene_scale, + scene_shift=scene_shift, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget: + instance_pose = cls.to_instance_pose(pose_target, normalize=normalize) + return InvariantPoseTarget.from_instance_pose(instance_pose) + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget: + instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets) + return cls.from_instance_pose(instance_pose, normalize=normalize) + + @classmethod + def get_scale_and_shift(cls, pointmap): + shift_z = pointmap[..., -1].nanmedian().unsqueeze(0) + shift = torch.zeros_like(shift_z.expand(1, 3)) + shift[..., -1] = shift_z + + shifted_pointmap = pointmap - shift + scale = shifted_pointmap.abs().nanmean().to(shift.device) + + shift = shift.reshape(3) + scale = scale.expand(3) + + return scale, shift + + @staticmethod + def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor): + if scale.ndim == 1: + scale = scale.unsqueeze(0) + if shift.ndim == 1: + shift = shift.unsqueeze(0) + return Transform3d().scale(scale).translate(shift).to(shift.device) + + +class ScaleShiftInvariantWTranslationScale(PoseTargetConvention): + """ + + Midas eq. (6): https://arxiv.org/pdf/1907.01341v3 + But for pointmaps (see MoGe): https://arxiv.org/pdf/2410.19115 + """ + + pose_target_convention: str = "ScaleShiftInvariantWTranslationScale" + scale_mean = torch.tensor([1.0232692956924438, 1.0232691764831543, 1.0232692956924438]).to(torch.float32) + scale_std = torch.tensor([1.3773751258850098, 1.3773752450942993, 1.3773750066757202]).to(torch.float32) + translation_mean = torch.tensor([0.003191213821992278, 0.017236359417438507, 0.9401122331619263]).to(torch.float32) + translation_std = torch.tensor([1.341888666152954, 0.7665449380874634, 3.175130605697632]).to(torch.float32) + + @classmethod + def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget: + metric_to_ssi = cls.ssi_to_metric( + instance_pose.scene_scale, instance_pose.scene_shift + ).inverse() + + ssi_scale, ssi_rotation, ssi_translation = InstancePose._broadcast_postcompose( + scale=instance_pose.instance_scale_l2c, + rotation=instance_pose.instance_quaternion_l2c, + translation=instance_pose.instance_position_l2c, + transform_to_postcompose=metric_to_ssi, + ) + + ssi_translation_scale = ssi_translation.norm(dim=-1, keepdim=True) + ssi_translation_unit = ssi_translation / ssi_translation_scale.clamp_min(1e-7) + + return PoseTarget( + x_instance_scale=ssi_scale, + x_instance_rotation=ssi_rotation, + x_instance_translation=ssi_translation_unit, + x_scene_scale=instance_pose.scene_scale, + x_scene_center=instance_pose.scene_shift, + x_translation_scale=ssi_translation_scale, + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose: + scene_scale = pose_target.x_scene_scale + scene_shift = pose_target.x_scene_center + ssi_to_metric = cls.ssi_to_metric(scene_scale, scene_shift) + + ins_translation_unit = pose_target.x_instance_translation / pose_target.x_instance_translation.norm(dim=-1, keepdim=True) + ins_translation = ins_translation_unit * pose_target.x_translation_scale + + + ins_scale, ins_rotation, ins_translation = InstancePose._broadcast_postcompose( + scale=pose_target.x_instance_scale, + rotation=pose_target.x_instance_rotation, + translation=ins_translation, + transform_to_postcompose=ssi_to_metric, + ) + + + return InstancePose( + instance_scale_l2c=ins_scale, + instance_position_l2c=ins_translation, + instance_quaternion_l2c=ins_rotation, + scene_scale=scene_scale, + scene_shift=scene_shift, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget: + instance_pose = cls.to_instance_pose(pose_target) + return InvariantPoseTarget.from_instance_pose(instance_pose) + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget: + instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets) + return cls.from_instance_pose(instance_pose) + + @classmethod + def get_scale_and_shift(cls, pointmap): + shift_z = pointmap[..., -1].nanmedian().unsqueeze(0) + shift = torch.zeros_like(shift_z.expand(1, 3)) + shift[..., -1] = shift_z + + shifted_pointmap = pointmap - shift + scale = shifted_pointmap.abs().nanmean().to(shift.device) + + shift = shift.reshape(3) + scale = scale.expand(3) + + return scale, shift + + @staticmethod + def ssi_to_metric(scale: torch.Tensor, shift: torch.Tensor): + if scale.ndim == 1: + scale = scale.unsqueeze(0) + if shift.ndim == 1: + shift = shift.unsqueeze(0) + return Transform3d().scale(scale).translate(shift).to(shift.device) + + +class DisparitySpace(PoseTargetConvention): + pose_target_convention: str = "DisparitySpace" + + @classmethod + def from_instance_pose(cls, instance_pose: InstancePose, normalize: bool = False) -> PoseTarget: + + # x_instance_scale = orig_scale / scene_scale + # x_instance_translation = [x/z, y/z, 0] / scene_scale + # x_translation_scale = z / scene_scale + assert torch.allclose(instance_pose.scene_scale, torch.ones_like(instance_pose.scene_scale)) + + if not instance_pose.scene_shift.ndim == instance_pose.instance_position_l2c.ndim - 1: + raise ValueError(f"scene_shift must be (N+1)D and instance_position_l2c must be (N+1)D, but got {instance_pose.scene_shift.ndim} and {instance_pose.instance_position_l2c.ndim}") + shift_xy, shift_z_log = instance_pose.scene_shift.unsqueeze(-2).split([2, 1], dim=-1) + + + pose_xy, pose_z = instance_pose.instance_position_l2c.split([2, 1], dim=-1) + # Handle batch dimensions properly + if shift_xy.ndim < pose_xy.ndim: + shift_xy = shift_xy.unsqueeze(-2) + pose_xy_scaled = pose_xy / pose_z - shift_xy + + pose_z_scaled_log = torch.log(pose_z) - shift_z_log + x_instance_scale_log = torch.log(instance_pose.instance_scale_l2c) - torch.log(pose_z) + + x_instance_translation = torch.cat([pose_xy_scaled, torch.zeros_like(pose_z)], dim=-1) + x_translation_scale = torch.exp(pose_z_scaled_log) + x_instance_scale = torch.exp(x_instance_scale_log) + + + + return PoseTarget( + x_instance_scale=x_instance_scale, + x_instance_translation=x_instance_translation, + x_instance_rotation=instance_pose.instance_quaternion_l2c, + x_scene_scale=instance_pose.scene_scale, + x_scene_center=instance_pose.scene_shift, + x_translation_scale=x_translation_scale, + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_instance_pose(cls, pose_target: PoseTarget, normalize: bool = False) -> InstancePose: + scene_scale = pose_target.x_scene_scale + scene_shift = pose_target.x_scene_center + + if not pose_target.x_scene_center.ndim == pose_target.x_instance_translation.ndim - 1: + raise ValueError(f"x_scene_center must be (N+1)D and x_instance_translation must be (N+1)D, but got {pose_target.x_scene_center.ndim} and {pose_target.x_instance_translation.ndim}") + shift_xy, shift_z_log = pose_target.x_scene_center.unsqueeze(-2).split([2, 1], dim=-1) + scene_z_scale = torch.exp(shift_z_log) + + z = pose_target.x_translation_scale + ins_translation = pose_target.x_instance_translation.clone() + ins_translation[...,2] = 1.0 + ins_translation[...,:2] = ins_translation[...,:2] + shift_xy + ins_translation = ins_translation * z * scene_z_scale + + ins_scale = pose_target.x_instance_scale * z * scene_z_scale + + return InstancePose( + instance_scale_l2c=ins_scale * scene_scale, + instance_position_l2c=ins_translation * scene_scale, + instance_quaternion_l2c=pose_target.x_instance_rotation, + scene_scale=scene_scale, + scene_shift=scene_shift, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget, normalize: bool = False) -> InvariantPoseTarget: + instance_pose = cls.to_instance_pose(pose_target, normalize=normalize) + return InvariantPoseTarget.from_instance_pose(instance_pose) + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget, normalize: bool = False) -> PoseTarget: + instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets) + return cls.from_instance_pose(instance_pose, normalize=normalize) + + + +class NormalizedSceneScale(PoseTargetConvention): + """ + x_instance_scale and x_translation_scale are normalized to x_scene_scale + """ + + pose_target_convention: str = "NormalizedSceneScale" + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget): + translation = invariant_targets.t_unit * invariant_targets.t_rel_norm + return PoseTarget( + x_instance_scale=invariant_targets.s_rel, + x_instance_rotation=invariant_targets.q, + x_instance_translation=translation, + x_scene_scale=invariant_targets.s_scene, + x_scene_center=invariant_targets.t_scene_center, + x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm), + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget): + t_rel_norm = torch.norm( + pose_target.x_instance_translation, dim=-1, keepdim=True + ) + return InvariantPoseTarget( + s_scene=pose_target.x_scene_scale, + s_rel=pose_target.x_instance_scale, + q=pose_target.x_instance_rotation, + t_unit=pose_target.x_instance_translation / t_rel_norm, + t_rel_norm=t_rel_norm, + t_scene_center=pose_target.x_scene_center, + ) + + +class Naive(PoseTargetConvention): + pose_target_convention: str = "Naive" + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget): + s_scene = invariant_targets.s_rel * invariant_targets.s_scene + t_scene = invariant_targets.t_unit * invariant_targets.t_rel_norm + return PoseTarget( + x_instance_scale=s_scene, + x_instance_rotation=invariant_targets.q, + x_instance_translation=t_scene, + x_scene_scale=invariant_targets.s_scene, + x_scene_center=invariant_targets.t_scene_center, + x_translation_scale=torch.ones_like(invariant_targets.t_rel_norm), + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget): + s_scene = pose_target.x_scene_scale + t_rel_norm = torch.norm( + pose_target.x_instance_translation, dim=-1, keepdim=True + ) + return InvariantPoseTarget( + s_scene=s_scene, + t_scene_center=pose_target.x_scene_center, + s_rel=pose_target.x_instance_scale / s_scene, + q=pose_target.x_instance_rotation, + t_unit=pose_target.x_instance_translation / t_rel_norm, + t_rel_norm=t_rel_norm, + ) + + +class NormalizedSceneScaleAndTranslation(PoseTargetConvention): + """ + x_instance_scale and x_translation_scale are normalized to x_scene_scale + x_instance_translation is unit + """ + + pose_target_convention: str = "NormalizedSceneScaleAndTranslation" + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget): + return PoseTarget( + x_instance_scale=invariant_targets.s_rel, + x_instance_rotation=invariant_targets.q, + x_instance_translation=invariant_targets.t_unit, + x_scene_scale=invariant_targets.s_scene, + x_scene_center=invariant_targets.t_scene_center, + x_translation_scale=invariant_targets.t_rel_norm, + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget): + return InvariantPoseTarget( + s_scene=pose_target.x_scene_scale, + t_scene_center=pose_target.x_scene_center, + s_rel=pose_target.x_instance_scale, + q=pose_target.x_instance_rotation, + t_unit=pose_target.x_instance_translation, + t_rel_norm=pose_target.x_translation_scale, + ) + + +class ApparentSize(PoseTargetConvention): + pose_target_convention: str = "ApparentSize" + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget): + return PoseTarget( + x_instance_scale=invariant_targets.s_tilde, + x_instance_rotation=invariant_targets.q, + x_instance_translation=invariant_targets.t_unit, + x_scene_scale=invariant_targets.s_scene, + x_scene_center=invariant_targets.t_scene_center, + x_translation_scale=invariant_targets.t_rel_norm, + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget): + return InvariantPoseTarget( + s_scene=pose_target.x_scene_scale, + t_scene_center=pose_target.x_scene_center, + s_tilde=pose_target.x_instance_scale, + q=pose_target.x_instance_rotation, + t_unit=pose_target.x_instance_translation, + t_rel_norm=pose_target.x_translation_scale, + ) + + +class Identity(PoseTargetConvention): + """ + Identity convention - no transformation applied. + Direct passthrough mapping between instance pose and pose target values. + This preserves all values including scene_scale and scene_shift. + """ + + pose_target_convention: str = "Identity" + + @classmethod + def from_instance_pose(cls, instance_pose: InstancePose) -> PoseTarget: + return PoseTarget( + x_instance_scale=instance_pose.instance_scale_l2c, + x_instance_rotation=instance_pose.instance_quaternion_l2c, + x_instance_translation=instance_pose.instance_position_l2c, + x_scene_scale=instance_pose.scene_scale, + x_scene_center=instance_pose.scene_shift, + x_translation_scale=torch.ones_like(instance_pose.instance_scale_l2c)[..., 0].unsqueeze(-1), + pose_target_convention=cls.pose_target_convention, + ) + + @classmethod + def to_instance_pose(cls, pose_target: PoseTarget) -> InstancePose: + return InstancePose( + instance_scale_l2c=pose_target.x_instance_scale, + instance_position_l2c=pose_target.x_instance_translation, + instance_quaternion_l2c=pose_target.x_instance_rotation, + scene_scale=pose_target.x_scene_scale, + scene_shift=pose_target.x_scene_center, + ) + + @classmethod + def to_invariant(cls, pose_target: PoseTarget) -> InvariantPoseTarget: + instance_pose = cls.to_instance_pose(pose_target) + return InvariantPoseTarget.from_instance_pose(instance_pose) + + @classmethod + def from_invariant(cls, invariant_targets: InvariantPoseTarget) -> PoseTarget: + instance_pose = InvariantPoseTarget.to_instance_pose(invariant_targets) + return cls.from_instance_pose(instance_pose) + + +class PoseTargetConverter: + @staticmethod + def pose_target_to_instance_pose(pose_target: PoseTarget, normalize: bool = False) -> InstancePose: + _convention_class = globals()[pose_target.pose_target_convention] + if _convention_class == ScaleShiftInvariant: + return _convention_class.to_instance_pose(pose_target, normalize=normalize) + else: + return _convention_class.to_instance_pose(pose_target) + + @staticmethod + def instance_pose_to_pose_target( + instance_pose: InstancePose, pose_target_convention: str, normalize: bool = False + ) -> PoseTarget: + _convention_class = globals()[pose_target_convention] + if _convention_class == ScaleShiftInvariant: + return _convention_class.from_instance_pose(instance_pose, normalize=normalize) + else: + return _convention_class.from_instance_pose(instance_pose) + + @staticmethod + def dicts_instance_pose_to_pose_target( + pose_target_convention: str, + **kwargs, + ): + instance_pose = InstancePose(**kwargs) + pose_target = PoseTargetConverter.instance_pose_to_pose_target( + instance_pose, pose_target_convention + ) + return asdict(pose_target) + + @staticmethod + def dicts_pose_target_to_instance_pose( + **kwargs, + ): + pose_target_convention = kwargs.get("pose_target_convention") + _convention_class = globals()[pose_target_convention] + assert ( + _convention_class.pose_target_convention == pose_target_convention + ), f"Normalization name mismatch: {_convention_class.pose_target_convention} != {pose_target_convention}" + + normalize = kwargs.pop("normalize", False) + pose_target = PoseTarget(**kwargs) + instance_pose = PoseTargetConverter.pose_target_to_instance_pose(pose_target, normalize) + return asdict(instance_pose) + + +class LogScaleShiftNormalizer: + def __init__(self, shift_log: torch.Tensor = 0.0, scale_log: torch.Tensor = 1.0): + self.shift_log = shift_log + self.scale_log = scale_log + + def normalize(self, value: torch.Tensor): + return torch.log(value) - self.shift_log / self.scale_log + + def denormalize(self, value: torch.Tensor): + return torch.exp(value * self.scale_log + self.shift_log) \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..862bd77ae28e18001e88b6701f3361e17045ebd6 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/preprocessor.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import warnings +import torch +from loguru import logger +from dataclasses import dataclass +from typing import Callable, Optional +import warnings + +from .img_and_mask_transforms import ( + SSIPointmapNormalizer, +) + + +# Load and process data +@dataclass +class PreProcessor: + """ + Preprocessor configuration for image, mask, and pointmap transforms. + + Transform application order: + 1. Pointmap normalization (if normalize_pointmap=True) + 2. Joint transforms (img_mask_pointmap_joint_transform or img_mask_joint_transform) + 3. Individual transforms (img_transform, mask_transform, pointmap_transform) + + For backward compatibility, img_mask_joint_transform is preserved. When both + img_mask_pointmap_joint_transform and img_mask_joint_transform are present, + img_mask_pointmap_joint_transform takes priority. + """ + + img_transform: Callable = (None,) + mask_transform: Callable = (None,) + img_mask_joint_transform: list[Callable] = (None,) + rgb_img_mask_joint_transform: list[Callable] = (None,) + + # New fields for pointmap support + pointmap_transform: Callable = (None,) + img_mask_pointmap_joint_transform: list[Callable] = (None,) + + # Pointmap normalization option + normalize_pointmap: bool = False + pointmap_normalizer: Optional[Callable] = None + rgb_pointmap_normalizer: Optional[Callable] = None + + def __post_init__(self): + if self.pointmap_normalizer is None: + self.pointmap_normalizer = SSIPointmapNormalizer() + if self.normalize_pointmap == False: + warnings.warn("normalize_pointmap is also set to False, which means we will return the moments but not normalize the pointmap. This supports old unnormalized pointmap models, but this is dangerous behavior.", DeprecationWarning, stacklevel=2) + + if self.rgb_pointmap_normalizer is None: + logger.warning("No rgb pointmap normalizer provided, using scale + shift ") + self.rgb_pointmap_normalizer = self.pointmap_normalizer + + + def _normalize_pointmap( + self, pointmap: torch.Tensor, + mask: torch.Tensor, + pointmap_normalizer: Callable, + scale: Optional[torch.Tensor] = None, + shift: Optional[torch.Tensor] = None, + ): + if pointmap is None: + return pointmap, None, None + + if self.normalize_pointmap == False: + # old behavior: Pose is normalized to the pointmap center, but pointmap is not + _, pointmap_scale, pointmap_shift = pointmap_normalizer.normalize(pointmap, mask) + return pointmap, pointmap_scale, pointmap_shift + + if scale is not None or shift is not None: + return pointmap_normalizer.normalize(pointmap, mask, scale, shift) + + return pointmap_normalizer.normalize(pointmap, mask) + + def _process_image_mask_pointmap_mess( + self, rgb_image, rgb_image_mask, pointmap=None + ): + """Extended version that handles pointmaps""" + + # Apply pointmap normalization if enabled + pointmap_for_crop, pointmap_scale, pointmap_shift = self._normalize_pointmap( + pointmap, rgb_image_mask, self.pointmap_normalizer + ) + + # Apply transforms to the original full rgb image and mask. + rgb_image, rgb_image_mask = self._preprocess_rgb_image_mask(rgb_image, rgb_image_mask) + + # These two are typically used for getting cropped images of the object + # : first apply joint transforms + processed_rgb_image, processed_mask, processed_pointmap = ( + self._preprocess_image_mask_pointmap(rgb_image, rgb_image_mask, pointmap_for_crop) + ) + # : then apply individual transforms on top of the joint transforms + processed_rgb_image = self._apply_transform( + processed_rgb_image, self.img_transform + ) + processed_mask = self._apply_transform(processed_mask, self.mask_transform) + if processed_pointmap is not None: + processed_pointmap = self._apply_transform( + processed_pointmap, self.pointmap_transform + ) + + # This version is typically the full version of the image + # : apply individual transforms only + rgb_image = self._apply_transform(rgb_image, self.img_transform) + rgb_image_mask = self._apply_transform(rgb_image_mask, self.mask_transform) + + rgb_pointmap, rgb_pointmap_scale, rgb_pointmap_shift = self._normalize_pointmap( + pointmap, rgb_image_mask, self.rgb_pointmap_normalizer, pointmap_scale, pointmap_shift + ) + + if rgb_pointmap is not None: + rgb_pointmap = self._apply_transform(rgb_pointmap, self.pointmap_transform) + + result = { + "mask": processed_mask, + "image": processed_rgb_image, + "rgb_image": rgb_image, + "rgb_image_mask": rgb_image_mask, + } + + # Add pointmap results if available + if processed_pointmap is not None: + result.update( + { + "pointmap": processed_pointmap, + "rgb_pointmap": rgb_pointmap, + } + ) + + # Add normalization parameters if normalization was applied + if pointmap_scale is not None and pointmap_shift is not None: + result.update( + { + "pointmap_scale": pointmap_scale, + "pointmap_shift": pointmap_shift, + "rgb_pointmap_scale": rgb_pointmap_scale, + "rgb_pointmap_shift": rgb_pointmap_shift, + } + ) + + return result + + def _process_image_and_mask_mess(self, rgb_image, rgb_image_mask): + """Original method - calls extended version without pointmap""" + return self._process_image_mask_pointmap_mess(rgb_image, rgb_image_mask, None) + + def _preprocess_rgb_image_mask(self, rgb_image: torch.Tensor, rgb_image_mask: torch.Tensor): + """Apply joint transforms to rgb_image and rgb_image_mask.""" + if ( + self.rgb_img_mask_joint_transform != (None,) + and self.rgb_img_mask_joint_transform is not None + ): + for trans in self.rgb_img_mask_joint_transform: + rgb_image, rgb_image_mask = trans(rgb_image, rgb_image_mask) + return rgb_image, rgb_image_mask + + def _preprocess_image_mask_pointmap(self, rgb_image, mask_image, pointmap=None): + """Apply joint transforms with priority: triple transforms > dual transforms.""" + # Priority: img_mask_pointmap_joint_transform when pointmap is provided + if ( + self.img_mask_pointmap_joint_transform != (None,) + and self.img_mask_pointmap_joint_transform is not None + and pointmap is not None + ): + for trans in self.img_mask_pointmap_joint_transform: + rgb_image, mask_image, pointmap = trans( + rgb_image, mask_image, pointmap=pointmap + ) + return rgb_image, mask_image, pointmap + + # Fallback: img_mask_joint_transform (existing behavior) + elif ( + self.img_mask_joint_transform != (None,) + and self.img_mask_joint_transform is not None + ): + for trans in self.img_mask_joint_transform: + rgb_image, mask_image = trans(rgb_image, mask_image) + return rgb_image, mask_image, pointmap + + return rgb_image, mask_image, pointmap + + def _preprocess_image_and_mask(self, rgb_image, mask_image): + """Backward compatibility wrapper - only applies dual transforms""" + rgb_image, mask_image, _ = self._preprocess_image_mask_pointmap( + rgb_image, mask_image, None + ) + return rgb_image, mask_image + + # keep here for backward compatibility + def _preprocess_image_and_mask_inference(self, rgb_image, mask_image): + warnings.warn( + "The _preprocess_image_and_mask_inference is deprecated! Please use _preprocess_image_and_mask", + category=DeprecationWarning, + stacklevel=2, + ) + return self._preprocess_image_and_mask(rgb_image, mask_image) + + def _apply_transform(self, input: torch.Tensor, transform): + if input is not None and transform is not None and transform != (None,): + input = transform(input) + + return input \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f417c2be8ec6259ccf2f403adb3bad43c54c87 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/dataset/tdfy/transforms_3d.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from collections import namedtuple +import math +import torch + +from pytorch3d.transforms import ( + Rotate, + Translate, + Scale, + Transform3d, + quaternion_to_matrix, + axis_angle_to_quaternion, +) + +DecomposedTransform = namedtuple( + "DecomposedTransform", ["scale", "rotation", "translation"] +) + + +def compose_transform( + scale: torch.Tensor, rotation: torch.Tensor, translation: torch.Tensor +) -> Transform3d: + """ + Args: + scale: (..., 3) tensor of scale factors + rotation: (..., 3, 3) tensor of rotation matrices + translation: (..., 3) tensor of translation vectors + """ + tfm = Transform3d(dtype=scale.dtype, device=scale.device) + return tfm.scale(scale).rotate(rotation).translate(translation) + + +def decompose_transform(transform: Transform3d) -> DecomposedTransform: + """ + Returns: + scale: (..., 3) tensor of scale factors + rotation: (..., 3, 3) tensor of rotation matrices + translation: (..., 3) tensor of translation vectors + """ + matrices = transform.get_matrix() + scale = torch.norm(matrices[:, :3, :3], dim=-1) + rotation = matrices[:, :3, :3] / scale.unsqueeze(-1) # Normalize rotation matrix + translation = matrices[:, 3, :3] # Extract translation vector + return DecomposedTransform(scale, rotation, translation) + + +def get_rotation_about_x_axis(angle: float = math.pi / 2) -> torch.Tensor: + axis = torch.tensor([1.0, 0.0, 0.0]) + axis_angle = axis * angle + return axis_angle_to_quaternion(axis_angle) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa096923b1f6721099aa5f754123a5f5d13678b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/data/utils.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any, Iterable, Tuple, Union, Dict, Sequence, Mapping, Container +import optree +import torch +from collections.abc import Iterable +import inspect +import ast +import astor +from torch.utils import _pytree + +# None = root, Iterable[Any] = path, Any = path of one +ChildPathType = Union[None, Iterable[Any], Any] +ArgsType = Iterable[ChildPathType] +KwargsType = Mapping[str, ChildPathType] +ArgsKwargsType = Tuple[ArgsType, KwargsType] +MappingType = Union[None, ArgsKwargsType, ArgsType, KwargsType] + + +def tree_transpose_level_one( + structure, + check_children=False, + map_fn=None, + is_leaf=None, +): + _, outer_spec = optree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + none_is_leaf=True, + ) + + spec = optree.tree_structure(structure, none_is_leaf=True, is_leaf=is_leaf) + children_spec = spec.children() + if len(children_spec) > 0: + inner_spec = children_spec[0] + if check_children: + for child_spec in children_spec[1:]: + assert ( + inner_spec == child_spec + ), f"one child was found having a different tree structure ({inner_spec} != {child_spec})" + + structure = optree.tree_transpose(outer_spec, inner_spec, structure) + + if map_fn is not None: + structure = optree.tree_map( + map_fn, + structure, + is_leaf=lambda x: optree.tree_structure( + x, is_leaf=is_leaf, none_is_leaf=True + ) + == outer_spec, + none_is_leaf=True, + ) + + return structure + + +@staticmethod +def tree_tensor_map(fn, tree, *rest): + return optree.tree_map( + fn, + tree, + *rest, + is_leaf=lambda x: isinstance(x, torch.Tensor), + none_is_leaf=False, + ) + + +def to_device(obj, device): + """Recursively moves all tensors in obj to the specified device. + + Args: + obj: Object to move to device - can be a tensor, list, tuple, dict or any nested combination + device: Target device (e.g. 'cuda', 'cpu', torch.device('cuda:0') etc.) + + Returns: + Same object structure with all contained tensors moved to specified device + """ + to_fn = lambda x: x.to(device) + return optree.tree_map(to_fn, obj, is_leaf=torch.is_tensor, none_is_leaf=False) + + +def expand_right(tensor, target_shape): + """ + e.g. Takes tensor of (a, b, c) and returns a tensor of (a, b, c, 1, 1, ...) + """ + current_shape = tensor.shape + dims_to_add = len(target_shape) - len(current_shape) + result = tensor + for _ in range(dims_to_add): + result = result.unsqueeze(-1) + expand_shape = list(current_shape) + [-1] * dims_to_add + for i in range(len(target_shape)): + if i < len(expand_shape) and expand_shape[i] == -1: + expand_shape[i] = target_shape[i] + return result.expand(*expand_shape) + + +def expand_as_right(tensor, target): + return expand_right(tensor, target.shape) + + +def as_keys(path: ChildPathType): + if isinstance(path, Iterable) and (not isinstance(path, str)): + return tuple(path) + elif path is None: + return () + return (path,) + + +def get_child(obj: Any, *keys: Iterable[Any]): + for key in keys: + obj = obj[key] + return obj + + +def set_child(obj: Any, value: Any, *keys: Iterable[Any]): + parent = None + for key in keys: + parent = obj + obj = obj[key] + if parent is None: + obj = value + else: + parent[key] = value + return obj + + +def build_args_batch_extractor(args_mapping: ArgsType): + def extract_fn(batch): + return tuple(get_child(batch, *as_keys(path)) for path in args_mapping) + + return extract_fn + + +def build_kwargs_batch_extractor(kwargs_mapping: KwargsType): + def extract_fn(batch): + return { + name: get_child(batch, *as_keys(path)) + for name, path in kwargs_mapping.items() + } + + return extract_fn + + +empty_mapping = object() +kwargs_identity_mapping = object() + + +def build_batch_extractor(mapping: MappingType): + extract_args_fn = lambda x: () + extract_kwargs_fn = lambda x: {} + + if mapping is None: + + def extract_args_fn(batch): + return (batch,) + + elif mapping is empty_mapping: + pass + elif mapping is kwargs_identity_mapping: + extract_kwargs_fn = lambda x: x + elif isinstance(mapping, Sequence) and (not isinstance(mapping, str)): + if ( + len(mapping) == 2 + and isinstance(mapping[0], Sequence) + and isinstance(mapping[1], Dict) + ): + extract_args_fn = build_args_batch_extractor(mapping[0]) + extract_kwargs_fn = build_kwargs_batch_extractor(mapping[1]) + else: + extract_args_fn = build_args_batch_extractor(mapping) + elif isinstance(mapping, Mapping): + extract_kwargs_fn = build_kwargs_batch_extractor(mapping) + else: + + def extract_args_fn(batch): + return (get_child(batch, *as_keys(mapping)),) + + def extract_fn(batch): + return extract_args_fn(batch), extract_kwargs_fn(batch) + + return extract_fn + + +# > + + +def right_broadcasting(arr, target): + return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim)) + + +def get_stats(tensor: torch.Tensor): + float_tensor = tensor.float() + return { + "shape": tuple(tensor.shape), + "min": tensor.min().item(), + "max": tensor.max().item(), + "mean": float_tensor.mean().item(), + "median": tensor.median().item(), + "std": float_tensor.std().item(), + } + + +def _get_caller_arg_name(argnum=0, parent_frame=1): + try: + frame = inspect.currentframe() # current frame + frame = inspect.getouterframes(frame)[1 + parent_frame] # parent frame + code = inspect.getframeinfo(frame[0]).code_context[0].strip() # get code line + + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Call): + args = node.args + break # only get the first parent call + + # get first argument string (do not handle '=') + label = astor.to_source(args[argnum]).strip() + except: + # TODO(Pierre) log exception + label = "{label}" + return label + + +def print_stats(tensor, label=None): + if label is None: + label = _get_caller_arg_name(argnum=0) + stats = get_stats(tensor) + string = f"{label}:\n" + "\n".join(f"- {k}: {v}" for k, v in stats.items()) + print(string) + + +def tree_reduce_unique(fn, tree, ensure_unique=True, **kwargs): + values = _pytree.tree_flatten(tree, **kwargs)[0] + values = tuple(map(fn, values)) + first = values[0] + if ensure_unique: + for value in values[1:]: + if value != first: + raise RuntimeError( + f"different values found, {value} and {first} should be the same" + ) + return first diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..c489fe7f9a67ddd1eaa484a0fdbffb1b75c63c06 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/dino.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from typing import Optional, Dict, Any +import warnings +from torchvision.transforms import Normalize +import torch.nn.functional as F +from loguru import logger + + +class Dino(torch.nn.Module): + def __init__( + self, + input_size: int = 224, + repo_or_dir: str = "facebookresearch/dinov2", + dino_model: str = "dinov2_vitb14", + source: str = "github", + backbone_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + # for backward compatible + prenorm_features: bool = False, + freeze_backbone: bool = True, + prune_network: bool = False, # False for backward compatible + ): + super().__init__() + if backbone_kwargs is None: + backbone_kwargs = {} + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + logger.info(f"Loading DINO model: {dino_model} from {repo_or_dir} (source: {source})") + if backbone_kwargs: + logger.info(f"DINO backbone kwargs: {backbone_kwargs}") + + self.backbone = torch.hub.load( + repo_or_dir=repo_or_dir, + model=dino_model, + source=source, + verbose=False, + **backbone_kwargs, + ) + + # Log model properties after loading + logger.info(f"Loaded DINO model - type: {type(self.backbone)}, " + f"embed_dim: {self.backbone.embed_dim}, " + f"patch_size: {getattr(self.backbone.patch_embed, 'patch_size', 'N/A')}") + + + self.resize_input_size = (input_size, input_size) + self.embed_dim = self.backbone.embed_dim + self.input_size = input_size + self.input_channels = 3 + self.normalize_images = normalize_images + self.prenorm_features = prenorm_features + self.register_buffer('mean', torch.as_tensor([[0.485, 0.456, 0.406]]).view(-1, 1, 1), persistent=False) + self.register_buffer('std', torch.as_tensor([[0.229, 0.224, 0.225]]).view(-1, 1, 1), persistent=False) + + # freeze + if freeze_backbone: + self.requires_grad_(False) + self.eval() + elif not prune_network: + logger.warning( + "Unfreeze encoder w/o prune parameter may lead to error in ddp/fp16 training" + ) + + if prune_network: + self._prune_network() + + def _preprocess_input(self, x): + _resized_images = torch.nn.functional.interpolate( + x, + size=self.resize_input_size, + mode="bilinear", + align_corners=False, + ) + + if x.shape[1] == 1: + _resized_images = _resized_images.repeat(1, 3, 1, 1) + + if self.normalize_images: + _resized_images = _resized_images.sub_(self.mean).div_(self.std) + + return _resized_images + + def _forward_intermediate_layers( + self, input_img, intermediate_layers, cls_token=True + ): + return self.backbone.get_intermediate_layers( + input_img, + intermediate_layers, + return_class_token=cls_token, + ) + + def _forward_last_layer(self, input_img): + output = self.backbone.forward_features(input_img) + if self.prenorm_features: + features = output["x_prenorm"] + tokens = F.layer_norm(features, features.shape[-1:]) + else: + tokens = torch.cat( + [ + output["x_norm_clstoken"].unsqueeze(1), + output["x_norm_patchtokens"], + ], + dim=1, + ) + return tokens + + def forward(self, x, **kwargs): + _resized_images = self._preprocess_input(x) + tokens = self._forward_last_layer(_resized_images) + return tokens.to(x.dtype) + + def _prune_network(self): + """ + Ran this script: + out = model(input) + loss = out.sum() + loss.backward() + + for name, p in dino_model.named_parameters(): + if p.grad is None: + print(name) + model.zero_grad() + """ + self.backbone.mask_token = None + if self.prenorm_features: + self.backbone.norm = torch.nn.Identity() + + +class DinoForMasks(torch.nn.Module): + def __init__( + self, + backbone: Dino, + ): + super().__init__() + self.backbone = backbone + self.embed_dim = self.backbone.embed_dim + + def forward(self, image, mask): + return self.backbone.forward(mask) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b6130420f0716ff864e9e78fd9b164ba1e9da4 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/embedder_fuser.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import math +import torch +from loguru import logger +from torch import nn +from typing import Optional, Tuple, List, Literal, Dict +from sam3d_objects.model.layers.llama3.ff import FeedForward +from omegaconf import OmegaConf + +class EmbedderFuser(torch.nn.Module): + """ + Fusing individual condition embedder. Require kwargs for the forward! + Args: + embedder_list: List of Tuples. Each Tuple consists of a condition_embedder + and a list of tuple. In the list, each tuple consists of a string, indicating + a kward, and astring, indicating the group of positional encoding to be used. + use_pos_embedding: whether to add positional embedding. If add, follow the index in + embedder_list. Choices of None (no pos emb), random, and learned. + projection_pre_norm: pre-normalize features before feeding into projector layers. + projection_net_hidden_dim_multiplier: hidden dimension for projection layer. If 0, don't use. + """ + + def __init__( + self, + embedder_list: List[Tuple[nn.Module, List[Tuple[str, Optional[str]]]]], + use_pos_embedding: Optional[Literal["random", "learned"]] = "learned", + projection_pre_norm: bool = True, + projection_net_hidden_dim_multiplier: float = 4.0, + compression_projection_multiplier: float = 0, + freeze: bool = False, + drop_modalities_weight: Dict[List[str], float] = None, + dropout_prob: float = 0.0, + force_drop_modalities: List[str] = None, + ): + super().__init__() + # torch.compile does not support OmegaConf.ListConfig, so we convert to a list + if not isinstance(embedder_list, List): + self.embedder_list = OmegaConf.to_container(embedder_list) + else: + self.embedder_list = embedder_list + + self.embed_dims = 0 + self.compression_projection_multiplier = compression_projection_multiplier + self.concate_embed_dims = 0 + # keep moduleList to be compatible with nn module + self.module_list = [] + max_positional_embed_idx = 0 + self.positional_embed_map = {} + for condition_embedder, kwargs_info in self.embedder_list: + self.embed_dims = max(self.embed_dims, condition_embedder.embed_dim) + self.module_list.append(condition_embedder) + for _, pos_group in kwargs_info: + self.concate_embed_dims += condition_embedder.embed_dim + if pos_group is not None: + if pos_group not in self.positional_embed_map: + self.positional_embed_map[pos_group] = max_positional_embed_idx + max_positional_embed_idx += 1 + self.module_list = nn.ModuleList(self.module_list) + self.use_pos_embedding = use_pos_embedding + if self.use_pos_embedding == "random": + idx_emb = torch.randn(max_positional_embed_idx + 1, 1, self.embed_dims) + self.register_buffer("idx_emb", idx_emb) + elif self.use_pos_embedding == "learned": + self.idx_emb = nn.Parameter( + torch.empty(max_positional_embed_idx + 1, self.embed_dims) + ) + nn.init.normal_( + self.idx_emb, mean=0.0, std=1.0 / math.sqrt(self.embed_dims) + ) + else: + raise NotImplementedError(f"Unknown pos embedding {self.use_pos_embedding}") + + self.projection_pre_norm = projection_pre_norm + self.projection_net_hidden_dim_multiplier = projection_net_hidden_dim_multiplier + if projection_net_hidden_dim_multiplier > 0: + self.projection_nets = [] + for condition_embedder, _ in self.embedder_list: + self.projection_nets.append( + self._make_projection_net( + condition_embedder.embed_dim, + self.embed_dims, + self.projection_net_hidden_dim_multiplier, + ) + ) + self.projection_nets = nn.ModuleList(self.projection_nets) + + if compression_projection_multiplier > 0: + self.compression_projector = self._make_projection_net( + self.concate_embed_dims, + self.embed_dims, + self.compression_projection_multiplier, + ) + + self.drop_modalities_weight = drop_modalities_weight if drop_modalities_weight is not None else [] + self.dropout_prob = dropout_prob + self.force_drop_modalities = force_drop_modalities + + if freeze: + self.requires_grad_(False) + self.eval() + + def _make_projection_net( + self, + input_embed_dim, + output_embed_dim: int, + multiplier: int, + ): + if self.projection_pre_norm: + pre_norm = nn.LayerNorm(input_embed_dim) + else: + pre_norm = nn.Identity() + + # Per-token projection + gated activation + ff_net = FeedForward( + dim=input_embed_dim, + hidden_dim=int(multiplier * output_embed_dim), + output_dim=output_embed_dim, + ) + return nn.Sequential(pre_norm, ff_net) + + def _build_dropout_distribution(self, device): + """ + Build the probability distribution for dropout configurations. + + Returns: + dropout_configs: List of sets containing modalities to drop + cumsum_weights: Cumulative sum of weights for sampling + """ + dropout_configs = [] + weights = [] + + # Add no-dropout configuration with remaining probability + dropout_configs.append(set()) + weights.append(1.0 - self.dropout_prob) + + # Add configured dropout patterns + total_dropout_weight = sum(w for _, w in self.drop_modalities_weight) + assert total_dropout_weight > 0, "Total dropout weight must be positive when drop_modalities_weight is provided" + for modality_list, weight in self.drop_modalities_weight: + dropout_configs.append(set(modality_list)) + # Scale weight by dropout_prob to ensure total probability sums to 1 + weights.append(self.dropout_prob * weight / total_dropout_weight) + + # Convert weights to cumulative distribution + weights_tensor = torch.tensor(weights, device=device) + + was_deterministic = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(False) + cumsum_weights = torch.cumsum(weights_tensor, dim=0) + torch.use_deterministic_algorithms(was_deterministic) + + return dropout_configs, cumsum_weights + + def _apply_force_drop(self, kwarg_names: List[str], tokens: List[torch.Tensor]): + if not self.force_drop_modalities: + return tokens + + force_drop_set = set(self.force_drop_modalities) + result_tokens = [] + + for kwarg_name, token_tensor in zip(kwarg_names, tokens): + # Create mask: 0 for forced drop, 1 otherwise + mask = 0.0 if kwarg_name in force_drop_set else 1.0 + result_tokens.append(token_tensor * mask) + + return result_tokens + + def _dropout_modalities(self, kwarg_names: List[str], tokens: List[torch.Tensor]): + # First apply forced drops (deterministic, always applied) + tokens = self._apply_force_drop(kwarg_names, tokens) + + # Then apply probabilistic dropout (only in training) + if not self.training or self.dropout_prob <= 0 or not self.drop_modalities_weight: + return tokens + + batch_size = tokens[0].shape[0] + device = tokens[0].device + + # Build dropout configurations and sample which to use per batch element + dropout_configs, cumsum_weights = self._build_dropout_distribution(device) + rand_vals = torch.rand(batch_size, device=device) + # Clamp indices to valid range (handle edge case where rand_val == 1.0) + config_indices = torch.searchsorted(cumsum_weights, rand_vals).clamp(max=len(dropout_configs) - 1) + + # Apply dropout masks with vectorized operations + result_tokens = [] + for kwarg_name, token_tensor in zip(kwarg_names, tokens): + # Start with all ones (no dropout) + mask = torch.ones(batch_size, dtype=token_tensor.dtype, device=device) + + # Vectorized mask creation: check all configurations at once + for config_idx, modalities_to_drop in enumerate(dropout_configs): + if kwarg_name in modalities_to_drop: + # Set mask to 0 for all batch elements using this configuration + mask[config_indices == config_idx] = 0.0 + + # Reshape mask to match token dimensions + mask = mask.view([batch_size] + [1] * (token_tensor.ndim - 1)) + result_tokens.append(token_tensor * mask) + + return result_tokens + + def forward(self, *args, **kwargs): + tokens = [] + kwarg_names = [] + + for i, (condition_embedder, kwargs_info) in enumerate(self.embedder_list): + # Ideally, we would batch the inputs; but that assumes same-sized inputs + for kwarg_name, pos_group in kwargs_info: + if kwarg_name not in kwargs: + logger.warning(f"{kwarg_name} not in kwargs to condition embedder!") + input_cond = kwargs[kwarg_name] + cond_token = condition_embedder(input_cond) + if self.projection_net_hidden_dim_multiplier > 0: + cond_token = self.projection_nets[i](cond_token) + if pos_group is not None: + pos_idx = self.positional_embed_map[pos_group] + if self.use_pos_embedding == "random": + cond_token += self.idx_emb[pos_idx : pos_idx + 1] + elif self.use_pos_embedding == "learned": + cond_token += self.idx_emb[pos_idx : pos_idx + 1, None] + else: + raise NotImplementedError( + f"Unknown pos embedding {self.use_pos_embedding}" + ) + tokens.append(cond_token) + kwarg_names.append(kwarg_name) + + # Apply dropout modalities with preserved order + tokens = self._dropout_modalities(kwarg_names, tokens) + + if self.compression_projection_multiplier > 0: + tokens = torch.cat(tokens, dim=-1) + tokens = self.compression_projector(tokens) + else: + tokens = torch.cat(tokens, dim=1) + + return tokens diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f971de1afc47eb3d726dff870c78a9809c17540c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/point_remapper.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn + + +class PointRemapper(nn.Module): + """Handles remapping of 3D point coordinates and their inverse transformations.""" + + VALID_TYPES = ["linear", "sinh", "exp", "sinh_exp", "exp_disparity"] + + def __init__(self, remap_type: str = "exp"): + super().__init__() + self.remap_type = remap_type + + if remap_type not in self.VALID_TYPES: + raise ValueError( + f"Invalid remap type: {remap_type}. Must be one of {self.VALID_TYPES}" + ) + + def forward(self, points: torch.Tensor) -> torch.Tensor: + """Apply remapping to point coordinates.""" + if self.remap_type == "linear": + return points + + elif self.remap_type == "sinh": + return torch.asinh(points) + + elif self.remap_type == "exp": + xy_scaled, z_exp = points.split([2, 1], dim=-1) + # Use log1p for better numerical stability near zero + z = torch.log1p(z_exp) + xy = xy_scaled / (1 + z_exp) + return torch.cat([xy, z], dim=-1) + + elif self.remap_type == "exp_disparity": + xy_scaled, z_exp = points.split([2, 1], dim=-1) + xy = xy_scaled / z_exp + z = torch.log(z_exp) + return torch.cat([xy, z], dim=-1) + + elif self.remap_type == "sinh_exp": + xy_sinh, z_exp = points.split([2, 1], dim=-1) + xy = torch.asinh(xy_sinh) + z = torch.log(z_exp.clamp(min=1e-8)) + return torch.cat([xy, z], dim=-1) + + else: + raise ValueError(f"Unknown remap type: {self.remap_type}") + + def inverse(self, points: torch.Tensor) -> torch.Tensor: + """Apply inverse remapping to recover original point coordinates.""" + if self.remap_type == "linear": + return points + + elif self.remap_type == "sinh": + return torch.sinh(points) + + elif self.remap_type == "exp": + xy, z = points.split([2, 1], dim=-1) + # Inverse of log1p is expm1(z) = exp(z) - 1 + z_exp = torch.expm1(z) + # Inverse of xy/(1+z_exp) is xy*(1+z_exp) + return torch.cat([xy * (1 + z_exp), z_exp], dim=-1) + + elif self.remap_type == "exp_disparity": + xy, z = points.split([2, 1], dim=-1) + z_exp = torch.exp(z) + return torch.cat([xy * z_exp, z_exp], dim=-1) + + elif self.remap_type == "sinh_exp": + xy, z = points.split([2, 1], dim=-1) + return torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + + else: + raise ValueError(f"Unknown remap type: {self.remap_type}") + + def extra_repr(self) -> str: + return f"remap_type='{self.remap_type}'" diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py new file mode 100644 index 0000000000000000000000000000000000000000..62328b5d9c708ba58d48993eb96eeeac804dc32c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/dit/embedder/pointmap.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from timm.models.vision_transformer import Block +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial +from loguru import logger + +from .point_remapper import PointRemapper + + +class PointPatchEmbed(nn.Module): + """ + Projects (x,y,z) → D + Splits into patches (patch_size x patch_size) + Runs a tiny self-attention block inside each window + Returns one token per window. + """ + + def __init__( + self, + input_size: int = 256, + patch_size: int = 8, + embed_dim: int = 768, + remap_output: str = "exp", # Add remap_output parameter + dropout_prob: float = 0.0, # Dropout probability for pointmap + force_dropout_always: bool = False, # Force dropout during validation/inference + ): + super().__init__() + self.input_size = input_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.dropout_prob = dropout_prob + self.force_dropout_always = force_dropout_always + + # Point remapper + self.point_remapper = PointRemapper(remap_output) + + # (1) point embedding + self.point_proj = nn.Linear(3, embed_dim) + self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim)) + + # Special embedding for dropped patches (used during dropout) + # Alternative dropout strategies to consider: + # 1. Drop all tokens entirely or use a single token only + # 2. Different dropout patterns per window + # 3. Use dropped_xyz_token/invalid_xyz_token per pixel + if dropout_prob > 0: + self.dropped_xyz_token = nn.Parameter(torch.zeros(embed_dim)) + + # (2) positional embedding + num_patches = input_size // patch_size + # For patches + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, num_patches, num_patches) + ) + # For points in a patch + self.pos_embed_window = nn.Parameter( + torch.zeros(1, 1 + patch_size * patch_size, embed_dim) + ) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # (3) within-patch transformer block(s) + # From MCC: https://github.com/facebookresearch/MCC/blob/b04c97518360e4fdedfb6f090db7e90d0c2f8ae6/mcc_model.py#L97 + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads=16, + mlp_ratio=2.0, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + ] + ) + self.initialize_weights() + + def initialize_weights(self): + # Initialize positional embeddings with small std + nn.init.normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.pos_embed_window, std=0.02) + nn.init.normal_(self.cls_token, std=0.02) + nn.init.normal_(self.invalid_xyz_token, std=0.02) + + # Initialize dropped pointmap token if dropout is enabled + if self.dropout_prob > 0: + nn.init.normal_(self.dropped_xyz_token, std=0.02) + + # Initialize point projection with xavier uniform for better gradient flow + # This is crucial since pointmaps can have large value ranges + nn.init.xavier_uniform_(self.point_proj.weight, gain=0.02) + if self.point_proj.bias is not None: + nn.init.constant_(self.point_proj.bias, 0) + + def _get_pos_embed(self, hw): + h, w = hw + pos_embed = F.interpolate( + self.pos_embed, size=(h, w), mode="bilinear", align_corners=False + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) # (B, H, W, C) + return pos_embed + + def resize_input(self, xyz: torch.Tensor) -> torch.Tensor: + resized_xyz = F.interpolate(xyz, size=self.input_size, mode="nearest") + resized_xyz = resized_xyz.permute(0, 2, 3, 1) # (B, H, W, C) + return resized_xyz + + def apply_pointmap_dropout(self, embeddings: torch.Tensor) -> torch.Tensor: + """ + Apply dropout to pointmap embeddings. + Drops entire pointmap for selected samples during training or when forced. + + When force_dropout_always is True, always drops pointmap regardless of training mode. + """ + # Check if we should apply dropout + should_apply_dropout = (self.training or self.force_dropout_always) and self.dropout_prob > 0 + + if not should_apply_dropout: + return embeddings + + # Check if dropout infrastructure exists + if not hasattr(self, 'dropped_xyz_token'): + if self.force_dropout_always: + raise RuntimeError( + "Cannot force dropout: model was initialized with dropout_prob=0. " + "Re-initialize with dropout_prob > 0 to enable forced dropout." + ) + return embeddings + + batch_size, n_windows, embed_dim = embeddings.shape + + # Decide dropout behavior + if self.force_dropout_always and not self.training: + # When forced during inference, always drop (100% dropout) + drop_mask = torch.ones(batch_size, device=embeddings.device, dtype=torch.bool) + else: + # Normal training dropout - use configured probability + drop_mask = torch.rand(batch_size, device=embeddings.device) < self.dropout_prob + + # Create dropped embedding for all windows - use same token for all patches + # Shape: (batch_size, n_windows, embed_dim) + dropped_embedding = self.dropped_xyz_token.view(1, 1, embed_dim).expand(batch_size, n_windows, embed_dim) + + # Add positional embeddings to dropped tokens (same as regular embeddings get) + n_windows_h = n_windows_w = int(n_windows ** 0.5) + pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape( + 1, n_windows, embed_dim + ) + dropped_embedding = dropped_embedding + pos_embed_patch + drop_mask_expanded = drop_mask.view(batch_size, 1, 1).expand_as(embeddings) + embeddings = torch.where(drop_mask_expanded, dropped_embedding, embeddings) + + return embeddings + + @torch._dynamo.disable() + def embed_pointmap_windows( + self, xyz: torch.Tensor, valid_mask: torch.Tensor = None + ) -> torch.Tensor: + """Process pointmap into window embeddings without positional encoding""" + with torch.no_grad(): + xyz = self.resize_input(xyz) + if valid_mask is None: + valid_mask = xyz.isfinite().all(dim=-1) + + B, H, W, _ = xyz.shape + assert ( + H % self.patch_size == 0 and W % self.patch_size == 0 + ), "image must be divisible by patch_size" + + # (1) Handle NaN values before remapping to prevent propagation + xyz_safe = xyz.clone() + xyz_safe[~valid_mask] = 0.0 # Set invalid points to 0 before remapping + + # (1b) remap points to normalize their range + xyz_remapped = self.point_remapper(xyz_safe) + + # (2) project + invalid token + x = self.point_proj(xyz_remapped) # (B,H,W,D) + + x[~valid_mask] = 0.0 # Stop gradient for invalid points + x[~valid_mask] += self.invalid_xyz_token + + return x, B, H, W + + def inner_forward( + self, x: torch.Tensor, B: int, H: int, W: int + ) -> torch.Tensor: + x = x.view( + B, + H // self.patch_size, + self.patch_size, + W // self.patch_size, + self.patch_size, + self.embed_dim, + ) # (B, hW, wW, ws, ws, D) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # (B, hW, wW, ws, ws, D) + x = x.view(-1, self.patch_size * self.patch_size, self.embed_dim) + + # (4) CLS token that contains the patch information + cls_tok = self.cls_token.expand(x.shape[0], -1, -1) + toks = torch.cat([cls_tok, x], dim=1) + + # (5) add positional embedding for window + toks = toks + self.pos_embed_window + + # (6) intra-window attention + for blk in self.blocks: + toks = blk(toks) + + # (7) Extract CLS tokens and reshape to (B, n_windows, embed_dim) + n_windows_h = H // self.patch_size + n_windows_w = W // self.patch_size + window_embeddings = toks[:, 0].view(B, n_windows_h * n_windows_w, self.embed_dim) + + # Add positional embeddings + pos_embed_patch = self._get_pos_embed((n_windows_h, n_windows_w)).reshape( + 1, n_windows_h * n_windows_w, self.embed_dim + ) + out = window_embeddings + pos_embed_patch + + # Apply dropout if enabled (during training OR when forced) + if (self.training or self.force_dropout_always) and self.dropout_prob > 0: + out = self.apply_pointmap_dropout(out) + + return out + + def forward( + self, xyz: torch.Tensor, valid_mask: torch.Tensor = None + ) -> torch.Tensor: + """ + xyz : (B, 3, H, W) map of (x,y,z) coordinates + valid_mask : (B, H, W) boolean - True for valid points (optional) + + returns: (B, num_windows, D) + """ + # Get window embeddings + x, B, H, W = self.embed_pointmap_windows(xyz, valid_mask) + return self.inner_forward(x, B, H, W) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7f99c46c64718e35a04a8db2ad6db5c45ab085 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/base.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from typing import Optional, Union + + +class Base(torch.nn.Module): + def __init__(self, seed_or_generator: Optional[Union[int, torch.Generator]] = None): + super().__init__() + + if isinstance(seed_or_generator, torch.Generator): + self.random_generator = seed_or_generator + elif isinstance(seed_or_generator, int): + self.seed = seed_or_generator + elif seed_or_generator is None: + self.random_generator = torch.default_generator + else: + raise RuntimeError( + f"cannot use argument of type {type(seed_or_generator)} to set random generator" + ) + + @property + def seed(self): + raise AttributeError(f"Cannot read attribute 'seed'.") + + @seed.setter + def seed(self, value: int): + self._random_generator = torch.Generator().manual_seed(value) + + @property + def random_generator(self): + return self._random_generator + + @random_generator.setter + def random_generator(self, generator: torch.Generator): + self._random_generator = generator + + def forward(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals): + return self.generate( + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ) + + def generate(self, x_shape, x_device, *args_conditionals, **kwargs_conditionals): + for _, xt, _ in self.generate_iter( + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ): + pass + return xt + + def generate_iter( + self, + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ): + raise NotImplementedError + + def loss(self, x, *args_conditionals, **kwargs_conditionals): + raise NotImplementedError diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..804ae18046ecbf3721e0f45ef15788026f220e34 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/classifier_free_guidance.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from functools import partial +from numbers import Number +import torch +import random +from torch.utils import _pytree +from torch.utils._pytree import tree_map_only +from loguru import logger + +def _zeros_like(struct): + def make_zeros(x): + if isinstance(x, torch.Tensor): + return torch.zeros_like(x) + return x + + return _pytree.tree_map(make_zeros, struct) + + +def zero_out(args, kwargs): + args = _zeros_like(args) + kwargs = _zeros_like(kwargs) + return args, kwargs + + +def discard(args, kwargs): + return (), {} + + +def _drop_tensors(struct): + """ + Drop any conditioning that are tensors + Not using _pytree since we actually want to throw them instead of keeping them. + """ + if isinstance(struct, dict): + return { + k: _drop_tensors(v) + for k, v in struct.items() + if not isinstance(v, torch.Tensor) + } + elif isinstance(struct, (list, tuple)): + filtered = [_drop_tensors(x) for x in struct if not isinstance(x, torch.Tensor)] + return tuple(filtered) if isinstance(struct, tuple) else filtered + else: + return struct + + +def drop_tensors(args, kwargs): + args = _drop_tensors(args) + kwargs = _drop_tensors(kwargs) + return args, kwargs + + +def add_flag(args, kwargs): + kwargs["cfg"] = True + return args, kwargs + + +class ClassifierFreeGuidance(torch.nn.Module): + UNCONDITIONAL_HANDLING_TYPES = { + "zeros": zero_out, + "discard": discard, + "drop_tensors": drop_tensors, + "add_flag": add_flag, + } + + def __init__( + self, + backbone, # backbone should be a backbone/generator (e.g. DDPM/DDIM/FlowMatching) + p_unconditional=0.1, + strength=3.0, + # "zeros" = set cond tensors to 0, + # "discard" = remove cond arguments and let underlying model handle it + # "drop_tensors" = drop all tensors but leave non-tensors + # "add_flag" = add an argument in kwargs as "cfg" and defer the handling to generator backbone + unconditional_handling="zeros", + interval=None, # only perform cfg if t within interval + ): + super().__init__() + + if not ( + unconditional_handling + in ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES + ): + raise RuntimeError( + f"'{unconditional_handling}' is not valid for `unconditional_handling`, should be in {ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES}" + ) + + self.backbone = backbone + self.p_unconditional = p_unconditional + self.strength = strength + self.unconditional_handling = unconditional_handling + self.interval = interval + self._make_unconditional_args = ( + ClassifierFreeGuidance.UNCONDITIONAL_HANDLING_TYPES[ + self.unconditional_handling + ] + ) + + def _cfg_step_tensor(self, y_cond, y_uncond, strength): + return (1 + strength) * y_cond - strength * y_uncond + + def _cfg_step(self, y_cond, y_uncond, strength): + if isinstance(strength, dict): + return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, strength) + else: + return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength), y_cond, y_uncond) + + def inner_forward(self, x, t, is_cond, strength, *args_cond, **kwargs_cond): + y_cond = self.backbone(x, t, *args_cond, **kwargs_cond) + if is_cond: + return y_cond + else: + args_cond, kwargs_cond = self._make_unconditional_args( + args_cond, + kwargs_cond, + ) + y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond) + return self._cfg_step(y_cond, y_uncond, strength) + + def forward(self, x, t, *args_cond, **kwargs_cond): + # handle case when no conditional arguments are provided + if len(args_cond) + len(kwargs_cond) == 0: # unconditional + if self.unconditional_handling != "discard": + raise RuntimeError( + f"cannot call `ClassifierFreeGuidance` module without condition" + ) + return self.backbone(x, t) + else: # conditional arguments are provided + # training mode + if self.training: + coin_flip = random.random() < self.p_unconditional + if coin_flip: # unconditional + args_cond, kwargs_cond = self._make_unconditional_args( + args_cond, + kwargs_cond, + ) + return self.backbone(x, t, *args_cond, **kwargs_cond) + else: # inference mode + strength = get_strength(self.strength, self.interval, t) + is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0]) + return self.inner_forward( + x, t, is_cond, strength, *args_cond, **kwargs_cond + ) + +def get_strength(strength, interval, t): + if interval is None: + return _pytree.tree_map(lambda x: 0.0, strength) + + # If interval is not a dict (single tuple), broadcast it + if not isinstance(interval, dict): + return _pytree.tree_map( + lambda x: x if interval[0] <= t <= interval[1] else 0.0, + strength + ) + + return _pytree.tree_map( + lambda x, iv: x if iv[0] <= t <= iv[1] else 0.0, + strength, + interval + ) + +class PointmapCFG(ClassifierFreeGuidance): + + def __init__(self, *args, strength_pm=0.0, **kwargs): + super().__init__(*args, **kwargs) + self.strength_pm = strength_pm + + def _cfg_step_tensor(self, y_cond, y_uncond, y_unpm, strength, strength_pm): + # https://arxiv.org/abs/2411.18613 + return y_cond \ + + strength_pm * (y_cond - y_unpm) \ + + strength * (y_unpm - y_uncond) + + def _cfg_step(self, y_cond, y_uncond, y_pm, strength, strength_pm): + if isinstance(strength, dict): + return _pytree.tree_map(self._cfg_step_tensor, y_cond, y_uncond, y_pm, strength, strength_pm) + else: + return _pytree.tree_map(partial(self._cfg_step_tensor, strength=strength, strength_pm=strength_pm), y_cond, y_uncond, y_pm) + + def inner_forward(self, x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond): + y_cond = self.backbone(x, t, *args_cond, **kwargs_cond) + + if is_cond: + return y_cond + else: + force_drop_modalities = self.backbone.condition_embedder.force_drop_modalities + self.backbone.condition_embedder.force_drop_modalities = ['pointmap', 'rgb_pointmap'] + y_pm = self.backbone(x, t, *args_cond, **kwargs_cond) + self.backbone.condition_embedder.force_drop_modalities = force_drop_modalities + + args_cond, kwargs_cond = self._make_unconditional_args( + args_cond, + kwargs_cond, + ) + y_uncond = self.backbone(x, t, *args_cond, **kwargs_cond) + return self._cfg_step(y_cond, y_uncond, y_pm, strength, strength_pm) + + def forward(self, x, t, *args_cond, **kwargs_cond): + # handle case when no conditional arguments are provided + if len(args_cond) + len(kwargs_cond) == 0: # unconditional + if self.unconditional_handling != "discard": + raise RuntimeError( + f"cannot call `ClassifierFreeGuidance` module without condition" + ) + return self.backbone(x, t) + else: # conditional arguments are provided + # training mode + if self.training: + coin_flip = random.random() < self.p_unconditional + if coin_flip: # unconditional + args_cond, kwargs_cond = self._make_unconditional_args( + args_cond, + kwargs_cond, + ) + return self.backbone(x, t, *args_cond, **kwargs_cond) + else: # inference mode + strength = get_strength(self.strength, self.interval, t) + is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0]) + strength_pm = get_strength(self.strength_pm, self.interval, t) + return self.inner_forward( + x, t, is_cond, strength, strength_pm, *args_cond, **kwargs_cond + ) + +class ClassifierFreeGuidanceWithExternalUnconditionalProbability(ClassifierFreeGuidance): + + def __init__(self, *args, use_unconditional_from_flow_matching=False, **kwargs): + super().__init__(*args, **kwargs) + self.use_unconditional_from_flow_matching = use_unconditional_from_flow_matching + + def forward(self, x, t, *args_cond, p_unconditional=None, **kwargs_cond): + # p_unconditional should be a value in [0, 1], indicating the probability of unconditional + + if p_unconditional is None: + coin_flip = random.random() < self.p_unconditional + else: + coin_flip = random.random() < p_unconditional + + # handle case when no conditional arguments are provided + if len(args_cond) + len(kwargs_cond) == 0: # unconditional + if self.unconditional_handling != "discard": + raise RuntimeError( + f"cannot call `ClassifierFreeGuidance` module without condition" + ) + return self.backbone(x, t) + else: # conditional arguments are provided + # training mode + if self.training: + if coin_flip: # unconditional + args_cond, kwargs_cond = self._make_unconditional_args( + args_cond, + kwargs_cond, + ) + return self.backbone(x, t, *args_cond, **kwargs_cond) + else: # inference mode + strength = get_strength(self.strength, self.interval, t) + is_cond = not any(x > 0.0 for x in _pytree.tree_flatten(strength)[0]) + return self.inner_forward( + x, t, is_cond, strength, *args_cond, **kwargs_cond + ) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c76b7be7c3fa2917061000d34642dc42c28e5ae9 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/model.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Callable, Sequence, Union +import torch +import numpy as np +from functools import partial +import optree +import math + +from sam3d_objects.model.backbone.generator.base import Base +from sam3d_objects.data.utils import right_broadcasting +from sam3d_objects.data.utils import tree_tensor_map, tree_reduce_unique +from sam3d_objects.model.backbone.generator.flow_matching.solver import ( + ODESolver, + Euler, + Midpoint, + RungeKutta4, + gradient, + SDE, +) + +# default sampler in flow matching +uniform_sampler = torch.rand + + +# https://arxiv.org/pdf/2403.03206 +def lognorm_sampler(mean=0.0, std=1.0, **kwargs): + logit = torch.randn(**kwargs) * std + mean + return torch.nn.functional.sigmoid(logit) + + +# for backwards compatibility; please do not use this +def rev_lognorm_sampler(mean=0.0, std=1.0, **kwargs): + logit = torch.randn(**kwargs) * std + mean + return 1 - torch.nn.functional.sigmoid(logit) + + +# https://arxiv.org/pdf/2210.02747 +class FlowMatching(Base): + SOLVER_METHODS = { + "euler": Euler, + "midpoint": Midpoint, + "rk4": RungeKutta4, + "sde": SDE, + } + + def __init__( + self, + reverse_fn: Callable, + sigma_min: float = 0.0, # 0. = rectifier flow + inference_steps: int = 100, + time_scale: float = 1000.0, # scale [0,1]-time before passing to `reverse_fn` + training_time_sampler_fn: Callable = partial( + lognorm_sampler, + mean=0, + std=1, + ), + reversed_timestamp=False, + rescale_t=1.0, + loss_fn=partial(torch.nn.functional.mse_loss, reduction="mean"), + loss_weights=1.0, + solver_method: Union[str, ODESolver] = "euler", + solver_kwargs: dict = {}, + **kwargs, + ): + super().__init__(**kwargs) + + self.reverse_fn = reverse_fn + self.sigma_min = sigma_min + self.inference_steps = inference_steps + self.time_scale = time_scale + self.training_time_sampler_fn = training_time_sampler_fn + self.reversed_timestamp = reversed_timestamp + self.rescale_t = rescale_t + self.loss_fn = loss_fn + self.loss_weights = loss_weights + self._solver_method, self._solver = self._get_solver( + solver_method, solver_kwargs + ) + + def _get_solver(self, solver_method, solver_kwargs): + if solver_method in FlowMatching.SOLVER_METHODS: + solver = FlowMatching.SOLVER_METHODS[solver_method](**solver_kwargs) + elif isinstance(solver_method, ODESolver): + solver_method = f"custom[{solver_method.__class__.__name__}]" + solver = solver_method + else: + raise ValueError( + f"Invalid solver `{solver_method}`, should be in {set(self.SOLVER_METHODS.keys())} or an ODESolver instance" + ) + return solver_method, solver + + def _generate_noise_tensor(self, x_shape, x_device): + return torch.randn( + x_shape, + # generator=self.random_generator, + device=x_device, + ) + + def _generate_noise(self, x_shape, x_device): + def is_shape(maybe_shape): + return isinstance(maybe_shape, Sequence) and all( + (isinstance(s, int) and s >= 0) for s in maybe_shape + ) + + return optree.tree_map( + partial(self._generate_noise_tensor, x_device=x_device), + x_shape, + is_leaf=is_shape, + none_is_leaf=False, + ) + + def _generate_x0_tensor(self, x1: torch.Tensor): + x0 = self._generate_noise_tensor(x1.shape, x1.device) + return x0 + + def _generate_xt_tensor(self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor): + # equation (22) + tb = right_broadcasting(t.to(x1.device), x1) + x_t = (1 - (1 - self.sigma_min) * tb) * x0 + tb * x1 + + return x_t + + def _generate_target_tensor(self, x0: torch.Tensor, x1: torch.Tensor): + # equation (23) + target = x1 - (1 - self.sigma_min) * x0 + + return target + + def _generate_x0(self, x1): + return tree_tensor_map(self._generate_x0_tensor, x1) + + def _generate_xt(self, x0, x1, t): + return tree_tensor_map( + partial(self._generate_xt_tensor, t=t), + x0, + x1, + ) + + def _generate_target(self, x0, x1): + return tree_tensor_map( + self._generate_target_tensor, + x0, + x1, + ) + + def _generate_t(self, x1): + first_tensor = optree.tree_flatten(x1)[0][0] + batch_size = first_tensor.shape[0] + device = first_tensor.device + + t = self.training_time_sampler_fn( + size=(batch_size,), + generator=self.random_generator, + ).to(device) + + return t + + def loss(self, x1: torch.Tensor, *args_conditionals, **kwargs_conditionals): + t = self._generate_t(x1) + x0 = self._generate_x0(x1) + x_t = self._generate_xt(x0, x1, t) + target = self._generate_target(x0, x1) + + prediction = self.reverse_fn( + x_t, + t * self.time_scale, + *args_conditionals, + **kwargs_conditionals, + ) + + # broadcast & and compute loss + loss = optree.tree_broadcast_map( + lambda fn, weight, pred, targ: weight * fn(pred, targ), + self.loss_fn, + self.loss_weights, + prediction, + target, + ) + + total_loss = sum(optree.tree_flatten(loss)[0]) + + # Create detailed loss breakdown + detail_losses = { + "flow_matching_loss": total_loss, + } + if isinstance(loss, dict): + detail_losses.update(loss) + return total_loss, detail_losses + + def _prepare_t(self, steps=None): + steps = self.inference_steps if steps is None else steps + t_seq = torch.linspace(0, 1, steps + 1) + + if self.rescale_t: + t_seq = t_seq / (1 + (self.rescale_t - 1) * (1 - t_seq)) + + if self.reversed_timestamp: + t_seq = 1 - t_seq + + return t_seq + + def generate_iter( + self, + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ): + x_0 = self._generate_noise(x_shape, x_device) + t_seq = self._prepare_t().to(x_device) + + for x_t, t in self._solver.solve_iter( + self._generate_dynamics, + x_0, + t_seq, + *args_conditionals, + **kwargs_conditionals, + ): + yield t, x_t, () + + def _generate_dynamics( + self, + x_t, + t, + *args_conditionals, + **kwargs_conditionals, + ): + return self.reverse_fn(x_t, t * self.time_scale, *args_conditionals, **kwargs_conditionals) + + def _log_p0(self, x0): + x0 = self._tree_flatten(x0) + inside_exp = -(x0**2).sum(dim=1) / 2 + return inside_exp - math.log(2 * math.pi) / 2 * x0.shape[1] + + def log_likelihood( + self, + x1, + solver=None, + steps=None, + z_samples=1, + *args_conditionals, + **kwargs_conditionals, + ): + device = tree_reduce_unique(lambda tensor: tensor.device, x1) + # device = "cuda" + t_seq = self._prepare_t(steps).to(device) + t_seq = 1 - t_seq # from x1 to x0 + solver = self._solver if solver is None else self._get_solver(solver)[1] + + x_0 = solver.solve( + partial(self._log_likelihood_dynamics, device=device, z_samples=z_samples), + {"x": x1, "log_p": 0.0}, + t_seq, + *args_conditionals, + **kwargs_conditionals, + ) + + log_p1 = x_0["log_p"] + self._log_p0(x_0["x"]) + + return log_p1 + + def _log_likelihood_dynamics( + self, + state, + t, + device, + z_samples, + *args_conditionals, + **kwargs_conditionals, + ): + t = torch.tensor([t * self.time_scale], device=device, dtype=torch.float32) + x_t = state["x"] + + with torch.set_grad_enabled(True): + tree_tensor_map(lambda x,: x.requires_grad_(True), x_t) + velocity = self.reverse_fn( + x_t, + t, + *args_conditionals, + **kwargs_conditionals, + ) + + # compute the divergence estimate + div = self._compute_hutchinson_divergence(velocity, x_t, z_samples) + + tree_tensor_map(lambda x,: x.requires_grad_(False), x_t) + velocity = tree_tensor_map(lambda x: x.detach(), velocity) + + return {"x": velocity, "log_p": div.detach()} + + def _tree_flatten(self, tree): + flat_x = tree_tensor_map(lambda x: x.flatten(start_dim=1), tree) + flat_x, _ = optree.tree_flatten( + flat_x, + is_leaf=lambda x: isinstance(x, torch.Tensor), + ) + flat_x = torch.cat(flat_x, dim=1) + return flat_x + + def _compute_hutchinson_divergence(self, velocity, x_t, z_samples): + flat_velocity = self._tree_flatten(velocity) + flat_velocity = flat_velocity.unsqueeze(-1) + + z = torch.randn( + flat_velocity.shape[:-1] + (z_samples,), + dtype=flat_velocity.dtype, + device=flat_velocity.device, + ) + z = z < 0 + z = z * 2.0 - 1.0 + z = z / math.sqrt(z_samples) + + # compute Hutchinson divergence estimator E[z^T D_x(vt) z] = E[D_x(z^T vt) z)] + vt_dot_z = torch.einsum("ijk,ijk->ik", flat_velocity, z) + grad_vt_dot_z = [ + gradient(vt_dot_z[..., i], x_t, create_graph=(z_samples > 1)) + for i in range(z_samples) + ] + grad_vt_dot_z = [self._tree_flatten(g) for g in grad_vt_dot_z] + grad_vt_dot_z = torch.stack(grad_vt_dot_z, dim=-1) + div = torch.einsum("ijk,ijk->i", grad_vt_dot_z, z) + return div + + +def _get_device(x): + device = tree_reduce_unique(lambda tensor: tensor.device, x) + return device + + +class ConditionalFlowMatching(FlowMatching): + def generate_iter( + self, + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ): + x_0 = self._generate_noise(x_shape, x_device) + t_seq = self._prepare_t().to(x_device) + + noise_override = None + if "noise_override" in kwargs_conditionals: + noise_override = kwargs_conditionals["noise_override"] + del kwargs_conditionals["noise_override"] + if noise_override is not None: + if type(x_0) == dict: + x_0.update(noise_override) + else: + x_0 = noise_override + + for x_t, t in self._solver.solve_iter( + self._generate_dynamics, + x_0, + t_seq, + *args_conditionals, + **kwargs_conditionals, + ): + if noise_override is not None: + if type(noise_override) == dict: + x_t.update(noise_override) + else: + x_t = noise_override + yield t, x_t, () diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..ead1b58c921faebe0451227200bc93d16bda52aa --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/flow_matching/solver.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import optree +import torch +from functools import partial + +from sam3d_objects.data.utils import tree_tensor_map + + +def linear_approximation_step(x_t, dt, velocity): + # x_tp1 = x_t + velocity * dt + x_tp1 = tree_tensor_map(lambda x, v: x + v * dt, x_t, velocity) + return x_tp1 + + +def gradient(output, x, create_graph: bool = False): + tensors, pyspec = optree.tree_flatten( + x, is_leaf=lambda x: isinstance(x, torch.Tensor) + ) + grad_outputs = [torch.ones_like(output).detach() for _ in tensors] + grads = torch.autograd.grad( + output, + tensors, + grad_outputs=grad_outputs, + create_graph=create_graph, + ) + return optree.tree_unflatten(pyspec, grads) + + +class ODESolver: + def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + raise NotImplementedError + + def solve_iter(self, dynamics_fn, x_init, times, *args, **kwargs): + x_t = x_init + for t0, t1 in zip(times[:-1], times[1:]): + dt = t1 - t0 + x_t = self.step(dynamics_fn, x_t, t0, dt, *args, **kwargs) + yield x_t, t0 + + def solve(self, dynamics_fn, x_init, times, *args, **kwargs): + for x_t, _ in self.solve_iter(dynamics_fn, x_init, times, *args, **kwargs): + pass + return x_t + + +# https://en.wikipedia.org/wiki/Euler_method +class Euler(ODESolver): + def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + velocity = dynamics_fn(x_t, t, *args, **kwargs) + x_tp1 = linear_approximation_step(x_t, dt, velocity) + return x_tp1 + + +# https://arxiv.org/abs/2505.05470 +class SDE(ODESolver): + def __init__(self, **kwargs): + super().__init__() + self.sde_strength = kwargs.get("sde_strength", 0.1) + + def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + velocity = dynamics_fn(x_t, t, *args, **kwargs) + sigma = 1 - t + var_t = sigma / (1 - torch.tensor(sigma).clamp(min=dt)) + std_dev_t = ( + torch.sqrt(variance) * self.sde_strength + ) # self.sde_strength = alpha + + def compute_mean(x, v): + drift_term = x * (std_dev_t**2 / (2 * sigma) * dt) + velocity_term = v * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt + return x + drift_term + velocity_term + + prev_sample_mean = tree_tensor_map(compute_mean, x_t, velocity) + + # Generate noise and compute final sample using tree_tensor_map + def add_noise(mean_val): + variance_noise = torch.randn_like(mean_val) + return mean_val + std_dev_t * torch.sqrt(torch.tensor(dt)) * variance_noise + + prev_sample = tree_tensor_map(add_noise, prev_sample_mean) + + return prev_sample + + +# https://en.wikipedia.org/wiki/Midpoint_method +class Midpoint(ODESolver): + def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + half_dt = 0.5 * dt + + x_mid = Euler.step(self, dynamics_fn, x_t, t, half_dt, *args, **kwargs) + + velocity_mid = dynamics_fn(x_mid, t + half_dt, *args, **kwargs) + x_tp1 = linear_approximation_step(x_t, dt, velocity_mid) + return x_tp1 + + +# https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods +class RungeKutta4(ODESolver): + + def k1(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + return dynamics_fn(x_t, t, *args, **kwargs) + + def k2(self, dynamics_fn, x_t, t, dt, k1, *args, **kwargs): + x_k1 = linear_approximation_step(x_t, dt * 0.5, k1) + return dynamics_fn(x_k1, t + dt * 0.5, *args, **kwargs) + + def k3(self, dynamics_fn, x_t, t, dt, k2, *args, **kwargs): + x_k2 = linear_approximation_step(x_t, dt * 0.5, k2) + return dynamics_fn(x_k2, t + dt * 0.5, *args, **kwargs) + + def k4(self, dynamics_fn, x_t, t, dt, k3, *args, **kwargs): + x_k3 = linear_approximation_step(x_t, dt, k3) + return dynamics_fn(x_k3, t + dt, *args, **kwargs) + + def step(self, dynamics_fn, x_t, t, dt, *args, **kwargs): + k1 = self.k1(dynamics_fn, x_t, t, dt, *args, **kwargs) + k2 = self.k2(dynamics_fn, x_t, t, dt, k1, *args, **kwargs) + k3 = self.k3(dynamics_fn, x_t, t, dt, k2, *args, **kwargs) + k4 = self.k4(dynamics_fn, x_t, t, dt, k3, *args, **kwargs) + + def compute_velocity(k1, k2, k3, k4): + return (k1 + 2 * k2 + 2 * k3 + k4) / 6 + + velocity_k = tree_tensor_map(compute_velocity, k1, k2, k3, k4) + x_tp1 = linear_approximation_step(x_t, dt, velocity_k) + return x_tp1 diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/model.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4d93d62f66e2275f68822dabc7bd485074369924 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/generator/shortcut/model.py @@ -0,0 +1,450 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import random +from typing import Callable, Sequence, Union +import torch +import numpy as np +from functools import partial +import optree +import math + +from sam3d_objects.model.backbone.generator.base import Base +from sam3d_objects.data.utils import right_broadcasting +from sam3d_objects.data.utils import tree_tensor_map, tree_reduce_unique +from sam3d_objects.model.backbone.generator.flow_matching.model import FlowMatching, _get_device +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import copy + + +# https://arxiv.org/pdf/2410.12557 +class ShortCut(FlowMatching): + def __init__( + self, + no_shortcut=False, + self_consistency_prob=0.25, + shortcut_loss_weight=1.0, + self_consistency_cfg_strength=3.0, + ratio_cfg_samples_in_self_consistency_target=0.5, + fm_in_shortcut_target_prob=0.0, + fm_eps_max=0, + batch_mode=False, + cfg_modalities=["shape"], + **kwargs, + ): + super().__init__(**kwargs) + self.no_shortcut = no_shortcut + self.self_consistency_prob = self_consistency_prob + self.shortcut_loss_weight = shortcut_loss_weight + self.self_consistency_cfg_strength = self_consistency_cfg_strength + self.ratio_cfg_samples_in_self_consistency_target = ratio_cfg_samples_in_self_consistency_target + self.fm_in_shortcut_target_prob = fm_in_shortcut_target_prob + self.fm_eps_max = fm_eps_max + self.batch_mode = batch_mode + self.cfg_modalities = cfg_modalities + + def _generate_d(self, x1): + """ + Generate shortcut step sizes d with binary-time schedule. + + This method ensures deterministic behavior for distributed training: + - Exactly self_consistency_prob fraction of samples will have d > 0 (self-consistency) + - Remaining samples will have d = 0 (flow matching) + - All distributed ranks will have consistent counts, preventing deadlocks + + Args: + x1: Input tensor or tree of tensors + + Returns: + d: Tensor of step sizes with shape [batch_size] + """ + first_tensor = optree.tree_flatten(x1)[0][0] + batch_size = first_tensor.shape[0] + device = first_tensor.device + + # Use binary-time schedule: d ∈ {1/2^i for i in range(8)} + base = [1 / 2**i for i in range(8)] + + # Deterministic approach: exactly self_consistency_prob fraction will have d>0 + # This ensures all distributed ranks have consistent behavior + if self.batch_mode: + num_self_consistency_samples = int(random.random() < self.self_consistency_prob) * batch_size + else: + num_self_consistency_samples = int(batch_size * self.self_consistency_prob) + num_flow_matching_samples = batch_size - num_self_consistency_samples + + # Create deterministic d values + d = torch.zeros(batch_size, device=device) + + if num_self_consistency_samples > 0: + # Randomly select d values for self-consistency samples + selected_elements = random.choices(base, k=num_self_consistency_samples) + d[:num_self_consistency_samples] = torch.FloatTensor(selected_elements).to(device) + + # Shuffle the d values to randomize which samples get which d values + # This maintains the deterministic count while randomizing positions + shuffle_indices = torch.randperm(batch_size, device=device) + d = d[shuffle_indices] + + return d + + @torch.no_grad() + def compute_self_consistency_target(self, x_t, t, d, *args_conditionals, **kwargs_conditionals): + """ + Compute self-consistency target for shortcut model's self-consistency objective. + + This method uses a mixed approach where: + - First 25% of samples (num_cfg_samples) use CFG blending with strength 7.0 + - Remaining 75% of samples use conditional-only targets + + Safety guarantees: + - Ensures at least 1 sample in CFG part (num_cfg_samples >= 1 for batch_size >= 2) + - For batch_size < 2, falls back to all conditional-only (no CFG) + - Handles edge cases where batch size is too small for mixed approach + + The process involves: + 1. Forward all samples through conditional model to get s_t_cond and s_td_cond + 2. Forward first num_cfg_samples through unconditional model to get s_t_uncond and s_td_uncond + 3. Apply CFG blending: (1 + strength) * cond - strength * uncond for first num_cfg_samples + 4. Concatenate CFG results with conditional-only results for remaining samples + 5. Average the two velocities to get final self-consistency target + """ + # CFG strength for self-consistency target computation + self_consistency_cfg_strength = self.self_consistency_cfg_strength + + # Mixed approach: configurable ratio of CFG:conditional-only samples + batch_size = x_t.shape[0] if not isinstance(x_t, dict) else next(iter(x_t.values())).shape[0] + if self.batch_mode: + num_cfg_samples = int(random.random() < self.ratio_cfg_samples_in_self_consistency_target) * batch_size + else: + num_cfg_samples = int(batch_size * self.ratio_cfg_samples_in_self_consistency_target) # Configurable ratio for CFG + num_cond_only_samples = batch_size - num_cfg_samples # Remaining for conditional-only + use_fm_in_shortcut_target = random.random() < self.fm_in_shortcut_target_prob + + # Handle edge case where batch_size < 2 (fallback to all conditional-only) + # if batch_size < 2: + # num_cfg_samples = 0 + # num_cond_only_samples = batch_size + + + # ### DEBUG ############################### + # num_cfg_samples = 0 + # num_cond_only_samples = batch_size + # # ### DEBUG ############################### + + # Step 1: Get velocity predictions at current time t + # Forward all samples through conditional model + s_t_cond = self.reverse_fn( + x_t, + t * self.time_scale, + *args_conditionals, + d=d * self.time_scale if not use_fm_in_shortcut_target else d * self.time_scale * 0, + p_unconditional=0.0, + **kwargs_conditionals, + ) + + # Handle CFG and conditional-only parts + if num_cfg_samples > 0: + # Forward first num_cfg_samples through unconditional model + if isinstance(x_t, dict): + x_t_cfg = {k: v[:num_cfg_samples] for k, v in x_t.items()} + else: + x_t_cfg = x_t[:num_cfg_samples] + + s_t_uncond = self.reverse_fn( + x_t_cfg, + t[:num_cfg_samples] * self.time_scale, + *(arg[:num_cfg_samples] if not self.batch_mode and torch.is_tensor(arg) else arg for arg in args_conditionals), + d=d[:num_cfg_samples] * self.time_scale if not use_fm_in_shortcut_target else d[:num_cfg_samples] * self.time_scale * 0, + p_unconditional=1.0, + **{k: v[:num_cfg_samples] if not self.batch_mode and torch.is_tensor(v) else v for k, v in kwargs_conditionals.items()}, + ) + + # Apply CFG blending for first num_cfg_samples using our standard formula + s_t_cfg = tree_tensor_map( + lambda cond, uncond: (1 + self_consistency_cfg_strength) * cond - self_consistency_cfg_strength * uncond, + tree_tensor_map(lambda x: x[:num_cfg_samples], s_t_cond), s_t_uncond + ) + + # Combine CFG results with conditional-only results for remaining samples + if num_cond_only_samples > 0: + s_t = tree_tensor_map( + lambda cfg, cond: torch.cat([cfg, cond[num_cfg_samples:]], dim=0), + s_t_cfg, s_t_cond + ) + else: + # All samples use CFG + s_t = s_t_cond + if isinstance(s_t_cond, dict): + for modality in self.cfg_modalities: + s_t[modality] = s_t_cfg[modality] + else: + s_t = s_t_cfg + else: + # All samples use conditional-only (fallback for very small batches) + s_t = s_t_cond + + # Step 2: Take a step of size d using current velocity + x_td = tree_tensor_map(lambda x, v: x + v * d[..., None, None], x_t, s_t) + + # Step 3: Get velocity predictions at time t+d + # Forward all samples through conditional model at t+d + s_td_cond = self.reverse_fn( + x_td, + (t + d) * self.time_scale, + *args_conditionals, + d=d * self.time_scale if not use_fm_in_shortcut_target else d * self.time_scale * 0, + p_unconditional=0.0, + **kwargs_conditionals, + ) + + # Handle CFG and conditional-only parts at t+d + if num_cfg_samples > 0: + # Forward first num_cfg_samples through unconditional model at t+d + if isinstance(x_td, dict): + x_td_cfg = {k: v[:num_cfg_samples] for k, v in x_td.items()} + else: + x_td_cfg = x_td[:num_cfg_samples] + + s_td_uncond = self.reverse_fn( + x_td_cfg, + (t + d)[:num_cfg_samples] * self.time_scale, + *(arg[:num_cfg_samples] if not self.batch_mode and torch.is_tensor(arg) else arg for arg in args_conditionals), + d=d[:num_cfg_samples] * self.time_scale if not use_fm_in_shortcut_target else d[:num_cfg_samples] * self.time_scale * 0, + p_unconditional=1.0, + **{k: v[:num_cfg_samples] if not self.batch_mode and torch.is_tensor(v) else v for k, v in kwargs_conditionals.items()}, + ) + + # Apply CFG blending for first num_cfg_samples at t+d using our standard formula + s_td_cfg = tree_tensor_map( + lambda cond, uncond: (1 + self_consistency_cfg_strength) * cond - self_consistency_cfg_strength * uncond, + tree_tensor_map(lambda x: x[:num_cfg_samples], s_td_cond), s_td_uncond + ) + + # Combine CFG results with conditional-only results for remaining samples at t+d + if num_cond_only_samples > 0: + s_td = tree_tensor_map( + lambda cfg, cond: torch.cat([cfg, cond[num_cfg_samples:]], dim=0), + s_td_cfg, s_td_cond + ) + else: + # All samples use CFG + s_td = s_td_cond + if isinstance(s_td_cond, dict): + for modality in self.cfg_modalities: + s_td[modality] = s_td_cfg[modality] + else: + s_td = s_td_cfg + else: + # All samples use conditional-only (fallback for very small batches) + s_td = s_td_cond + + # Step 4: Compute self-consistency target as average of two velocities + s_target = tree_tensor_map(lambda a, b: (a + b).detach() / 2, s_t, s_td) + + return s_target + + def _generate_t_and_d(self, x1): + """ + Generate t and d together according to shortcut models paper. + + According to the paper: "During training, we first sample d, then sample t only at the discrete + points for which the shortcut model will be queried, i.e. multiples of d. We train the + self-consistency objective only at these timesteps." + + This ensures that when d > 0 (self-consistency samples), t is sampled at multiples of d. + When d = 0 (flow matching samples), t can be sampled normally. + """ + first_tensor = optree.tree_flatten(x1)[0][0] + batch_size = first_tensor.shape[0] + device = first_tensor.device + + # First sample d + d = self._generate_d(x1) + + # Then sample t based on d + t = torch.zeros(batch_size, device=device) + + # For flow matching samples (d = 0), sample t normally + flow_matching_mask = (d == 0) + if flow_matching_mask.any(): + num_flow_samples = flow_matching_mask.sum().item() + t_flow = self.training_time_sampler_fn( + size=(num_flow_samples,), + generator=self.random_generator, + ).to(device) + t[flow_matching_mask] = t_flow + + # For self-consistency samples (d > 0), sample t at multiples of d + self_consistency_mask = (d > 0) + if self_consistency_mask.any(): + d_nonzero = d[self_consistency_mask] + # Sample how many multiples of d to use for each sample + # We want t to be k*d where k is a random integer such that t ∈ [0, 1-d] + # This ensures t + d ≤ 1 + max_multiples = torch.floor((1.0 - d_nonzero) / d_nonzero).long() + # Ensure max_multiples is at least 0 to avoid empty range + max_multiples = torch.clamp(max_multiples, min=0) + + # For each sample, randomly choose k from [0, max_multiples] - vectorized + # Generate random values [0, 1) for all samples + random_vals = torch.rand_like(d_nonzero) + # Scale to [0, max_multiples + 1) and floor to get integers [0, max_multiples] + k_values = torch.floor(random_vals * (max_multiples.float() + 1)) + # Compute t = k * d for all samples + t_self_consistency = k_values * d_nonzero + + t[self_consistency_mask] = t_self_consistency + + return t, d + + def loss(self, x1: torch.Tensor, *args_conditionals, **kwargs_conditionals): + """Compute shortcut model loss with mixed flow matching and self-consistency objectives""" + # t, d = self._generate_t_and_d(x1) + t = self._generate_t(x1) + d = self._generate_d(x1) + x0 = self._generate_x0(x1) + x_t = self._generate_xt(x0, x1, t) + + # Determine which samples use flow matching vs self-consistency + flow_matching_indices = (d == 0).nonzero(as_tuple=False).squeeze(-1) # 75% of the time use d=0 (flow matching), 25% use self-consistency + self_consistency_indices = (d > 0).nonzero(as_tuple=False).squeeze(-1) + d[d == 0] = torch.rand_like(d[d == 0]) * self.fm_eps_max + + # Clear autocast cache for gradient computation + torch.clear_autocast_cache() + + # Get model prediction + s = self.reverse_fn( + x_t, + t * self.time_scale, + *args_conditionals, + d=2 * d * self.time_scale, + **kwargs_conditionals, + ) + + # Compute component losses separately by selecting relevant indices + flow_matching_loss_val = torch.tensor(0.0, device=d.device, dtype=torch.float32) + self_consistency_loss_val = torch.tensor(0.0, device=d.device, dtype=torch.float32) + + # Flow matching component (for d=0 samples) + if len(flow_matching_indices) > 0: + # Select samples where d=0 and compute flow matching target only for these samples + x0_flow = tree_tensor_map(lambda x: x[flow_matching_indices], x0) + x1_flow = tree_tensor_map(lambda x: x[flow_matching_indices], x1) + s_flow = tree_tensor_map(lambda x: x[flow_matching_indices], s) + + # Compute flow matching target only for selected samples + flow_matching_target = self._generate_target(x0_flow, x1_flow) + + flow_matching_loss = optree.tree_broadcast_map( + lambda fn, weight, pred, targ: weight * fn(pred, targ), + self.loss_fn, + self.loss_weights, + s_flow, + flow_matching_target, + ) + flow_matching_loss_val = sum(optree.tree_flatten(flow_matching_loss)[0]) + + # Shortcut self-consistency component (for d>0 samples) + if len(self_consistency_indices) > 0: + # Select samples where d>0 and compute self-consistency target only for these samples + x_t_shortcut = tree_tensor_map(lambda x: x[self_consistency_indices], x_t) + t_shortcut = t[self_consistency_indices] + d_shortcut = d[self_consistency_indices] + s_shortcut = tree_tensor_map(lambda x: x[self_consistency_indices], s) + + # Create conditional arguments for selected samples + if self.batch_mode: + args_conditionals_shortcut = args_conditionals + kwargs_conditionals_shortcut = kwargs_conditionals + else: + args_conditionals_shortcut = tuple( + tree_tensor_map(lambda x: x[self_consistency_indices], arg) if torch.is_tensor(arg) else arg + for arg in args_conditionals + ) + kwargs_conditionals_shortcut = { + k: (tree_tensor_map(lambda x: x[self_consistency_indices], v) if torch.is_tensor(v) else v) + for k, v in kwargs_conditionals.items() + } + + # Compute self-consistency target only for selected samples + self_consistency_target = self.compute_self_consistency_target( + x_t_shortcut, t_shortcut, d_shortcut, + *args_conditionals_shortcut, **kwargs_conditionals_shortcut + ) + + self_consistency_loss = optree.tree_broadcast_map( + lambda fn, weight, pred, targ: weight * fn(pred, targ), + self.loss_fn, + self.loss_weights, + s_shortcut, + self_consistency_target, + ) + self_consistency_loss_val = sum(optree.tree_flatten(self_consistency_loss)[0]) + + # Total loss is the sum of both components (linear combination) + total_loss = flow_matching_loss_val + self.shortcut_loss_weight * self_consistency_loss_val + + # Create detailed loss breakdown + detail_losses = { + "flow_matching_loss": flow_matching_loss_val, + "self_consistency_loss": self_consistency_loss_val, + } + return total_loss, detail_losses + + def _prepare_t_and_d(self, steps=None): + """Prepare time sequence and step size for inference""" + steps = self.inference_steps if steps is None else steps + t_seq = np.linspace(0, 1, steps + 1) + + if self.no_shortcut: + d = 0 + else: + # Use uniform step size for inference + d = 1 / steps + + if self.rescale_t: + t_seq = t_seq / (1 + (self.rescale_t - 1) * (1 - t_seq)) + + if self.reversed_timestamp: + t_seq = 1 - t_seq + + return t_seq, d + + def generate_iter( + self, + x_shape, + x_device, + *args_conditionals, + **kwargs_conditionals, + ): + """Generate samples using shortcut model""" + x_0 = self._generate_noise(x_shape, x_device) + t_seq, d = self._prepare_t_and_d() + + for x_t, t in self._solver.solve_iter( + self._generate_dynamics, + x_0, + t_seq, + d, + *args_conditionals, + **kwargs_conditionals, + ): + yield t, x_t, () + + def _generate_dynamics( + self, + x_t, + t, + d, + *args_conditionals, + **kwargs_conditionals, + ): + """Generate dynamics for ODE solver""" + t = torch.tensor( + [t * self.time_scale], device=_get_device(x_t), dtype=torch.float32 + ) + d = torch.tensor( + [d * self.time_scale], device=_get_device(x_t), dtype=torch.float32 + ) + return self.reverse_fn(x_t, t, *args_conditionals, d=d, **kwargs_conditionals) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad8496fbc6d95191e1c63409794b4453f9a2c69a --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +from .sparse_structure_flow import SparseStructureFlowModel, SparseStructureFlowTdfyWrapper +from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder +from .structured_latent_vae import SLatGaussianDecoder +from .structured_latent_flow import SLatFlowModel, SLatFlowModelTdfyWrapper + +def from_pretrained(path: str, **kwargs): + """ + Load a model from a pretrained checkpoint. + + Args: + path: The path to the checkpoint. Can be either local path or a Hugging Face model name. + NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. + **kwargs: Additional arguments for the model constructor. + """ + import os + import json + from safetensors.torch import load_file + + is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors") + + if is_local: + config_file = f"{path}.json" + model_file = f"{path}.safetensors" + else: + from huggingface_hub import hf_hub_download + + path_parts = path.split("/") + repo_id = f"{path_parts[0]}/{path_parts[1]}" + model_name = "/".join(path_parts[2:]) + config_file = hf_hub_download(repo_id, f"{model_name}.json") + model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") + + with open(config_file, "r") as f: + config = json.load(f) + model = __getattr__(config["name"])(**config["args"], **kwargs) + model.load_state_dict(load_file(model_file)) + + return model diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mm_latent.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mm_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..4abacfea9298097430c3a51846f383dd998324c8 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mm_latent.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from torch import nn +from ..modules.transformer import ( + AbsolutePositionEmbedder, +) +from ..modules import spatial +import torch.nn.functional as F +from typing import Callable + + +class Latent(nn.Module): + def __init__( + self, in_channels, model_channels: int, pos_embedder: Callable[[], torch.Tensor] + ): + super().__init__() + self.input_layer = nn.Linear(in_channels, model_channels) + self.out_layer = nn.Linear(model_channels, in_channels) + + pos_emb = pos_embedder() + if isinstance(pos_emb, torch.nn.Parameter): + # learnt position embedding + self.register_parameter("pos_emb", pos_emb) + elif isinstance(pos_emb, torch.Tensor): + # fixed position embedding + self.register_buffer("pos_emb", pos_emb) + else: + raise NotImplementedError + + self.initialize_weights() + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Zero-out output layers: + # nn.init.constant_(self.out_layer.weight, 0) + nn.init.xavier_uniform_(self.out_layer.weight) + nn.init.constant_(self.out_layer.bias, 0) + + def to_input(self, x): + x = self.input_layer(x) + x = x + self.pos_emb[None] + + return x + + def to_output(self, h): + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + return h + + +def ShapePositionEmbedder(model_channels, resolution, patch_size): + def embedder(): + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid( + *[torch.arange(res) for res in [resolution // patch_size] * 3], + indexing="ij", + ) + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + + return pos_emb + + return embedder + + +def RandomPositionEmbedder(model_channels, token_len): + def embedder(): + pos_emb = torch.randn(token_len, model_channels) + + return pos_emb + + return embedder + + +def LearntPositionEmbedder(model_channels, token_len): + def embedder(): + pos_emb = torch.nn.Parameter(torch.randn(token_len, model_channels)) + + return pos_emb + + return embedder diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mot_sparse_structure_flow.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mot_sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..b2abb0ed8ef941dc46977cca70bc617c9f5f07b9 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/mot_sparse_structure_flow.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from functools import partial +from typing import * +from torch.utils import _pytree +import torch +import torch.nn as nn +from ..modules.utils import convert_module_to_f16, convert_module_to_f32 +from collections import namedtuple +from ..modules.utils import FP16_TYPE +from ..modules.transformer import ( + MOTModulatedTransformerCrossBlock, +) +from sam3d_objects.data.utils import ( + tree_reduce_unique, +) +from .timestep_embedder import TimestepEmbedder +from omegaconf import OmegaConf + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + freeze_shared_parameters: bool = False, + is_shortcut_model: bool = False, + *args, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = FP16_TYPE if use_fp16 else torch.float32 + self.is_shortcut_model = is_shortcut_model + if is_shortcut_model: + self.d_embedder = TimestepEmbedder(model_channels) # for shortcut model + + self.t_embedder = TimestepEmbedder(model_channels, freeze=freeze_shared_parameters) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + self.blocks = nn.ModuleList( + [ + MOTModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode="full", + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + latent_names=self.latent_names, + freeze_shared_parameters=freeze_shared_parameters + ) + for _ in range(num_blocks) + ] + ) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # zero init like controlnet, for MLP should only zero + # the weight of the last layer only + if self.is_shortcut_model: + nn.init.constant_(self.d_embedder.mlp[2].weight, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + def _cast_type(self, x, dtype): + return x.type(dtype) + + def forward( + self, + h: Dict, + t: torch.Tensor, + cond: torch.Tensor, + d: torch.Tensor = None, + ) -> torch.Tensor: + t_emb = self.t_embedder(t) + if d is not None: + d_emb = self.d_embedder(d) + t_emb = t_emb + d_emb + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + + input_dtype = tree_reduce_unique(lambda tensor: tensor.dtype, h) + t_emb = t_emb.type(self.dtype) + h = _pytree.tree_map( + partial(self._cast_type, dtype=self.dtype), + h, + ) + cond = cond.type(self.dtype) + + for block in self.blocks: + h = block(h, t_emb, cond) + + h = _pytree.tree_map( + partial(self._cast_type, dtype=input_dtype), + h, + ) + + return h + + +class SparseStructureFlowTdfyWrapper(SparseStructureFlowModel): + def __init__( + self, + latent_mapping: dict, + latent_share_transformer: dict = {}, + *args, + **kwargs, + ): + condition_embedder = kwargs.pop("condition_embedder", None) + # if enabled, model will record the condition_shape in one run and uses zeros for all that afterwards + force_zeros_cond = kwargs.pop("force_zeros_cond", False) + # backward compatible to models trained before PR #87 + kwargs.pop("shape_attend_pose", None) + merge_latent_names = [i for _, v in latent_share_transformer.items() for i in v] + self.latent_names = [ + latent_name + for latent_name in list(latent_mapping.keys()) + if latent_name not in merge_latent_names + ] + list(latent_share_transformer.keys()) + super().__init__(*args, **kwargs) + if condition_embedder is not None: + self.condition_embedder = condition_embedder + else: + self.condition_embedder = lambda x: x + self.force_zeros_cond = force_zeros_cond + self.latent_mapping = nn.ModuleDict(latent_mapping) + if not isinstance(latent_share_transformer, dict): + self.latent_share_transformer = OmegaConf.to_container(latent_share_transformer) + else: + self.latent_share_transformer = latent_share_transformer + self.input_latent_mappings = list(self.latent_mapping.keys()) + + def forward( + self, + latents_dict: dict, + t: torch.Tensor, + *condition_args, + **condition_kwargs, + ) -> dict: + d = condition_kwargs.pop("d", None) + + cfg_activate = condition_kwargs.pop("cfg", False) + if self.force_zeros_cond and cfg_activate: + cond = self.condition_embedder(*condition_args, **condition_kwargs) + cond = cond * 0 + else: + cond = self.condition_embedder(*condition_args, **condition_kwargs) + + # concatenate input + latent_dict = self.project_input(latents_dict) + output = super().forward(latent_dict, t, cond, d) + + # split input to multiple output modalities + output_latents = self.project_output(output) + + return output_latents + + def project_input( + self, + latents_dict: Dict, + ) -> Dict: + # concatenate input from multiple modalities + latent_dict = {} + for latent_name in self.input_latent_mappings: + assert ( + latent_name in latents_dict + ), f"'{latent_name}' not found in latents_dict" + latent_input = latents_dict[latent_name] + x = self.latent_mapping[latent_name].to_input(latent_input) + latent_dict[latent_name] = x + + latent_dict = self.merge_latent_share_transformer(latent_dict) + return latent_dict + + def project_output(self, output: Dict) -> Dict: + output = self.split_latent_share_transformer(output) + output_latents = {} + for latent_name in self.input_latent_mappings: + latent = self.latent_mapping[latent_name].to_output(output[latent_name]) + output_latents[latent_name] = latent + + return output_latents + + def merge_latent_share_transformer(self, latent_dict): + visited_latent_names = set() + return_dict = {} + for merged_name, latent_names in self.latent_share_transformer.items(): + tensors = [] + for latent_name in latent_names: + visited_latent_names.add(latent_name) + tensors.append(latent_dict[latent_name]) + tensors = torch.cat(tensors, dim=1) + return_dict[merged_name] = tensors + + for latent_name in latent_dict: + if latent_name not in visited_latent_names: + return_dict[latent_name] = latent_dict[latent_name] + + return return_dict + + def split_latent_share_transformer(self, output_latents): + return_dict = {} + visited_latent_names = set() + for merged_name, latent_names in self.latent_share_transformer.items(): + start = 0 + visited_latent_names.add(merged_name) + tensors = output_latents[merged_name] + for latent_name in latent_names: + token_len = self.latent_mapping[latent_name].pos_emb.shape[0] + latent = tensors[:, start : start + token_len] + return_dict[latent_name] = latent + start += token_len + + for latent_name in output_latents: + if latent_name not in visited_latent_names: + return_dict[latent_name] = output_latents[latent_name] + + return return_dict diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_flow.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..ae29759f6364295aa100af74f4bbece9ebbf5c4b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_flow.py @@ -0,0 +1,303 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import convert_module_to_f16, convert_module_to_f32 +from collections import namedtuple +from ..modules.utils import FP16_TYPE +from ..modules.transformer import ( + AbsolutePositionEmbedder, + ModulatedTransformerCrossBlock, +) +from ..modules.spatial import patchify, unpatchify +import warnings + + +DataType = namedtuple("DataType", ["shape", "pose"]) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + if t.ndim == 0: + t = t.unsqueeze(0) + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + include_pose: bool = False, + pose_weight: float = 1.0, + ): + warnings.warn( + "The old SparseStructureFlowModel is deprecated! Please upgrade to use the one in mm_sparse_structure_flow.", + category=DeprecationWarning, + ) + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = FP16_TYPE if use_fp16 else torch.float32 + self.include_pose = include_pose + self.pose_weight = pose_weight + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + pos_embedder = AbsolutePositionEmbedder(model_channels, 3) + coords = torch.meshgrid( + *[ + torch.arange(res, device=self.device) + for res in [resolution // patch_size] * 3 + ], + indexing="ij", + ) + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + pos_emb = pos_embedder(coords) + if include_pose: + pose_pos_emb = torch.ones(1, model_channels) * 0.5 + pos_emb = torch.cat([pos_emb, pose_pos_emb], dim=0) + self.register_buffer("pos_emb", pos_emb) + + self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels) + + self.blocks = nn.ModuleList( + [ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode="full", + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ] + ) + + self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + # nn.init.constant_(self.out_layer.weight, 0) + nn.init.xavier_uniform_(self.out_layer.weight) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, + x: NamedTuple, + t: torch.Tensor, + cond: torch.Tensor, + ) -> NamedTuple: + pose = x.pose + x = x.shape + + assert [*x.shape] == [ + x.shape[0], + self.in_channels, + *[self.resolution] * 3, + ], f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = patchify(x, self.patch_size) + h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous() + if self.include_pose: + h = torch.cat([h, pose], dim=1) + + h = self.input_layer(h) + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + h = h.type(self.dtype) + cond = cond.type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + h = h.type(x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + if self.include_pose: + pose = h[:, -1:] + h = h[:, :-1] + + h = h.permute(0, 2, 1).view( + h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3 + ) + h = unpatchify(h, self.patch_size).contiguous() + return DataType(h, pose) + + +class SparseStructureFlowTdfyWrapper(SparseStructureFlowModel): + def __init__(self, *args, **kwargs): + condition_embedder = kwargs.pop("condition_embedder", None) + # if enabled, model will record the condition_shape in one run and uses zeros for all that afterwards + force_zeros_cond = kwargs.pop("force_zeros_cond", False) + # backward compatible to models trained before PR #87 + kwargs.pop("shape_attend_pose", None) + super().__init__(*args, **kwargs) + if condition_embedder is not None: + self.condition_embedder = condition_embedder + else: + self.condition_embedder = lambda x: x + self.force_zeros_cond = force_zeros_cond + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + *condition_args, + **condition_kwargs, + ) -> torch.Tensor: + cfg_activate = condition_kwargs.pop("cfg", False) + if self.force_zeros_cond and cfg_activate: + # TODO: @weiyaowang, refactor to read directly from embedder + cond = self.condition_embedder(*condition_args, **condition_kwargs) + cond = cond * 0 + else: + cond = self.condition_embedder(*condition_args, **condition_kwargs) + if self.include_pose: + pose = x[:, -1:] + x = x[:, :-1] + else: + pose = None + x = x.permute(0, 2, 1).contiguous() + n_voxels_cubed = x.shape[-1] + cube_root = n_voxels_cubed ** (1 / 3) + n_voxels = round(cube_root) + assert n_voxels - cube_root < 1e-6 + x = x.view(x.shape[0], x.shape[1], n_voxels, n_voxels, n_voxels) + input = DataType(x, pose) + output = super().forward(input, t, cond) + h = output.shape + pose = output.pose + h = h.view(h.shape[0], h.shape[1], -1).permute(0, 2, 1).contiguous() + if self.include_pose: + h = torch.cat([h, pose], dim=1) + return h diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_vae.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0db7f5790f6698199cc9873ab032e23a0d85a98c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/sparse_structure_vae.py @@ -0,0 +1,402 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +import math +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..modules.norm import GroupNorm32, ChannelLayerNorm32 +from ..modules.spatial import pixel_shuffle_3d +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from safetensors.torch import load_file +from loguru import logger + + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + """ + Return a normalization layer. + """ + if norm_type == "group": + return GroupNorm32(32, *args, **kwargs) + elif norm_type == "layer": + return ChannelLayerNorm32(*args, **kwargs) + else: + raise ValueError(f"Invalid norm type {norm_type}") + + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type: Literal["group", "layer"] = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = zero_module( + nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + ) + self.skip_connection = ( + nn.Conv3d(channels, self.out_channels, 1) + if channels != self.out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class DownsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "avgpool"] = "conv", + ): + assert mode in ["conv", "avgpool"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2) + elif mode == "avgpool": + assert ( + in_channels == out_channels + ), "Pooling mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + return self.conv(x) + else: + return F.avg_pool3d(x, 2) + + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode: Literal["conv", "nearest"] = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels * 8, 3, padding=1) + elif mode == "nearest": + assert ( + in_channels == out_channels + ), "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + + +class SparseStructureEncoder(nn.Module): + """ + Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). + + Args: + in_channels (int): Channels of the input. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the encoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + + def __init__( + self, + in_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)]) + if i < len(channels) - 1: + self.blocks.append(DownsampleBlock3d(ch, channels[i + 1])) + + self.middle_block = nn.Sequential( + *[ + ResBlock3d(channels[-1], channels[-1]) + for _ in range(num_res_blocks_middle) + ] + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], latent_channels * 2, 3, padding=1), + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward( + self, x: torch.Tensor, sample_posterior: bool = False, return_raw: bool = False + ) -> torch.Tensor: + x = x.float() + h = self.input_layer(x) + h = h.type(self.dtype) + + for block in self.blocks: + h = block(h) + h = self.middle_block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + + mean, logvar = h.chunk(2, dim=1) + + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + + if return_raw: + return z, mean, logvar + return z + + +class SparseStructureDecoder(nn.Module): + """ + Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). + + Args: + out_channels (int): Channels of the output. + latent_channels (int): Channels of the latent representation. + num_res_blocks (int): Number of residual blocks at each resolution. + channels (List[int]): Channels of the decoder blocks. + num_res_blocks_middle (int): Number of residual blocks in the middle. + norm_type (Literal["group", "layer"]): Type of normalization layer. + use_fp16 (bool): Whether to use FP16. + """ + + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type: Literal["group", "layer"] = "layer", + reshape_input_to_cube: bool = False, + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.reshape_input_to_cube = reshape_input_to_cube + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential( + *[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ] + ) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ResBlock3d(ch, ch) for _ in range(num_res_blocks)]) + if i < len(channels) - 1: + self.blocks.append(UpsampleBlock3d(ch, channels[i + 1])) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1), + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.use_fp16 = False + self.dtype = torch.float32 + self.blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.reshape_input_to_cube: + x = self.flat_to_cube(x) + + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h + + @staticmethod + def flat_to_cube(flat_latent: torch.Tensor) -> torch.Tensor: + """ + For converting latent tokens from generator to cube + + Args: + flat_latent: (B, T, C) + Returns: + cube: (B, C, D, H, W) + """ + k = round(math.pow(flat_latent.shape[1], 1 / 3)) + assert ( + k**3 == flat_latent.shape[1] + ), f"Flat latent must be a cube {k**3} != {flat_latent.shape[1]}" + latent = flat_latent.view( + flat_latent.shape[0], k, k, k, flat_latent.shape[2] + ).permute(0, 4, 1, 2, 3) + return latent + + +class SparseStructureDecoderTdfyWrapper(SparseStructureDecoder): + def __init__(self, *args, **kwargs): + pretrained_ckpt_path = kwargs.pop("pretrained_ckpt_path", None) + super().__init__(*args, **kwargs) + if pretrained_ckpt_path is not None: + if os.path.exists(pretrained_ckpt_path): + logger.info( + f"Loading pretrained ss decoder from {pretrained_ckpt_path}" + ) + file_type = os.path.splitext(pretrained_ckpt_path)[1] + if file_type == ".safetensors": + self.load_state_dict(load_file(pretrained_ckpt_path)) + else: + self.load_state_dict( + torch.load(pretrained_ckpt_path, weights_only=True) + ) + else: + raise FileNotFoundError( + f"The path for the SS decoder does not exist: {pretrained_ckpt_path}" + ) + + +class SparseStructureEncoderTdfyWrapper(SparseStructureEncoder): + def __init__(self, sample_posterior=True, return_raw=True, *args, **kwargs): + pretrained_ckpt_path = kwargs.pop("pretrained_ckpt_path", None) + super().__init__(*args, **kwargs) + if pretrained_ckpt_path is not None: + if os.path.exists(pretrained_ckpt_path): + logger.info( + f"Loading pretrained ss encoder from {pretrained_ckpt_path}" + ) + file_type = os.path.splitext(pretrained_ckpt_path)[1] + if file_type == ".safetensors": + self.load_state_dict(load_file(pretrained_ckpt_path)) + else: + self.load_state_dict( + torch.load(pretrained_ckpt_path, weights_only=True) + ) + else: + raise FileNotFoundError( + f"The path for the SS encoder does not exist: {pretrained_ckpt_path}" + ) + self.sample_posterior = sample_posterior + self.return_raw = return_raw + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z, mean, logvar = super().forward( + x, sample_posterior=self.sample_posterior, return_raw=True + ) + if self.return_raw: + return { + "z": z, + "mean": mean, + "logvar": logvar, + } + else: + return {"z": z} diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_flow.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..76d48d3006958b508add001ee62be57194ea3b14 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_flow.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not ( + downsample and upsample + ), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module( + sp.SparseConv3d(self.out_channels, self.out_channels, 3) + ) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = ( + sp.SparseLinear(channels, self.out_channels) + if channels != self.out_channels + else nn.Identity() + ) + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + is_shortcut_model: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.is_shortcut_model = is_shortcut_model + self.dtype = torch.float16 if use_fp16 else torch.float32 + if is_shortcut_model: + self.d_embedder = TimestepEmbedder(model_channels) # for shortcut model + + + assert int(np.log2(patch_size)) == np.log2( + patch_size + ), "Patch size must be a power of 2" + assert np.log2(patch_size) == len( + io_block_channels + ), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, io_block_channels[0]) + self.input_blocks = nn.ModuleList([]) + for chs, next_chs in zip( + io_block_channels, io_block_channels[1:] + [model_channels] + ): + self.input_blocks.extend( + [ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks - 1) + ] + ) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList( + [ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode="full", + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ] + ) + + self.out_blocks = nn.ModuleList([]) + for chs, prev_chs in zip( + reversed(io_block_channels), + [model_channels] + list(reversed(io_block_channels[1:])), + ): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend( + [ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks - 1) + ] + ) + self.out_layer = sp.SparseLinear(io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # zero init like controlnet, for MLP should only zero + # # the weight of the last layer only + # if self.is_shortcut_model: + # nn.init.constant_(self.d_embedder.mlp[2].weight, 0) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward( + self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor, d: torch.Tensor = None + ) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if d is not None: + d_emb = self.d_embedder(d) + t_emb = t_emb + d_emb + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h + + +class SLatFlowModelTdfyWrapper(SLatFlowModel): + def __init__(self, *args, **kwargs): + condition_embedder = kwargs.pop("condition_embedder", None) + # if enabled, model will record the condition_shape in one run and uses zeros for all that afterwards + force_zeros_cond = kwargs.pop("force_zeros_cond", False) + # backward compatible to models trained before PR #87 + # kwargs.pop("use_fp16", None) + super().__init__(*args, **kwargs) + if condition_embedder is not None: + self.condition_embedder = condition_embedder + else: + self.condition_embedder = lambda x: x + self.force_zeros_cond = force_zeros_cond + # self.null_condition = None + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + *condition_args, + **condition_kwargs, + ) -> torch.Tensor: + # Extract d from kwargs_conditionals if present, for shortcut model + if not torch.compiler.is_compiling(): + d = condition_kwargs.pop("d", None) + if "coords" in condition_kwargs: + coords = condition_kwargs["coords"] + del condition_kwargs["coords"] + else: + coords = condition_args[-1] + condition_args = condition_args[:-1] + else: + coords = condition_args[-1] + condition_args = condition_args[:-1] + d = condition_kwargs.pop("d", None) + + coords = torch.tensor(coords).to(x.device) + x = sp.SparseTensor( + feats=x[0], + coords=coords, + ) + cfg_activate = condition_kwargs.pop("cfg", False) + if self.force_zeros_cond and cfg_activate: + # TODO: @weiyaowang, refactor to read directly from embedder + cond = self.condition_embedder(*condition_args, **condition_kwargs) + cond = cond * 0 + else: + cond = self.condition_embedder(*condition_args, **condition_kwargs) + h = super().forward(x, t, cond, d) + h = h.feats[None] + return h diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15ebf58f34a933c4f69f797c53f069407f9e4b57 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .encoder import SLatEncoder +from .decoder_gs import SLatGaussianDecoder +from .decoder_rf import SLatRadianceFieldDecoder diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/base.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/base.py new file mode 100644 index 0000000000000000000000000000000000000000..683bd67eacfe9abb1e0ca8934d0faf855f44672c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/base.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +from ...modules.utils import convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from ...modules.transformer import AbsolutePositionEmbedder +from ...modules.sparse.transformer import SparseTransformerBlock + + +def block_attn_config(self): + """ + Return the attention configuration of the model. + """ + for i in range(self.num_blocks): + if self.attn_mode == "shift_window": + yield "serialized", self.window_size, 0, ( + 16 * (i % 2), + ) * 3, sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_sequence": + yield "serialized", self.window_size, self.window_size // 2 * (i % 2), ( + 0, + 0, + 0, + ), sp.SerializeMode.Z_ORDER + elif self.attn_mode == "shift_order": + yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4] + elif self.attn_mode == "full": + yield "full", None, None, None, None + elif self.attn_mode == "swin": + yield "windowed", self.window_size, None, self.window_size // 2 * ( + i % 2 + ), None + + +class SparseTransformerBase(nn.Module): + """ + Sparse Transformer without output layers. + Serve as the base class for encoder and decoder. + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4.0, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", + window_size: Optional[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.window_size = window_size + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.attn_mode = attn_mode + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.qk_rms_norm = qk_rms_norm + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels) + self.blocks = nn.ModuleList( + [ + SparseTransformerBlock( + model_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + qk_rms_norm=self.qk_rms_norm, + ) + for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config( + self + ) + ] + ) + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + h = self.input_layer(x) + if self.pe_mode == "ape": + h = h + self.pos_embedder(x.coords[:, 1:]) + h = h.type(self.dtype) + for block in self.blocks: + h = block(h) + return h diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_gs.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000000000000000000000000000000000000..6b192a45528dae9ce2debb5aea630d7778b7497f --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian +from safetensors.torch import load_file +from loguru import logger + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [ + hammersley_sequence(3, i, self.rep_config["num_gaussians"]) + for i in range(self.rep_config["num_gaussians"]) + ] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config["voxel_size"] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer("offset_perturbation", perturbation) + + def _calc_layout(self) -> None: + self.layout = { + "_xyz": { + "shape": (self.rep_config["num_gaussians"], 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_features_dc": { + "shape": (self.rep_config["num_gaussians"], 1, 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_scaling": { + "shape": (self.rep_config["num_gaussians"], 3), + "size": self.rep_config["num_gaussians"] * 3, + }, + "_rotation": { + "shape": (self.rep_config["num_gaussians"], 4), + "size": self.rep_config["num_gaussians"] * 4, + }, + "_opacity": { + "shape": (self.rep_config["num_gaussians"], 1), + "size": self.rep_config["num_gaussians"], + }, + } + start = 0 + for k, v in self.layout.items(): + v["range"] = (start, start + v["size"]) + start += v["size"] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size=self.rep_config["3d_filter_kernel_size"], + scaling_bias=self.rep_config["scaling_bias"], + opacity_bias=self.rep_config["opacity_bias"], + scaling_activation=self.rep_config["scaling_activation"], + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == "_xyz": + offset = x.feats[x.layout[i]][ + :, v["range"][0] : v["range"][1] + ].reshape(-1, *v["shape"]) + offset = offset * self.rep_config["lr"][k] + if self.rep_config["perturb_offset"]: + offset = offset + self.offset_perturbation + offset = ( + torch.tanh(offset) + / self.resolution + * 0.5 + * self.rep_config["voxel_size"] + ) + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = ( + x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]] + .reshape(-1, *v["shape"]) + .flatten(0, 1) + ) + feats = feats * self.rep_config["lr"][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) + + +class SLatGaussianDecoderTdfyWrapper(SLatGaussianDecoder): + def __init__(self, *args, **kwargs): + pretrained_ckpt_path = kwargs.pop("pretrained_ckpt_path", None) + super().__init__(*args, **kwargs) + if pretrained_ckpt_path is not None: + if os.path.exists(pretrained_ckpt_path): + logger.info( + f"Loading pretrained slat decoder gs from {pretrained_ckpt_path}" + ) + file_type = os.path.splitext(pretrained_ckpt_path)[1] + if file_type == ".safetensors": + self.load_state_dict(load_file(pretrained_ckpt_path)) + else: + self.load_state_dict( + torch.load(pretrained_ckpt_path, weights_only=True) + ) + else: + raise FileNotFoundError( + f"The path for slat decoder gs does not exist: {pretrained_ckpt_path}" + ) \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_mesh.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0ff36e2867e27572ff7d740a1603911ed61873 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh +import os +from safetensors.torch import load_file +from loguru import logger + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32, + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d( + channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}" + ), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module( + sp.SparseConv3d( + self.out_channels, + self.out_channels, + 3, + indice_key=f"res_{self.out_resolution}", + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d( + channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}" + ) + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + device="cuda" + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.mesh_extractor = SparseFeatures2Mesh( + res=self.resolution * 4, use_color=self.rep_config.get("use_color", False), device=device + ) + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList( + [ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4, + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8, + ), + ] + ) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) + + +class SLatMeshDecoderTdfyWrapper(SLatMeshDecoder): + def __init__(self, *args, **kwargs): + pretrained_ckpt_path = kwargs.pop("pretrained_ckpt_path", None) + super().__init__(*args, **kwargs) + if pretrained_ckpt_path is not None and os.path.exists(pretrained_ckpt_path): + logger.info( + f"Loading pretrained slat decoder gs from {pretrained_ckpt_path}" + ) + self.load_state_dict(load_file(pretrained_ckpt_path)) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_rf.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_rf.py new file mode 100644 index 0000000000000000000000000000000000000000..68d0e550c86a7f3faeaf0a6fe9dc5cabac68f526 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/decoder_rf.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import Strivec + + +class SLatRadianceFieldDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _calc_layout(self) -> None: + self.layout = { + "trivec": { + "shape": (self.rep_config["rank"], 3, self.rep_config["dim"]), + "size": self.rep_config["rank"] * 3 * self.rep_config["dim"], + }, + "density": { + "shape": (self.rep_config["rank"],), + "size": self.rep_config["rank"], + }, + "features_dc": { + "shape": (self.rep_config["rank"], 1, 3), + "size": self.rep_config["rank"] * 3, + }, + } + start = 0 + for k, v in self.layout.items(): + v["range"] = (start, start + v["size"]) + start += v["size"] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Strivec( + sh_degree=0, + resolution=self.resolution, + aabb=[-0.5, -0.5, -0.5, 1, 1, 1], + rank=self.rep_config["rank"], + dim=self.rep_config["dim"], + device="cuda", + ) + representation.density_shift = 0.0 + representation.position = ( + x.coords[x.layout[i]][:, 1:].float() + 0.5 + ) / self.resolution + representation.depth = torch.full( + (representation.position.shape[0], 1), + int(np.log2(self.resolution)), + dtype=torch.uint8, + device="cuda", + ) + for k, v in self.layout.items(): + setattr( + representation, + k, + x.feats[x.layout[i]][:, v["range"][0] : v["range"][1]].reshape( + -1, *v["shape"] + ), + ) + representation.trivec = representation.trivec + 1 + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Strivec]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/encoder.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b93b4d32602cd41fb9dfba19ed2fbbc35d87dbcb --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/structured_latent_vae/encoder.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase +from safetensors.torch import load_file +from loguru import logger +import os + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class SLatEncoderTdfyWrapper(SLatEncoder): + def __init__(self, *args, **kwargs): + pretrained_ckpt_path = kwargs.pop("pretrained_ckpt_path", None) + super().__init__(*args, **kwargs) + if pretrained_ckpt_path is not None and os.path.exists(pretrained_ckpt_path): + logger.info( + f"Loading pretrained slat decoder gs from {pretrained_ckpt_path}" + ) + file_type = os.path.splitext(pretrained_ckpt_path)[1] + if file_type == '.safetensors': + self.load_state_dict(load_file(pretrained_ckpt_path)) + else: + self.load_state_dict(torch.load(pretrained_ckpt_path, weights_only=True)) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/timestep_embedder.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/timestep_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8b5293028c76d7c8470e933df9ba98410353f5 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/models/timestep_embedder.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch.nn as nn +import torch +import numpy as np + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, freeze=False): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + if freeze: + self.requires_grad_(False) + self.eval() + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + + Returns: + an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -np.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + t = t[:, None].float() + args = t * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + if t.ndim == 0: + t = t.unsqueeze(0) + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..14053c17d579ddee4b6224c0b8a1be87cd7d23f4 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +from loguru import logger + +BACKEND = "sdpa" +DEBUG = False + + +def __from_env(): + import os + + global BACKEND + global DEBUG + + env_attn_backend = os.environ.get("ATTN_BACKEND") + env_sttn_debug = os.environ.get("ATTN_DEBUG") + + if env_attn_backend is not None and env_attn_backend in [ + "xformers", + "flash_attn", + "torch_flash_attn", + "sdpa", + "naive", + ]: + BACKEND = env_attn_backend + # BACKEND = "sdpa" + if env_sttn_debug is not None: + DEBUG = env_sttn_debug == "1" + + logger.info(f"[ATTENTION] Using backend: {BACKEND}") + + +__from_env() + + +def set_backend(backend: Literal["xformers", "flash_attn", "torch_flash_attn"]): + global BACKEND + BACKEND = backend + + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +from .full_attn import * +from .modules import * diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/full_attn.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..a9631af6164bd0686548dbff0b901aed6999c0b4 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/full_attn.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == "xformers": + import xformers.ops as xops +elif BACKEND == "flash_attn": + import flash_attn +elif BACKEND == "torch_flash_attn": + pass +elif BACKEND == "sdpa": + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == "naive": + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + "scaled_dot_product_attention", +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + + +@overload +def scaled_dot_product_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]} + num_all_args = len(args) + len(kwargs) + assert ( + num_all_args in arg_names_dict + ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args) :]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs["qkv"] + assert ( + len(qkv.shape) == 5 and qkv.shape[2] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs["q"] + kv = args[1] if len(args) > 1 else kwargs["kv"] + assert ( + q.shape[0] == kv.shape[0] + ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert ( + len(kv.shape) == 5 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs["q"] + k = args[1] if len(args) > 1 else kwargs["k"] + v = args[2] if len(args) > 2 else kwargs["v"] + assert ( + q.shape[0] == k.shape[0] == v.shape[0] + ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert ( + len(k.shape) == 4 + ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert ( + len(v.shape) == 4 + ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == "xformers": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == "flash_attn": + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == "sdpa": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == "torch_flash_attn": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + original_dtype = q.dtype + with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + with torch.autocast(device_type=device.type, dtype=torch.bfloat16): + out = torch.nn.functional.scaled_dot_product_attention(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + out = out.to(original_dtype) + elif BACKEND == "naive": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/modules.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..757b12981f0205618e9ff8e05cdf2581cd5954b3 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/attention/modules.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from functools import partial +from typing import * +from torch.utils import _pytree +import torch +import torch.nn as nn +import torch.nn.functional as F +from .full_attn import scaled_dot_product_attention +from sam3d_objects.data.utils import ( + tree_reduce_unique, +) + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x, dim=-1) * self.gamma * self.scale).to(x.dtype) + + +class RotaryPositionEmbedder(nn.Module): + def __init__(self, hidden_size: int, in_channels: int = 3): + super().__init__() + assert hidden_size % 2 == 0, "Hidden size must be divisible by 2" + self.hidden_size = hidden_size + self.in_channels = in_channels + self.freq_dim = hidden_size // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000**self.freqs) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases + x_embed = ( + torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + ) + return x_embed + + def forward( + self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (sp.SparseTensor): [..., N, D] tensor of queries + k (sp.SparseTensor): [..., N, D] tensor of keys + indices (torch.Tensor): [..., N, C] tensor of spatial positions + """ + if indices is None: + indices = torch.arange(q.shape[-2], device=q.device) + if len(q.shape) > 2: + indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,)) + + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[1] < self.hidden_size // 2: + phases = torch.cat( + [ + phases, + torch.polar( + torch.ones( + *phases.shape[:-1], + self.hidden_size // 2 - phases.shape[1], + device=phases.device, + ), + torch.zeros( + *phases.shape[:-1], + self.hidden_size // 2 - phases.shape[1], + device=phases.device, + ), + ), + ], + dim=-1, + ) + q_embed = self._rotary_embedding(q, phases) + k_embed = self._rotary_embedding(k, phases) + return q_embed, k_embed + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert ( + type == "self" or attn_mode == "full" + ), "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + if self.use_rope: + q, k, v = qkv.unbind(dim=2) + q, k = self.rope(q, k, indices) + qkv = torch.stack([q, k, v], dim=2) + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=2) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + + +class MOTMultiHeadSelfAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + latent_names: List = None, + protect_modality_list: List = ["shape"], + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" + assert ( + type == "self" or attn_mode == "full" + ), "Cross-attention only supports full attention" + + if attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + self.protect_modality_list = protect_modality_list + + if self._type == "self": + self.to_qkv = torch.nn.ModuleDict( + { + latent_name: nn.Linear(channels, channels * 3, bias=qkv_bias) + for latent_name in latent_names + } + ) + else: + self.to_q = torch.nn.ModuleDict( + { + latent_name: nn.Linear(channels, channels, bias=qkv_bias) + for latent_name in latent_names + } + ) + self.to_kv = torch.nn.ModuleDict( + { + latent_name: nn.Linear( + self.ctx_channels, channels * 2, bias=qkv_bias + ) + for latent_name in latent_names + } + ) + + if self.qk_rms_norm: + self.q_rms_norm = torch.nn.ModuleDict( + { + latent_name: MultiHeadRMSNorm(self.head_dim, num_heads) + for latent_name in latent_names + } + ) + self.k_rms_norm = torch.nn.ModuleDict( + { + latent_name: MultiHeadRMSNorm(self.head_dim, num_heads) + for latent_name in latent_names + } + ) + + self.to_out = torch.nn.ModuleDict( + {latent_name: nn.Linear(channels, channels) for latent_name in latent_names} + ) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + def _reshape(self, qkv, tensor_shape, num_heads): + B, L, _ = tensor_shape + return qkv.reshape(B, L, 3, num_heads, -1) + + def _reshape_back(self, qkv, tensor_shape): + B, L, _ = tensor_shape + return qkv.reshape(B, L, -1) + + def _apply_module(self, x, module): + return module(x) + + # This is stupid, _pytree does not support ModuleDict + def _moduledict_to_dict(self, module): + return {key: module for key, module in module.items()} + + def unbind_qkv(self, qkv): + q, k, v = {}, {}, {} + for latent_name, _qkv in qkv.items(): + _q, _k, _v = _qkv.unbind(dim=2) + q[latent_name] = _q + k[latent_name] = _k + v[latent_name] = _v + + return q, k, v + + def _get_shape(self, x): + return x.shape + + def concatenate_tensor(self, tensor_dict, latent_names): + merged = [] + indicies_mapping = {} + total_tokens = 0 + for latent_name in latent_names: + merged.append(tensor_dict[latent_name]) + cur_token_len = tensor_dict[latent_name].shape[1] + indicies_mapping[latent_name] = [total_tokens, cur_token_len] + total_tokens += cur_token_len + # merge along token dimension + return torch.cat(merged, dim=1), indicies_mapping + + def unpack_tensors(self, h_others, indicies_mapping): + h = {} + for latent_name, (start, cur_token_len) in indicies_mapping.items(): + h[latent_name] = h_others[:, start : start + cur_token_len] + + return h + + def mm_scale_dot_product_attention(self, q, k, v): + h = {} + latent_names = list(q.keys()) + # for protected modality, it only attends itself + for protect_modality in self.protect_modality_list: + _q = q[protect_modality] + _k = k[protect_modality] + _v = v[protect_modality] + h[protect_modality] = scaled_dot_product_attention(_q, _k, _v) + + # for the rest it is ok to attend each other and allow gradient + other_modalities = [ + n for n in latent_names if n not in self.protect_modality_list + ] + _q, indicies_mapping = self.concatenate_tensor(q, other_modalities) + o_k, _ = self.concatenate_tensor(k, other_modalities) + o_v, _ = self.concatenate_tensor(v, other_modalities) + # no gradiant flow back to protected modality (e.g. shape) + _k, _ = self.concatenate_tensor(k, self.protect_modality_list) + _v, _ = self.concatenate_tensor(v, self.protect_modality_list) + _k = _k.detach() + _v = _v.detach() + _k = torch.cat([o_k, _k], dim=1) + _v = torch.cat([o_v, _v], dim=1) + h_others = scaled_dot_product_attention(_q, _k, _v) + h.update(self.unpack_tensors(h_others, indicies_mapping)) + + return h + + def forward( + self, + x: Dict, + ) -> torch.Tensor: + shapes = _pytree.tree_map(self._get_shape, x) + if self._type == "self": + qkv = _pytree.tree_map( + self._apply_module, x, self._moduledict_to_dict(self.to_qkv) + ) + qkv = _pytree.tree_map( + partial(self._reshape, num_heads=self.num_heads), qkv, shapes + ) + if self.use_rope: + raise NotImplementedError + if self.attn_mode == "full": + if self.qk_rms_norm: + q, k, v = self.unbind_qkv(qkv) + q = _pytree.tree_map( + self._apply_module, q, self._moduledict_to_dict(self.q_rms_norm) + ) + k = _pytree.tree_map( + self._apply_module, k, self._moduledict_to_dict(self.k_rms_norm) + ) + h = self.mm_scale_dot_product_attention(q, k, v) + else: + raise NotImplementedError + elif self.attn_mode == "windowed": + raise NotImplementedError("Windowed attention is not yet implemented") + else: + raise NotImplementedError + + h = _pytree.tree_map(self._reshape_back, h, shapes) + h = _pytree.tree_map( + self._apply_module, h, self._moduledict_to_dict(self.to_out) + ) + return h diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/norm.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9c505fd7b2dc33dac3f9b9ce6df60ec241d385 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/norm.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM - 1, *range(1, DIM - 1)).contiguous() + return x diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..1abbeeb1a45c11839091f7fe80b1e6c62db9bfea --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/__init__.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +from loguru import logger + +BACKEND = "spconv" +# BACKEND = "torchsparse" +DEBUG = False +ATTN = "sdpa" + + +def __from_env(): + import os + + global BACKEND + global DEBUG + global ATTN + + env_sparse_backend = os.environ.get("SPARSE_BACKEND") + env_sparse_debug = os.environ.get("SPARSE_DEBUG") + env_sparse_attn = os.environ.get("SPARSE_ATTN_BACKEND") + if env_sparse_attn is None: + env_sparse_attn = os.environ.get("ATTN_BACKEND") + + if env_sparse_backend is not None and env_sparse_backend in [ + "spconv", + "torchsparse", + ]: + BACKEND = env_sparse_backend + if env_sparse_debug is not None: + DEBUG = env_sparse_debug == "1" + # env_sparse_attn = "sdpa" + if env_sparse_attn is not None and env_sparse_attn in [ + "xformers", + "flash_attn", + "sdpa", + ]: + ATTN = env_sparse_attn + + logger.info(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}") + + +__from_env() + + +def set_backend(backend: Literal["spconv", "torchsparse"]): + global BACKEND + BACKEND = backend + + +def set_debug(debug: bool): + global DEBUG + DEBUG = debug + + +def set_attn(attn: Literal["xformers", "flash_attn"]): + global ATTN + ATTN = attn + + +from .basic import * +from .norm import * +from .nonlinearity import * +from .linear import * +from .attention import * +from .conv import * +from .spatial import * +from . import transformer diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..79f739096b80636dc214e444a2acf239aa204401 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .full_attn import * +from .serialized_attn import * +from .windowed_attn import * +from .modules import * +from .masked_sdpa import * diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/full_attn.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/full_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..1149b1ec6c6742566ff9f4119d0ab2eb13e6feef --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/full_attn.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == "xformers": + import xformers.ops as xops +elif ATTN == "flash_attn": + import flash_attn +elif ATTN == "sdpa": + from .masked_sdpa import masked_sdpa +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + "sparse_scaled_dot_product_attention", +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + + +@overload +def sparse_scaled_dot_product_attention( + q: SparseTensor, kv: Union[SparseTensor, torch.Tensor] +) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + + +@overload +def sparse_scaled_dot_product_attention( + q: torch.Tensor, kv: SparseTensor +) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + + +@overload +def sparse_scaled_dot_product_attention( + q: SparseTensor, k: SparseTensor, v: SparseTensor +) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + + +@overload +def sparse_scaled_dot_product_attention( + q: SparseTensor, k: torch.Tensor, v: torch.Tensor +) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + + +@overload +def sparse_scaled_dot_product_attention( + q: torch.Tensor, k: SparseTensor, v: SparseTensor +) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = {1: ["qkv"], 2: ["q", "kv"], 3: ["q", "k", "v"]} + num_all_args = len(args) + len(kwargs) + assert ( + num_all_args in arg_names_dict + ), f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args) :]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs["qkv"] + assert isinstance( + qkv, SparseTensor + ), f"qkv must be a SparseTensor, got {type(qkv)}" + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [ + qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0]) + ] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs["q"] + kv = args[1] if len(args) > 1 else kwargs["kv"] + assert ( + isinstance(q, SparseTensor) + and isinstance(kv, (SparseTensor, torch.Tensor)) + or isinstance(q, torch.Tensor) + and isinstance(kv, SparseTensor) + ), f"Invalid types, got {type(q)} and {type(kv)}" + assert ( + q.shape[0] == kv.shape[0] + ), f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert ( + len(q.shape) == 3 + ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert ( + len(kv.shape) == 4 and kv.shape[1] == 2 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [ + kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0]) + ] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert ( + len(kv.shape) == 5 + ), f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs["q"] + k = args[1] if len(args) > 1 else kwargs["k"] + v = args[2] if len(args) > 2 else kwargs["v"] + assert ( + isinstance(q, SparseTensor) + and isinstance(k, (SparseTensor, torch.Tensor)) + and type(k) == type(v) + or isinstance(q, torch.Tensor) + and isinstance(k, SparseTensor) + and isinstance(v, SparseTensor) + ), f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert ( + q.shape[0] == k.shape[0] == v.shape[0] + ), f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert ( + len(q.shape) == 3 + ), f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert ( + len(q.shape) == 4 + ), f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert ( + len(k.shape) == 3 + ), f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert ( + len(v.shape) == 3 + ), f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [ + k.layout[i].stop - k.layout[i].start for i in range(k.shape[0]) + ] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert ( + len(k.shape) == 4 + ), f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert ( + len(v.shape) == 4 + ), f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert ( + s.coords[s.layout[i]] == i + ).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [ + 1, + sum(q_seqlen), + ], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [ + 1, + sum(kv_seqlen), + ], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [ + 1, + sum(kv_seqlen), + ], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == "xformers": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == "flash_attn": + cu_seqlens_q = ( + torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]) + .int() + .to(device) + ) + if num_all_args in [2, 3]: + cu_seqlens_kv = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)] + ) + .int() + .to(device) + ) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens_q, max(q_seqlen) + ) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func( + q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen) + ) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen) + ) + elif ATTN == "sdpa": + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + out = masked_sdpa(q, k, v, q_seqlen, kv_seqlen) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/masked_sdpa.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/masked_sdpa.py new file mode 100644 index 0000000000000000000000000000000000000000..23489705e9a09ac6f57572afce3dfd7af572819a --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/masked_sdpa.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn.functional as F + + +def block_diag_attn_mask(q_seqlens, kv_seqlens, device=None, dtype=torch.float32): + """ + Create an additive attention mask for block-diagonal attention. + The result is shape [sum_q, sum_kv], with 0.0 in the valid + region(s) and -inf elsewhere. + """ + total_q = sum(q_seqlens) + total_kv = sum(kv_seqlens) + + # Start with everything "masked out" + attn_mask = torch.full( + (total_q, total_kv), float("-inf"), device=device, dtype=dtype + ) + + q_start = 0 + kv_start = 0 + for q_len, kv_len in zip(q_seqlens, kv_seqlens): + attn_mask[q_start : q_start + q_len, kv_start : kv_start + kv_len] = 0 + q_start += q_len + kv_start += kv_len + + return attn_mask + + +def masked_sdpa(q, k, v, q_seqlen, kv_seqlen): + """ + Mimic xFormers' memory_efficient_attention using PyTorch 2.0 scaled_dot_product_attention. + """ + # Build the block-diagonal additive mask + # shape: [sum_q_len, sum_kv_len] with 0 where allowed, -inf where masked + attn_mask_2d = block_diag_attn_mask( + q_seqlen, kv_seqlen, device=q.device, dtype=q.dtype + ) + + # PyTorch’s scaled_dot_product_attention expects a mask broadcastable to + # [batch_size, n_heads, q_len, kv_len]. For a single batch, single head: + attn_mask_4d = attn_mask_2d.unsqueeze(0).unsqueeze(0) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + + # Now call PyTorch 2.0’s built-in SDPA + # By default, it will automatically apply the "1/sqrt(dim)" scaling internally. + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=attn_mask_4d, # Additive mask + dropout_p=0.0, # or whatever dropout you need + is_causal=False, # True if you want a causal (triangular) mask + ) + # out is shape [1, sum_q_len, dim] + out = out.permute(0, 2, 1, 3) + + return out[0] diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/modules.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..175e12fe6e85f2a2d19243d21416c581ac5d2d3d --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/modules.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from .. import SparseTensor +from .full_attn import sparse_scaled_dot_product_attention +from .serialized_attn import ( + SerializeMode, + sparse_serialized_scaled_dot_product_self_attention, +) +from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention +from ...attention import RotaryPositionEmbedder + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward( + self, x: Union[SparseTensor, torch.Tensor] + ) -> Union[SparseTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, SparseTensor): + x = x.replace(F.normalize(x.feats, dim=-1)) + else: + x = F.normalize(x, dim=-1) + return (x * self.gamma * self.scale).to(x_type) + + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "serialized", "windowed"] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + qkv_bias: bool = True, + use_rope: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__() + assert channels % num_heads == 0 + assert type in ["self", "cross"], f"Invalid attention type: {type}" + assert attn_mode in [ + "full", + "serialized", + "windowed", + ], f"Invalid attention mode: {attn_mode}" + assert ( + type == "self" or attn_mode == "full" + ), "Cross-attention only supports full attention" + assert ( + type == "self" or use_rope is False + ), "Rotary position embeddings only supported for self-attention" + self.channels = channels + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_sequence = shift_sequence + self.shift_window = shift_window + self.serialize_mode = serialize_mode + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = RotaryPositionEmbedder(channels) + + @staticmethod + def _linear( + module: nn.Linear, x: Union[SparseTensor, torch.Tensor] + ) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs( + x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...] + ) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre( + self, x: Union[SparseTensor, torch.Tensor], num_fused: int + ) -> Union[SparseTensor, torch.Tensor]: + if isinstance(x, SparseTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats + + def _rope(self, qkv: SparseTensor) -> SparseTensor: + q, k, v = qkv.feats.unbind(dim=1) # [T, H, C] + q, k = self.rope(q, k, qkv.coords[:, 1:]) + qkv = qkv.replace(torch.stack([q, k, v], dim=1)) + return qkv + + def forward( + self, + x: Union[SparseTensor, torch.Tensor], + context: Optional[Union[SparseTensor, torch.Tensor]] = None, + ) -> Union[SparseTensor, torch.Tensor]: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.use_rope: + qkv = self._rope(qkv) + if self.qk_rms_norm: + q, k, v = qkv.unbind(dim=1) + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "serialized": + h = sparse_serialized_scaled_dot_product_self_attention( + qkv, + self.window_size, + serialize_mode=self.serialize_mode, + shift_sequence=self.shift_sequence, + shift_window=self.shift_window, + ) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=1) + k = self.k_rms_norm(k) + kv = kv.replace(torch.stack([k.feats, v.feats], dim=1)) + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/serialized_attn.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/serialized_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..f74fb000c63e1bba4e5fca07727c07b19d3146db --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == "xformers": + import xformers.ops as xops +elif ATTN == "flash_attn": + import flash_attn +elif ATTN == "sdpa": + from torch.nn.functional import scaled_dot_product_attention as sdpa + from .masked_sdpa import masked_sdpa +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + "sparse_serialized_scaled_dot_product_self_attention", +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED, +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if "vox2seq" not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor( + shift_window, dtype=torch.int32, device=tensor.device + ).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode="z_order", permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode="z_order", permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode="hilbert", permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start : s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append( + torch.zeros_like(to_ordered).scatter_( + 0, to_ordered, torch.arange(num_points, device=tensor.device) + ) + ) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [ + (i + 0.5) * valid_window_size + shift_sequence + for i in range(num_windows) + ] + split = [ + math.floor(i * valid_window_size + shift_sequence) + for i in range(num_windows + 1) + ] + bwd_index = torch.zeros( + (num_points,), dtype=torch.int64, device=tensor.device + ) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append( + to_ordered[ + torch.arange(padded_start, padded_end, device=tensor.device) + % num_points + ] + ) + offset += valid_start - padded_start + bwd_index.scatter_( + 0, + fwd_indices[-1][ + valid_start - padded_start : valid_end - padded_start + ], + torch.arange( + offset, offset + valid_end - valid_start, device=tensor.device + ), + ) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0), +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = ( + f"serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}" + ) + serialization_spatial_cache = qkv.get_spatial_cache( + serialization_spatial_cache_name + ) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization( + qkv, window_size, serialize_mode, shift_sequence, shift_window + ) + qkv.register_spatial_cache( + serialization_spatial_cache_name, + (fwd_indices, bwd_indices, seq_lens, seq_batch_indices), + ) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = ( + serialization_spatial_cache + ) + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert ( + qkv_coords[start : start + seq_lens[i], 0] == seq_batch_indices[i] + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + ) + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == "flash_attn": + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + elif ATTN == "sdpa": + q, k, v = qkv_feats.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == "flash_attn": + cu_seqlens = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], + dim=0, + ) + .to(qkv.device) + .int() + ) + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_feats, cu_seqlens, max(seq_lens) + ) # [M, H, C] + elif ATTN == "sdpa": + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + out = masked_sdpa(q, k, v, seq_lens, seq_lens) + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal( + qkv_coords, qkv.coords + ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/windowed_attn.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/windowed_attn.py new file mode 100755 index 0000000000000000000000000000000000000000..16268071a9dc0429201006c34cd8956779a2dee5 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/attention/windowed_attn.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == "xformers": + import xformers.ops as xops +elif ATTN == "flash_attn": + import flash_attn +elif ATTN == "sdpa": + from torch.nn.functional import scaled_dot_product_attention as sdpa + from .masked_sdpa import masked_sdpa +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + "sparse_windowed_scaled_dot_product_self_attention", +] + + +def calc_window_partition( + tensor: SparseTensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + shift_window (Tuple[int, ...]): The shift of serialized coordinates. + + Returns: + (torch.Tensor): Forwards indices. + (torch.Tensor): Backwards indices. + (List[int]): Sequence lengths. + (List[int]): Sequence batch indices. + """ + DIM = tensor.coords.shape[1] - 1 + shift_window = ( + (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + ) + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor( + shift_window, device=tensor.device, dtype=torch.int32 + ).unsqueeze(0) + + MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist() + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor( + window_size, device=tensor.device, dtype=torch.int32 + ).unsqueeze(0) + shifted_indices = ( + shifted_coords + * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0) + ).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + seq_batch_indices = ( + torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) + // OFFSET[0] + ) + mask = seq_lens != 0 + seq_lens = seq_lens[mask].tolist() + seq_batch_indices = seq_batch_indices[mask].tolist() + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_windowed_scaled_dot_product_self_attention( + qkv: SparseTensor, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply windowed scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert ( + len(qkv.shape) == 4 and qkv.shape[1] == 3 + ), f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f"window_partition_{window_size}_{shift_window}" + serialization_spatial_cache = qkv.get_spatial_cache( + serialization_spatial_cache_name + ) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition( + qkv, window_size, shift_window + ) + qkv.register_spatial_cache( + serialization_spatial_cache_name, + (fwd_indices, bwd_indices, seq_lens, seq_batch_indices), + ) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = ( + serialization_spatial_cache + ) + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + seq_coords = qkv_coords[start : start + seq_lens[i]] + assert ( + seq_coords[:, 0] == seq_batch_indices[i] + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + ) + assert ( + seq_coords[:, 1:].max(dim=0).values + - seq_coords[:, 1:].min(dim=0).values + < window_size + ).all(), ( + f"SparseWindowedScaledDotProductSelfAttention: window size exceeded" + ) + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == "flash_attn": + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + elif ATTN == "sdpa": + q, k, v = qkv_feats.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == "xformers": + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == "flash_attn": + cu_seqlens = ( + torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], + dim=0, + ) + .to(qkv.device) + .int() + ) + out = flash_attn.flash_attn_varlen_qkvpacked_func( + qkv_feats, cu_seqlens, max(seq_lens) + ) # [M, H, C] + elif ATTN == "sdpa": + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + out = masked_sdpa(q, k, v, seq_lens, seq_lens) + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal( + qkv_coords, qkv.coords + ), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/basic.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/basic.py new file mode 100755 index 0000000000000000000000000000000000000000..b4709c98221f61972f68706e3131d32745591459 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/basic.py @@ -0,0 +1,525 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +from . import BACKEND, DEBUG + +if BACKEND == "torchsparse": + from torchsparse import SparseTensor as SparseTensorData +elif BACKEND == "spconv": + from spconv.pytorch import SparseConvTensor as SparseTensorData + +__all__ = [ + "SparseTensor", + "sparse_batch_broadcast", + "sparse_batch_op", + "sparse_cat", + "sparse_unbind", +] + + +class SparseTensor: + """ + Sparse tensor with support for both torchsparse and spconv backends. + + Parameters: + - feats (torch.Tensor): Features of the sparse tensor. + - coords (torch.Tensor): Coordinates of the sparse tensor. + - shape (torch.Size): Shape of the sparse tensor. + - layout (List[slice]): Layout of the sparse tensor for each batch + - data (SparseTensorData): Sparse tensor data used for convolusion + + NOTE: + - Data corresponding to a same batch should be contiguous. + - Coords should be in [0, 1023] + """ + + @overload + def __init__( + self, + feats: torch.Tensor, + coords: torch.Tensor, + shape: Optional[torch.Size] = None, + layout: Optional[List[slice]] = None, + **kwargs, + ): ... + + @overload + def __init__( + self, + data, + shape: Optional[torch.Size] = None, + layout: Optional[List[slice]] = None, + **kwargs, + ): ... + + def __init__(self, *args, **kwargs): + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if "data" in kwargs else 0 + + if method_id == 0: + feats, coords, shape, layout = args + (None,) * (4 - len(args)) + if "feats" in kwargs: + feats = kwargs["feats"] + del kwargs["feats"] + if "coords" in kwargs: + coords = kwargs["coords"] + del kwargs["coords"] + if "shape" in kwargs: + shape = kwargs["shape"] + del kwargs["shape"] + if "layout" in kwargs: + layout = kwargs["layout"] + del kwargs["layout"] + + if shape is None: + shape = self.__cal_shape(feats, coords) + if layout is None: + layout = self.__cal_layout(coords, shape[0]) + if BACKEND == "torchsparse": + self.data = SparseTensorData(feats, coords, **kwargs) + elif BACKEND == "spconv": + spatial_shape = list(coords.max(0)[0] + 1)[1:] + self.data = SparseTensorData( + feats.reshape(feats.shape[0], -1), + coords, + spatial_shape, + shape[0], + **kwargs, + ) + self.data._features = feats + elif method_id == 1: + data, shape, layout = args + (None,) * (3 - len(args)) + if "data" in kwargs: + data = kwargs["data"] + del kwargs["data"] + if "shape" in kwargs: + shape = kwargs["shape"] + del kwargs["shape"] + if "layout" in kwargs: + layout = kwargs["layout"] + del kwargs["layout"] + + self.data = data + if shape is None: + shape = self.__cal_shape(self.feats, self.coords) + if layout is None: + layout = self.__cal_layout(self.coords, shape[0]) + + self._shape = shape + self._layout = layout + self._scale = kwargs.get("scale", (1, 1, 1)) + self._spatial_cache = kwargs.get("spatial_cache", {}) + + if DEBUG: + try: + assert ( + self.feats.shape[0] == self.coords.shape[0] + ), f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}" + assert self.shape == self.__cal_shape( + self.feats, self.coords + ), f"Invalid shape: {self.shape}" + assert self.layout == self.__cal_layout( + self.coords, self.shape[0] + ), f"Invalid layout: {self.layout}" + for i in range(self.shape[0]): + assert torch.all( + self.coords[self.layout[i], 0] == i + ), f"The data of batch {i} is not contiguous" + except Exception as e: + print("Debugging information:") + print(f"- Shape: {self.shape}") + print(f"- Layout: {self.layout}") + print(f"- Scale: {self._scale}") + print(f"- Coords: {self.coords}") + raise e + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [ + slice((offset[i] - seq_len[i]).item(), offset[i].item()) + for i in range(batch_size) + ] + return layout + + @property + def shape(self) -> torch.Size: + return self._shape + + def dim(self) -> int: + return len(self.shape) + + @property + def layout(self) -> List[slice]: + return self._layout + + @property + def feats(self) -> torch.Tensor: + if BACKEND == "torchsparse": + return self.data.F + elif BACKEND == "spconv": + return self.data.features + + @feats.setter + def feats(self, value: torch.Tensor): + if BACKEND == "torchsparse": + self.data.F = value + elif BACKEND == "spconv": + self.data.features = value + + @property + def coords(self) -> torch.Tensor: + if BACKEND == "torchsparse": + return self.data.C + elif BACKEND == "spconv": + return self.data.indices + + @coords.setter + def coords(self, value: torch.Tensor): + if BACKEND == "torchsparse": + self.data.C = value + elif BACKEND == "spconv": + self.data.indices = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @overload + def to(self, dtype: torch.dtype) -> "SparseTensor": ... + + @overload + def to( + self, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> "SparseTensor": ... + + def to(self, *args, **kwargs) -> "SparseTensor": + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if "dtype" in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs["dtype"] + if "device" in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs["device"] + + new_feats = self.feats.to(device=device, dtype=dtype) + new_coords = self.coords.to(device=device) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> "SparseTensor": + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> "SparseTensor": + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> "SparseTensor": + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> "SparseTensor": + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> "SparseTensor": + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def dense(self) -> torch.Tensor: + if BACKEND == "torchsparse": + return self.data.dense() + elif BACKEND == "spconv": + return self.data.dense() + + def reshape(self, *shape) -> "SparseTensor": + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List["SparseTensor"]: + return sparse_unbind(self, dim) + + def replace( + self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None + ) -> "SparseTensor": + new_shape = [self.shape[0]] + new_shape.extend(feats.shape[1:]) + if BACKEND == "torchsparse": + new_data = SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif BACKEND == "spconv": + new_data = SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict, + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + new_tensor = SparseTensor( + new_data, + shape=torch.Size(new_shape), + layout=self.layout, + scale=self._scale, + spatial_cache=self._spatial_cache, + ) + return new_tensor + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> "SparseTensor": + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing="ij"), dim=-1).reshape( + -1, 3 + ) + coords = torch.cat( + [ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], + dim=1, + ).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: "SparseTensor") -> dict: + new_cache = {} + for k in set( + list(self._spatial_cache.keys()) + list(other._spatial_cache.keys()) + ): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __neg__(self) -> "SparseTensor": + return self.replace(-self.feats) + + def __elemwise__( + self, other: Union[torch.Tensor, "SparseTensor"], op: callable + ) -> "SparseTensor": + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = sparse_batch_broadcast(self, other) + except: + pass + if isinstance(other, SparseTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __add__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.add) + + def __radd__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.add) + + def __sub__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.sub) + + def __rsub__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.mul) + + def __rmul__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.mul) + + def __truediv__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, torch.div) + + def __rtruediv__( + self, other: Union[torch.Tensor, "SparseTensor", float] + ) -> "SparseTensor": + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == ( + self.shape[0], + ), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + coords = [] + feats = [] + for new_idx, old_idx in enumerate(idx): + coords.append(self.coords[self.layout[old_idx]].clone()) + coords[-1][:, 0] = new_idx + feats.append(self.feats[self.layout[old_idx]]) + coords = torch.cat(coords, dim=0).contiguous() + feats = torch.cat(feats, dim=0).contiguous() + return SparseTensor(feats=feats, coords=coords) + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + +def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + coords, feats = input.coords, input.feats + broadcasted = torch.zeros_like(feats) + for k in range(input.shape[0]): + broadcasted[input.layout[k]] = other[k] + return broadcasted + + +def sparse_batch_op( + input: SparseTensor, other: torch.Tensor, op: callable = torch.add +) -> SparseTensor: + """ + Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation. + + Args: + input (torch.Tensor): 1D tensor to broadcast. + target (SparseTensor): Sparse tensor to broadcast to. + op (callable): Operation to perform after broadcasting. Defaults to torch.add. + """ + return input.replace(op(input.feats, sparse_batch_broadcast(input, other))) + + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + """ + Concatenate a list of sparse tensors. + + Args: + inputs (List[SparseTensor]): List of sparse tensors to concatenate. + """ + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + """ + Unbind a sparse tensor along a dimension. + + Args: + input (SparseTensor): Sparse tensor to unbind. + dim (int): Dimension to unbind. + """ + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5c31b4f76d772b001bf5e084a63efb5084801cea --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .. import BACKEND + + +SPCONV_ALGO = "auto" # 'auto', 'implicit_gemm', 'native' + + +def __from_env(): + import os + + global SPCONV_ALGO + env_spconv_algo = os.environ.get("SPCONV_ALGO") + if env_spconv_algo is not None and env_spconv_algo in [ + "auto", + "implicit_gemm", + "native", + ]: + SPCONV_ALGO = env_spconv_algo + print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}") + + +__from_env() + +if BACKEND == "torchsparse": + from .conv_torchsparse import * +elif BACKEND == "spconv": + from .conv_spconv import * diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/conv_spconv.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/conv_spconv.py new file mode 100755 index 0000000000000000000000000000000000000000..f54a6c4672a8c207fc825723cdb2bf52cf64c8be --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + + +class SparseConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + padding=None, + bias=True, + indice_key=None, + ): + super(SparseConv3d, self).__init__() + if "spconv" not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == "native": + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == "implicit_gemm": + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d( + in_channels, + out_channels, + kernel_size, + dilation=dilation, + bias=bias, + indice_key=indice_key, + algo=algo, + ) + else: + self.conv = spconv.SparseConv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + bias=bias, + indice_key=indice_key, + algo=algo, + ) + self.stride = ( + tuple(stride) + if isinstance(stride, (list, tuple)) + else (stride, stride, stride) + ) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_( + 0, fwd, torch.arange(fwd.shape[0], device=fwd.device) + ) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, + shape=torch.Size(new_shape), + layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache( + f"conv_{self.stride}_unsorted_data", unsorted_data + ) + out.register_spatial_cache(f"conv_{self.stride}_sort_bwd", bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + bias=True, + indice_key=None, + ): + super(SparseInverseConv3d, self).__init__() + if "spconv" not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d( + in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key + ) + self.stride = ( + tuple(stride) + if isinstance(stride, (list, tuple)) + else (stride, stride, stride) + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f"conv_{self.stride}_unsorted_data") + bwd = x.get_spatial_cache(f"conv_{self.stride}_sort_bwd") + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal( + data.indices, x.coords[bwd] + ), "Recover the original order failed" + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, + shape=torch.Size(new_shape), + layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/linear.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/linear.py new file mode 100755 index 0000000000000000000000000000000000000000..d3767a4756d29fc21646de0c27b462a135fd9480 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/linear.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = ["SparseLinear"] + + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/nonlinearity.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/nonlinearity.py new file mode 100755 index 0000000000000000000000000000000000000000..37db1b8386791543e7385f281722e4536c530d59 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/nonlinearity.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = ["SparseReLU", "SparseSiLU", "SparseGELU", "SparseActivation"] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/norm.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/norm.py new file mode 100755 index 0000000000000000000000000000000000000000..01d1e62f79bb14de50c7ca070201d970ab798d7f --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/norm.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + "SparseGroupNorm", + "SparseLayerNorm", + "SparseGroupNorm32", + "SparseLayerNorm32", +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert ( + input.coords[input.layout[k], 0] == k + ).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/spatial.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/spatial.py new file mode 100755 index 0000000000000000000000000000000000000000..2c65d0f43b3fb8f0ee41fb2a31ee4a3ec6cc374d --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/spatial.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = ["SparseDownsample", "SparseUpsample", "SparseSubdivide"] + + +class SparseDownsample(nn.Module): + """ + Downsample a sparse tensor by a factor of `factor`. + Implemented as average pooling. + """ + + def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): + super(SparseDownsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len( + factor + ), "Input coordinates must have the same dimension as the downsample factor." + + coord = list(input.coords.unbind(dim=-1)) + for i, f in enumerate(factor): + coord[i + 1] = coord[i + 1] // f + + MAX = [coord[i + 1].max().item() + 1 for i in range(DIM)] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_feats = torch.scatter_reduce( + torch.zeros( + code.shape[0], + input.feats.shape[1], + device=input.feats.device, + dtype=input.feats.dtype, + ), + dim=0, + index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), + src=input.feats, + reduce="mean", + ) + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i + 1]) % MAX[i] for i in range(DIM)], + dim=-1, + ) + out = SparseTensor( + new_feats, + new_coords, + input.shape, + ) + out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + + out.register_spatial_cache(f"upsample_{factor}_coords", input.coords) + out.register_spatial_cache(f"upsample_{factor}_layout", input.layout) + out.register_spatial_cache(f"upsample_{factor}_idx", idx) + + return out + + +class SparseUpsample(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + + def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): + super(SparseUpsample, self).__init__() + self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM + assert DIM == len( + factor + ), "Input coordinates must have the same dimension as the upsample factor." + + new_coords = input.get_spatial_cache(f"upsample_{factor}_coords") + new_layout = input.get_spatial_cache(f"upsample_{factor}_layout") + idx = input.get_spatial_cache(f"upsample_{factor}_idx") + if any([x is None for x in [new_coords, new_layout, idx]]): + raise ValueError( + "Upsample cache not found. SparseUpsample must be paired with SparseDownsample." + ) + new_feats = input.feats[idx] + out = SparseTensor(new_feats, new_coords, input.shape, new_layout) + out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) + out._spatial_cache = input._spatial_cache + return out + + +class SparseSubdivide(nn.Module): + """ + Upsample a sparse tensor by a factor of `factor`. + Implemented as nearest neighbor interpolation. + """ + + def __init__(self): + super(SparseSubdivide, self).__init__() + + def forward(self, input: SparseTensor) -> SparseTensor: + DIM = input.coords.shape[-1] - 1 + # upsample scale=2^DIM + n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) + n_coords = torch.nonzero(n_cube) + n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) + factor = n_coords.shape[0] + assert factor == 2**DIM + # print(n_coords.shape) + new_coords = input.coords.clone() + new_coords[:, 1:] *= 2 + new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to( + new_coords.dtype + ) + + new_feats = input.feats.unsqueeze(1).expand( + input.feats.shape[0], factor, *input.feats.shape[1:] + ) + out = SparseTensor( + new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape + ) + out._scale = input._scale * 2 + out._spatial_cache = input._spatial_cache + return out diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e94bfeea8b366b91cc24c284dd3e75b4aa7da4f8 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .blocks import * +from .modulated import * diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/blocks.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2e33cd30dd676c60e5e17ec0b650fba1b48b8728 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/blocks.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..linear import SparseLinear +from ..nonlinearity import SparseGELU +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 + + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: SparseTensor) -> SparseTensor: + return self.mlp(x) + + +class SparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN). + """ + + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor) -> SparseTensor: + h = x.replace(self.norm1(x.feats)) + h = self.attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, use_reentrant=False + ) + else: + return self._forward(x) + + +class SparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN). + """ + + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor): + h = x.replace(self.norm1(x.feats)) + h = self.self_attn(h) + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: SparseTensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, context, use_reentrant=False + ) + else: + return self._forward(x, context) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/modulated.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..d8277be02f24ba939a939070e83d66c7e3506cec --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/sparse/transformer/modulated.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +from ..basic import SparseTensor +from ..attention import SparseMultiHeadAttention, SerializeMode +from ...norm import LayerNorm32 +from .blocks import SparseFeedForwardNet + + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(mod).chunk(6, dim=1) + ) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, use_reentrant=False + ) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal[ + "full", "shift_window", "shift_sequence", "shift_order", "swin" + ] = "full", + window_size: Optional[int] = None, + shift_sequence: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + serialize_mode: Optional[SerializeMode] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_sequence=shift_sequence, + shift_window=shift_window, + serialize_mode=serialize_mode, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward( + self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor + ) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(mod).chunk(6, dim=1) + ) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward( + self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor + ) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, context, use_reentrant=False + ) + else: + return self._forward(x, mod, context) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/spatial.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/spatial.py new file mode 100644 index 0000000000000000000000000000000000000000..2a042b12a85cb077c4cc649d9671c1699bab9451 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/spatial.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H * scale_factor, W * scale_factor, D * scale_factor) + return x + + +def patchify(x: torch.Tensor, patch_size: int): + """ + Patchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + for d in range(2, DIM + 2): + assert ( + x.shape[d] % patch_size == 0 + ), f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}" + + x = x.reshape( + *x.shape[:2], + *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []), + ) + x = x.permute( + 0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]) + ) + x = x.reshape(x.shape[0], x.shape[1] * (patch_size**DIM), *(x.shape[-DIM:])) + return x + + +def unpatchify(x: torch.Tensor, patch_size: int): + """ + Unpatchify a tensor. + + Args: + x (torch.Tensor): (N, C, *spatial) tensor + patch_size (int): Patch size + """ + DIM = x.dim() - 2 + assert ( + x.shape[1] % (patch_size**DIM) == 0 + ), f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}" + + x = x.reshape( + x.shape[0], + x.shape[1] // (patch_size**DIM), + *([patch_size] * DIM), + *(x.shape[-DIM:]), + ) + x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], []))) + x = x.reshape( + x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)] + ) + return x diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e94bfeea8b366b91cc24c284dd3e75b4aa7da4f8 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .blocks import * +from .modulated import * diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/blocks.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8566c065f59fc454b87118cec1f98221ed1609 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/blocks.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention +from ..norm import LayerNorm32 + + +class AbsolutePositionEmbedder(nn.Module): + """ + Embeds spatial positions into vector representations. + """ + + def __init__(self, channels: int, in_channels: int = 3): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.freq_dim = channels // in_channels // 2 + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = 1.0 / (10000**self.freqs) + + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: + """ + Create sinusoidal position embeddings. + + Args: + x: a 1-D Tensor of N indices + + Returns: + an (N, D) Tensor of positional embeddings. + """ + self.freqs = self.freqs.to(x.device) + out = torch.outer(x, self.freqs) + out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1) + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (N, D) tensor of spatial positions + """ + N, D = x.shape + assert ( + D == self.in_channels + ), "Input dimension must match number of input channels" + embed = self._sin_cos_embedding(x.reshape(-1)) + embed = embed.reshape(N, -1) + if embed.shape[1] < self.channels: + embed = torch.cat( + [ + embed, + torch.zeros(N, self.channels - embed.shape[1], device=embed.device), + ], + dim=-1, + ) + return embed + + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class TransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN). + """ + + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[int] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = self.attn(h) + x = x + h + h = self.norm2(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, use_reentrant=False + ) + else: + return self._forward(x) + + +class TransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN). + """ + + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + ln_affine: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + + def _forward(self, x: torch.Tensor, context: torch.Tensor): + h = self.norm1(x) + h = self.self_attn(h) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = self.mlp(h) + x = x + h + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, context, use_reentrant=False + ) + else: + return self._forward(x, context) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/modulated.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a26247255aa96d034447083631bb775bf33ddf --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/transformer/modulated.py @@ -0,0 +1,341 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from functools import partial +from typing import * +from torch.utils import _pytree +import torch +import torch.nn as nn +from ..attention import MultiHeadAttention, MOTMultiHeadSelfAttention +from ..norm import LayerNorm32 +from .blocks import FeedForwardNet + + +class ModulatedTransformerBlock(nn.Module): + """ + Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(mod).chunk(6, dim=1) + ) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, use_reentrant=False + ) + else: + return self._forward(x, mod) + + +class ModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) + ) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(mod).chunk(6, dim=1) + ) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, context, use_reentrant=False + ) + else: + return self._forward(x, mod, context) + + +class MOTModulatedTransformerCrossBlock(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + latent_names: List = None, + freeze_shared_parameters: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = torch.nn.ModuleDict( + { + latent_name: LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + for latent_name in latent_names + } + ) + self.norm2 = torch.nn.ModuleDict( + { + latent_name: LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + for latent_name in latent_names + } + ) + self.norm3 = torch.nn.ModuleDict( + { + latent_name: LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + for latent_name in latent_names + } + ) + self.self_attn = MOTMultiHeadSelfAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + qk_rms_norm=qk_rms_norm, + latent_names=latent_names, + ) + self.cross_attn = torch.nn.ModuleDict( + { + latent_name: MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + for latent_name in latent_names + } + ) + self.mlp = torch.nn.ModuleDict( + { + latent_name: FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + for latent_name in latent_names + } + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) + ) + if freeze_shared_parameters: + self.adaLN_modulation.eval() + self.adaLN_modulation.requires_grad_(False) + + def _apply_module(self, h, module): + return module(h) + + def _apply_cross_attn(self, h, cross_attn, context): + return cross_attn(h, context) + + def _apply_msa(self, h, scale_msa, shift_msa): + return h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + + def _apply_mlp(self, h, scale_mlp, shift_mlp): + return h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + + def _apply_add(self, x, h): + return x + h + + def _apply_multiplication(self, h, multiplier): + return h * multiplier.unsqueeze(1) + + # This is stupid, _pytree does not support ModuleDict + def _moduledict_to_dict(self, module): + return {key: module for key, module in module.items()} + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk( + 6, dim=1 + ) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(mod).chunk(6, dim=1) + ) + h = _pytree.tree_map(self._apply_module, x, self._moduledict_to_dict(self.norm1)) + h = _pytree.tree_map( + partial(self._apply_msa, scale_msa=scale_msa, shift_msa=shift_msa), + h + ) + h = self.self_attn(h) + h = _pytree.tree_map( + partial(self._apply_multiplication, multiplier=gate_msa), + h + ) + x = _pytree.tree_map( + self._apply_add, + x, + h + ) + h = _pytree.tree_map(self._apply_module, x, self._moduledict_to_dict(self.norm2)) + h = _pytree.tree_map( + partial(self._apply_cross_attn, context=context), + h, + self._moduledict_to_dict(self.cross_attn), + ) + x = _pytree.tree_map( + self._apply_add, + x, + h + ) + h = _pytree.tree_map(self._apply_module, x, self._moduledict_to_dict(self.norm3)) + h = _pytree.tree_map( + partial(self._apply_mlp, scale_mlp=scale_mlp, shift_mlp=shift_mlp), + h + ) + h = _pytree.tree_map(self._apply_module, h, self._moduledict_to_dict(self.mlp)) + h = _pytree.tree_map( + partial(self._apply_multiplication, multiplier=gate_mlp), + h + ) + x = _pytree.tree_map( + self._apply_add, + x, + h + ) + return x + + def forward(self, x: Dict, mod: torch.Tensor, context: torch.Tensor): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint( + self._forward, x, mod, context, use_reentrant=False + ) + else: + return self._forward(x, mod, context) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..35f436a32e8259fde9a08d6b89a77527f2575617 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/modules/utils.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch.nn as nn +import torch + +FP16_TYPE = torch.float16 + +FP16_MODULES = [ + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, +] + +# If we add sparse modules back in, they are compatible with FP16. +# But for now we don't have them and avoid the dependency on FlashAttention +# Instead using the torch implementation of FlashAttention in SDPA. +try: + from ..modules import sparse as sp + + FP16_MODULES += [ + sp.SparseConv3d, + sp.SparseInverseConv3d, + sp.SparseLinear, + ] +except ImportError: + pass + +FP16_MODULES = tuple(FP16_MODULES) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + # p.data = p.data.half() + p.data = p.data.to(FP16_TYPE) + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, FP16_MODULES): + for p in l.parameters(): + p.data = p.data.float() + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..31308014907033b19b968268936a4b9773ea800d --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .octree_renderer import OctreeRenderer +from .gaussian_render import GaussianRenderer + +# handle case when nvdiffrast is not present on the machine +try: + from .mesh_renderer import MeshRenderer +except ImportError: + MeshRenderer = None diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/gaussian_render.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/gaussian_render.py new file mode 100755 index 0000000000000000000000000000000000000000..c0417142140d7057b66fb649846f013f93f85947 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/gaussian_render.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from easydict import EasyDict as edict +import numpy as np +from ..representations.gaussian import Gaussian +from .sh_utils import eval_sh +import torch.nn.functional as F +from easydict import EasyDict as edict +import warnings + +try: + from diff_gaussian_rasterization import ( + GaussianRasterizer, + GaussianRasterizationSettings, + ) +except ImportError: + warnings.warn( + "'diff_gaussian_rasterization' module cannot be imported, backend 'inria' won't be available", + ImportWarning, + ) + +from gsplat import rasterization + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, +) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = -2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1.0 + return ret + + +def render( + viewpoint_camera, + pc: Gaussian, + pipe, + bg_color: torch.Tensor, + scaling_modifier=1.0, + override_color=None, + backend="inria", +): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + means3D = pc.get_xyz + opacity = pc.get_opacity + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view( + -1, 3, (pc.max_sh_degree + 1) ** 2 + ) + dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( + pc.get_features.shape[0], 1 + ) + dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Backend-specific rasterization setup and execution + if backend == "inria": + kernel_size = pipe.kernel_size + subpixel_offset = torch.zeros( + (int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), + dtype=torch.float32, + device="cuda", + ) + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = ( + torch.zeros_like( + pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" + ) + + 0 + ) + try: + screenspace_points.retain_grad() + except: + pass + means2D = screenspace_points + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + kernel_size=kernel_size, + subpixel_offset=subpixel_offset, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug, + ) + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii = rasterizer( + means3D=means3D, + means2D=means2D, + shs=shs, + colors_precomp=colors_precomp, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=cov3D_precomp, + ) + elif backend == "gsplat": + """ + See reference code to convert from gsplat to inria: + https://github.com/nerfstudio-project/gsplat/blob/2323de5905d5e90e035f792fe65bad0fedd413e7/gsplat/rendering.py#L1108 + """ + # Unnormalize the intrinsics matrix to get pixel coordinates + Ks = viewpoint_camera.intrinsics.clone().unsqueeze(0) # Add batch dimension + Ks[0, 0, 0] *= viewpoint_camera.image_width # fx + Ks[0, 1, 1] *= viewpoint_camera.image_height # fy + Ks[0, 0, 2] *= viewpoint_camera.image_width # cx + Ks[0, 1, 2] *= viewpoint_camera.image_height # cy + + # For gsplat, when using SH coefficients, pass them as colors with sh_degree set + gsplat_colors = colors_precomp if colors_precomp is not None else shs + gsplat_sh_degree = pc.active_sh_degree if shs is not None else None + + render_colors, render_alphas, meta = rasterization( + means=means3D, + quats=rotations, + scales=scales, + opacities=opacity.squeeze(-1), + colors=gsplat_colors, + sh_degree=gsplat_sh_degree, + viewmats=viewpoint_camera.world_view_transform.T.contiguous().unsqueeze(0), + Ks=Ks, + width=int(viewpoint_camera.image_width), + height=int(viewpoint_camera.image_height), + backgrounds=bg_color, + ) + rendered_image = render_colors.squeeze(0).permute( + 2, 0, 1 + ) # Convert to (C, H, W) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return edict( + { + "render": rendered_image, + } + ) + + +class GaussianRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + self.pipe = edict( + { + "kernel_size": 0.1, + "convert_SHs_python": False, + "compute_cov3D_python": False, + "scale_modifier": 1.0, + "debug": False, + } + ) + + self.rendering_options = edict( + { + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": "random", + "backend": "inria", + } + ) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + gausssian: Gaussian, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the gausssian. + + Args: + gaussian : gaussianmodule + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color image + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.rendering_options["bg_color"] == "random": + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" + ) + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict( + { + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera, + "intrinsics": intrinsics, + } + ) + + # Render + render_ret = render( + camera_dict, + gausssian, + self.pipe, + self.bg_color, + override_color=colors_overwrite, + scaling_modifier=self.pipe.scale_modifier, + backend=self.rendering_options["backend"], + ) + + if ssaa > 1: + render_ret.render = F.interpolate( + render_ret.render[None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + + ret = edict({"color": render_ret["render"]}) + return ret diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/octree_renderer.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/octree_renderer.py new file mode 100755 index 0000000000000000000000000000000000000000..511c0757d5c6e6f0dabb2c5537b244322acf4a68 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/octree_renderer.py @@ -0,0 +1,390 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, +) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = -2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1.0 + return ret + + +def render( + viewpoint_camera, + octree: DfsOctree, + pipe, + bg_color: torch.Tensor, + scaling_modifier=1.0, + used_rank=None, + colors_overwrite=None, + aux=None, + halton_sampler=None, +): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if "OctreeTrivecRasterizer" not in globals(): + from diffoctreerast import ( + OctreeVoxelRasterizer, + OctreeGaussianRasterizer, + OctreeTrivecRasterizer, + OctreeDecoupolyRasterizer, + ) + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions=positions, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, + ) + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + ret["distloss"] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions=positions, + opacities=opacities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, + ) + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = ( + used_rank if used_rank is not None else trivecs.shape[1] + ) + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions=positions, + trivecs=trivecs, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + colors_overwrite=colors_overwrite, + depths=depths, + aabb=octree.aabb, + aux=aux, + halton_sampler=halton_sampler, + ) + ret["percent_depth"] = percent_depth + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = ( + used_rank if used_rank is not None else decoupolys_V.shape[1] + ) + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions=positions, + decoupolys_V=decoupolys_V, + decoupolys_g=decoupolys_g, + densities=densities, + shs=shs, + colors_precomp=colors_precomp, + depths=depths, + aabb=octree.aabb, + aux=aux, + ) + ret["rgb"] = rgb + ret["depth"] = depth + ret["alpha"] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print( + "\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m" + ) + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict( + { + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + } + ) + self.rendering_options = edict( + { + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": "random", + } + ) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[ + 0 + ] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText( + image, + "Unsupported", + origin, + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (255, 255, 255), + 3, + cv2.LINE_AA, + ) + return { + "color": torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) + / 255, + } + + if self.rendering_options["bg_color"] == "random": + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor( + self.rendering_options["bg_color"], dtype=torch.float32, device="cuda" + ) + + if self.pipe["with_aux"]: + aux = { + "grad_color2": torch.zeros( + (octree.num_leaf_nodes, 3), + dtype=torch.float32, + requires_grad=True, + device="cuda", + ) + + 0, + "contributions": torch.zeros( + (octree.num_leaf_nodes, 1), + dtype=torch.float32, + requires_grad=True, + device="cuda", + ) + + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict( + { + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera, + } + ) + + # Render + render_ret = render( + camera_dict, + octree, + self.pipe, + self.bg_color, + aux=aux, + colors_overwrite=colors_overwrite, + scaling_modifier=self.pipe.scale_modifier, + used_rank=self.pipe.used_rank, + halton_sampler=self.halton_sampler, + ) + + if ssaa > 1: + render_ret.rgb = F.interpolate( + render_ret.rgb[None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + render_ret.depth = F.interpolate( + render_ret.depth[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + render_ret.alpha = F.interpolate( + render_ret.alpha[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + if hasattr(render_ret, "percent_depth"): + render_ret.percent_depth = F.interpolate( + render_ret.percent_depth[None, None], + size=(resolution, resolution), + mode="bilinear", + align_corners=False, + antialias=True, + ).squeeze() + + ret = edict( + { + "color": render_ret.rgb, + "depth": render_ret.depth, + "alpha": render_ret.alpha, + } + ) + if self.pipe["with_distloss"] and "distloss" in render_ret: + ret["distloss"] = render_ret.distloss + if self.pipe["with_aux"]: + ret["aux"] = aux + if hasattr(render_ret, "percent_depth"): + ret["percent_depth"] = render_ret.percent_depth + return ret diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/sh_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/sh_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..16139dfd882bdde1cdbebcdf3031fb14ac566b01 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/renderers/sh_utils.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396, +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435, +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = ( + result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] + ) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = ( + result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8] + ) + + if deg > 2: + result = ( + result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15] + ) + + if deg > 3: + result = ( + result + + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] + * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + * sh[..., 24] + ) + return result + + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + + +def SH2RGB(sh): + return sh * C0 + 0.5 diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..c1332334d96e6eac0285c549a3890a77e745abeb --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .radiance_field import Strivec +from .octree import DfsOctree as Octree +from .gaussian import Gaussian +from .mesh import MeshExtractResult diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..682744a25c9ee20cee38a3d7415cfe880ae83eac --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .gaussian_model import Gaussian diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/gaussian_model.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/gaussian_model.py new file mode 100755 index 0000000000000000000000000000000000000000..e288730f26a9d49441ab3800728b5a52b657a2f8 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/gaussian_model.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation + + +class Gaussian: + def __init__( + self, + aabb: list, + sh_degree: int = 0, + mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, + opacity_bias: float = 0.1, + scaling_activation: str = "exp", + device="cuda", + ): + self.init_params = { + "aabb": aabb, + "sh_degree": sh_degree, + "mininum_kernel_size": mininum_kernel_size, + "scaling_bias": scaling_bias, + "opacity_bias": opacity_bias, + "scaling_activation": scaling_activation, + } + + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def setup_functions(self): + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = softplus_inverse_scaling_activation + + self.covariance_activation = self.build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation( + torch.tensor(self.scaling_bias) + ).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation( + torch.tensor(self.opacity_bias) + ).cuda() + + @staticmethod + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size**2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_features(self): + return ( + torch.cat((self._features_dc, self._features_rest), dim=2) + if self._features_rest is not None + else self._features_dc + ) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier=1): + return self.covariance_activation( + self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :] + ) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ["x", "y", "z", "nx", "ny", "nz"] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): + l.append("f_dc_{}".format(i)) + l.append("opacity") + for i in range(self._scaling.shape[1]): + l.append("scale_{}".format(i)) + for i in range(self._rotation.shape[1]): + l.append("rot_{}".format(i)) + return l + + def save_ply(self, path): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = ( + self._features_dc.detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + dtype_full = [ + (attribute, "f4") for attribute in self.construct_list_of_attributes() + ] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate( + (xyz, normals, f_dc, opacities, scale, rotation), axis=1 + ) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, "vertex") + PlyData([el]).write(path) + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack( + ( + np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"]), + ), + axis=1, + ) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("f_rest_") + ] + extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) + assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape( + (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1) + ) + + scale_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("scale_") + ] + scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [ + p.name for p in plydata.elements[0].properties if p.name.startswith("rot") + ] + rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = ( + torch.tensor(features_dc, dtype=torch.float, device=self.device) + .transpose(1, 2) + .contiguous() + ) + if self.sh_degree > 0: + features_extra = ( + torch.tensor(features_extra, dtype=torch.float, device=self.device) + .transpose(1, 2) + .contiguous() + ) + opacities = torch.sigmoid( + torch.tensor(opacities, dtype=torch.float, device=self.device) + ) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = ( + self.inverse_scaling_activation( + torch.sqrt(torch.square(scales) - self.mininum_kernel_size**2) + ) + - self.scale_bias + ) + self._rotation = rots - self.rots_bias[None, :] + +def softplus_inverse_scaling_activation(x): + return x + torch.log(-torch.expm1(-x)) \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/general_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/general_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..58e9fb8c295abcf39bcf404a96994de32925c971 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/gaussian/general_utils.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + + +def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + + +def build_rotation(r): + norm = torch.sqrt( + r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] + ) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device="cuda") + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y * y + z * z) + R[:, 0, 1] = 2 * (x * y - r * z) + R[:, 0, 2] = 2 * (x * z + r * y) + R[:, 1, 0] = 2 * (x * y + r * z) + R[:, 1, 1] = 1 - 2 * (x * x + z * z) + R[:, 1, 2] = 2 * (y * z - r * x) + R[:, 2, 0] = 2 * (x * z - r * y) + R[:, 2, 1] = 2 * (y * z + r * x) + R[:, 2, 2] = 1 - 2 * (x * x + y * y) + return R + + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:, 0, 0] = s[:, 0] + L[:, 1, 1] = s[:, 1] + L[:, 2, 2] = s[:, 2] + + L = R @ L + return L + + +def safe_state(silent): + old_f = sys.stdout + + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write( + x.replace( + "\n", + " [{}]\n".format( + str(datetime.now().strftime("%d/%m %H:%M:%S")) + ), + ) + ) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da08f14e64fee73412ae45a064b948f344324ab9 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/cube2mesh.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/cube2mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..f9580ca69a316394ad28f402d5c268363b3a1f94 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/cube2mesh.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +from .flexicubes.flexicubes import FlexiCubes + + +class MeshExtractResult: + def __init__(self, vertices, faces, vertex_attrs=None, res=64): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.res = res + self.success = vertices.shape[0] != 0 and faces.shape[0] != 0 + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + # print(face_normals.min(), face_normals.max(), face_normals.shape) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + """ + a model to generate a mesh from sparse features structures using flexicube + """ + super().__init__() + self.device = device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + "sdf": {"shape": (8, 1), "size": 8}, + "deform": {"shape": (8, 3), "size": 8 * 3}, + "weights": {"shape": (21,), "size": 21}, + } + if self.use_color: + """ + 6 channel color including normal map + """ + LAYOUTS["color"] = { + "shape": ( + 8, + 6, + ), + "size": 8 * 6, + } + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v["range"] = (start, start + v["size"]) + start += v["size"] + self.feats_channels = start + + def get_layout(self, feats: torch.Tensor, name: str): + if name not in self.layouts: + return None + return feats[ + :, self.layouts[name]["range"][0] : self.layouts[name]["range"][1] + ].reshape(-1, *self.layouts[name]["shape"]) + + def __call__(self, cubefeats: SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [ + self.get_layout(feats, name) + for name in ["sdf", "deform", "color", "weights"] + ] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts( + coords, torch.cat(v_attrs, dim=-1), training=training + ) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res + 1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = ( + v_attrs_d[..., 0], + v_attrs_d[..., 1:4], + v_attrs_d[..., 4:], + ) + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training, + ) + + mesh = MeshExtractResult( + vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res + ) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:, :20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/flexicubes.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/flexicubes.py new file mode 100644 index 0000000000000000000000000000000000000000..583b5e3f5af844cd0416cbf27065ccb6c4377792 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/flexicubes.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from .tables import * +from kaolin.utils.testing import check_tensor + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + def __init__(self, device="cuda"): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + + def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, + weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + assert torch.is_tensor(voxelgrid_vertices) and \ + check_tensor(voxelgrid_vertices, (None, 3), throw=False), \ + "'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)" + num_vertices = voxelgrid_vertices.shape[0] + assert torch.is_tensor(scalar_field) and \ + check_tensor(scalar_field, (num_vertices,), throw=False), \ + "'scalar_field' should be a tensor of shape (num_vertices,)" + assert torch.is_tensor(cube_idx) and \ + check_tensor(cube_idx, (None, 8), throw=False), \ + "'cube_idx' should be a tensor of shape (num_cubes, 8)" + num_cubes = cube_idx.shape[0] + assert beta is None or ( + torch.is_tensor(beta) and + check_tensor(beta, (num_cubes, 12), throw=False) + ), "'beta' should be a tensor of shape (num_cubes, 12)" + assert alpha is None or ( + torch.is_tensor(alpha) and + check_tensor(alpha, (num_cubes, 8), throw=False) + ), "'alpha' should be a tensor of shape (num_cubes, 8)" + assert gamma_f is None or ( + torch.is_tensor(gamma_f) and + check_tensor(gamma_f, (num_cubes,), throw=False) + ), "'gamma_f' should be a tensor of shape (num_cubes,)" + + surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) + if surf_cubes.sum() == 0: + return ( + torch.zeros((0, 3), device=self.device), + torch.zeros((0, 3), dtype=torch.long, device=self.device), + torch.zeros((0), device=self.device), + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + ) + beta, alpha, gamma_f = self._normalize_weights( + beta, alpha, gamma_f, surf_cubes, weight_scale) + + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( + scalar_field, cube_idx, surf_cubes + ) + + vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( + voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( + scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, + vd_idx_map, surf_edges_mask, training, vd_color) + return vertices, faces, L_dev, vertices_color + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta is not None: + beta = (torch.tanh(beta) * weight_scale + 1) + else: + beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha is not None: + alpha = (torch.tanh(alpha) * weight_scale + 1) + else: + alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = scalar_field < 0 + all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, scalar_field, cube_idx): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = scalar_field < 0 + occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] + , edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + if voxelgrid_colors is not None: + C = voxelgrid_colors.shape[-1] + surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + # if color is not None: + # vd_color = [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + # if color is not None: + # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) + + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + # if color is not None: + # vd_color = torch.cat(vd_color) + # else: + # vd_color = None + + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + + ''' + interpolate colors use the same method as dual vertices + ''' + if voxelgrid_colors is not None: + vd_color = torch.zeros((total_num_vd, C), device=self.device) + c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + uc_group = self._linear_interp(s_group * alpha_group, c_group) + vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + else: + vd_color = None + + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map, vd_color + + def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] + gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] + if not training: + mask = (gamma_02 > gamma_13) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 + vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + + if vd_color is not None: + color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 + color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 + color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + vd_color = torch.cat([vd_color, color_center]) + + + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices, vd_color \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/tables.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8ef58994f90715e214c5529b2c41dffe13fc6b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/flexicubes/tables.py @@ -0,0 +1,799 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/utils_cube.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/utils_cube.py new file mode 100644 index 0000000000000000000000000000000000000000..d84d7c4de1d2f1f9b5917d47f3f8904d50ef59f7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/mesh/utils_cube.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch + +cube_corners = torch.tensor( + [ + [0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [1, 1, 1], + ], + dtype=torch.int, +) +cube_neighbor = torch.tensor( + [[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]] +) +cube_edges = torch.tensor( + [0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], + dtype=torch.long, + requires_grad=False, +) + + +def construct_dense_grid(res, device="cuda"): + """construct a dense grid based on resolution""" + res_v = res + 1 + vertsid = torch.arange(res_v**3, device=device) + coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten() + cube_corners_bias = ( + cube_corners[:, 0] * res_v + cube_corners[:, 1] + ) * res_v + cube_corners[:, 2] + cube_fx8 = coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device) + verts = torch.stack( + [vertsid // (res_v**2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1 + ) + return verts, cube_fx8 + + +def construct_voxel_grid(coords): + verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3) + verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True) + cubes = inverse_indices.reshape(-1, 8) + return verts_unique, cubes + + +def cubes_to_verts(num_verts, cubes, value, reduce="mean"): + """ + Args: + cubes [Vx8] verts index for each cube + value [Vx8xM] value to be scattered + Operation: + reduced[cubes[i][j]][k] += value[i][k] + """ + M = value.shape[2] # number of channels + reduced = torch.zeros(num_verts, M, device=cubes.device, dtype=value.dtype) + return torch.scatter_reduce( + reduced, + 0, + cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1), + value.flatten(0, 1), + reduce=reduce, + include_self=False, + ) + + +def sparse_cube2verts(coords, feats, training=True): + new_coords, cubes = construct_voxel_grid(coords) + new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats) + if training: + con_loss = torch.mean((feats - new_feats[cubes]) ** 2) + else: + con_loss = 0.0 + return new_coords, new_feats, con_loss + + +def get_dense_attrs(coords: torch.Tensor, feats: torch.Tensor, res: int, sdf_init=True): + F = feats.shape[-1] + dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device, dtype=feats.dtype) + if sdf_init: + dense_attrs[..., 0] = 1 # initial outside sdf value + dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats + return dense_attrs.reshape(-1, F) + + +def get_defomed_verts(v_pos: torch.Tensor, deform: torch.Tensor, res): + return (v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)).to( + deform.dtype + ) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..afba2daa30062b6a6522e7c0e6b54551876daf2e --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .octree_dfs import DfsOctree diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/octree_dfs.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/octree_dfs.py new file mode 100755 index 0000000000000000000000000000000000000000..de3b583828de3dc02498aa938e4799a0c77cbe4e --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/octree/octree_dfs.py @@ -0,0 +1,583 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +DEFAULT_TRIVEC_CONFIG = { + "dim": 8, + "rank": 8, +} + +DEFAULT_VOXEL_CONFIG = { + "solid": False, +} + +DEFAULT_DECOPOLY_CONFIG = { + "degree": 8, + "rank": 16, +} + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0, 0, 0, 1, 1, 1], + sh_degree=2, + primitive="voxel", + primitive_config={}, + device="cuda", + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor( + [[8, 1, 0]], dtype=torch.int32, device=self.device + ) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor( + [0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device + ) + self.position[:, 1] = torch.tensor( + [0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device + ) + self.position[:, 2] = torch.tensor( + [0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device + ) + self.depth[:, 0] = 1 + + self.data = ["position", "depth"] + self.param_names = [] + + if primitive == "voxel": + self.features_dc = torch.zeros( + (8, 1, 3), dtype=torch.float32, device=self.device + ) + self.features_ac = torch.zeros( + (8, (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.data += ["features_dc", "features_ac"] + self.param_names += ["features_dc", "features_ac"] + if not primitive_config.get("solid", False): + self.density = torch.zeros( + (8, 1), dtype=torch.float32, device=self.device + ) + self.data.append("density") + self.param_names.append("density") + elif primitive == "gaussian": + self.features_dc = torch.zeros( + (8, 1, 3), dtype=torch.float32, device=self.device + ) + self.features_ac = torch.zeros( + (8, (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ["features_dc", "features_ac", "opacity"] + self.param_names += ["features_dc", "features_ac", "opacity"] + elif primitive == "trivec": + self.trivec = torch.zeros( + (8, primitive_config["rank"], 3, primitive_config["dim"]), + dtype=torch.float32, + device=self.device, + ) + self.density = torch.zeros( + (8, primitive_config["rank"]), dtype=torch.float32, device=self.device + ) + self.features_dc = torch.zeros( + (8, primitive_config["rank"], 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.features_ac = torch.zeros( + (8, primitive_config["rank"], (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.density_shift = 0 + self.data += ["trivec", "density", "features_dc", "features_ac"] + self.param_names += ["trivec", "density", "features_dc", "features_ac"] + elif primitive == "decoupoly": + self.decoupoly_V = torch.zeros( + (8, primitive_config["rank"], 3), + dtype=torch.float32, + device=self.device, + ) + self.decoupoly_g = torch.zeros( + (8, primitive_config["rank"], primitive_config["degree"]), + dtype=torch.float32, + device=self.device, + ) + self.density = torch.zeros( + (8, primitive_config["rank"]), dtype=torch.float32, device=self.device + ) + self.features_dc = torch.zeros( + (8, primitive_config["rank"], 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.features_ac = torch.zeros( + (8, primitive_config["rank"], (sh_degree + 1) ** 2 - 1, 3), + dtype=torch.float32, + device=self.device, + ) + self.density_shift = 0 + self.data += [ + "decoupoly_V", + "decoupoly_g", + "density", + "features_dc", + "features_ac", + ] + self.param_names += [ + "decoupoly_V", + "decoupoly_g", + "density", + "features_dc", + "features_ac", + ] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = ( + (lambda x: torch.exp(x - 2)) + if self.primitive != "trivec" + else (lambda x: x) + ) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8**self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == "voxel" and self.voxel_config["solid"]: + return torch.full( + (self.position.shape[0], 1), + 1000, + dtype=torch.float32, + device=self.device, + ) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = { + "structure": self.structure, + "position": self.position, + "depth": self.depth, + "sh_degree": self.sh_degree, + "active_sh_degree": self.active_sh_degree, + "trivec_config": self.trivec_config, + "voxel_config": self.voxel_config, + "primitive": self.primitive, + } + if hasattr(self, "density_shift"): + ret["density_shift"] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list( + set( + self.data + + self.param_names + + list(state_dict.keys()) + + ["structure", "position", "depth"] + ) + ) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros( + (self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device + ) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i + 1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i + 1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max( + mask[self.depth.squeeze() == self.max_depth], 0 + ) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min( + mask[self.depth.squeeze() == 1], 0 + ) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl == -8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[ + structre_valid + ] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros( + (self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device + ) + mem_offset.index_add_( + 0, self.structure[structre_valid, 1], structre_ctrl[structre_valid] + ) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange( + 0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device + ) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty( + (new_structure_length, 3), dtype=torch.int32, device=self.device + ) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[ + structre_valid + ] + + # Initialize the new nodes. + new_node_mask = torch.ones( + (new_structure_length,), dtype=torch.bool, device=self.device + ) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + non_leaf_cnt.cumsum(0)[:-1], + ] + ) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros( + (new_structure_length,), dtype=torch.int32, device=self.device + ) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + leaf_cnt.cumsum(0)[:-1], + ] + ) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros( + (self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device + ) + mem_offset.index_add_( + 0, + old_data_ptr[new_node_mask], + torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device), + ) # Add data array for new nodes + mem_offset[ + :-1 + ] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_( + 0, self.structure[structre_valid, 2], merged_nodes[structre_valid] + ) # Add data elements for merge nodes + new_data_idx = torch.arange( + 0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device + ) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = { + data: torch.empty( + (new_data_length,) + getattr(self, data).shape[1:], + dtype=getattr(self, data).dtype, + device=self.device, + ) + for data in self.data + } + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + leaf_cnt.cumsum(0)[:-1], + ] + ) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == "position": + offset = ( + torch.tensor( + [i // 4, (i // 2) % 2, i % 2], + dtype=torch.float32, + device=self.device, + ) + - 0.5 + ) + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data["position"][subdivide_data_ptr + i] = ( + self.position[subdivide_mask] + offset * scale + ) + elif data == "depth": + new_data["depth"][subdivide_data_ptr + i] = ( + self.depth[subdivide_mask] + 1 + ) + elif data == "opacity": + new_data["opacity"][subdivide_data_ptr + i] = ( + self.inverse_opacity_activation( + torch.sqrt( + self.opacity_activation( + self.opacity[subdivide_mask] + ) + ) + ) + ) + elif data == "trivec": + offset = ( + torch.tensor( + [i // 4, (i // 2) % 2, i % 2], + dtype=torch.float32, + device=self.device, + ) + * 0.5 + ) + coord = ( + torch.linspace( + 0, + 0.5, + self.trivec.shape[-1], + dtype=torch.float32, + device=self.device, + )[None] + + offset[:, None] + ).reshape(1, 3, self.trivec.shape[-1], 1) + axis = ( + torch.linspace( + 0, 1, 3, dtype=torch.float32, device=self.device + ) + .reshape(1, 3, 1, 1) + .repeat(1, 1, self.trivec.shape[-1], 1) + ) + coord = ( + torch.stack([coord, axis], dim=3) + .reshape(1, 3, self.trivec.shape[-1], 2) + .expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) + * 2 + - 1 + ) + new_data["trivec"][subdivide_data_ptr + i] = F.grid_sample( + self.trivec[subdivide_mask], coord, align_corners=True + ) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[ + subdivide_mask + ] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty( + (merged_nodes.sum().item(),), dtype=torch.int32, device=self.device + ) + merge_nodes_cumsum = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, device=self.device), + merged_nodes.cumsum(0)[:-1], + ] + ) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = ( + new_structure[new_structre_idx[merged_nodes > i], 2] + i + ) + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == "position": + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data["position"][merge_data_ptr] = ( + ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() + * scale + + 0.5 * scale + - 0.5 + ) + elif data == "depth": + new_data["depth"][merge_data_ptr] = ( + self.depth[old_merge_data_ptr] - 1 + ) + elif data == "opacity": + new_data["opacity"][subdivide_data_ptr + i] = ( + self.inverse_opacity_activation( + self.opacity_activation(self.opacity[subdivide_mask]) ** 2 + ) + ) + elif data == "trivec": + new_data["trivec"][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[ + old_merge_data_ptr + ] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + "subdivide_mask": subdivide_mask, + "merge_mask": merge_mask, + "data_valid": data_valid, + "new_data_idx": new_data_idx, + "new_data_length": new_data_length, + "new_data": new_data, + } diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..62ea036a70ecb3e359f86f2af55e50f6467dfc7c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .strivec import Strivec diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/strivec.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/strivec.py new file mode 100644 index 0000000000000000000000000000000000000000..f753b2e724ce2cd257dd9ce0130c9454ed4fc74b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/representations/radiance_field/strivec.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..octree import DfsOctree as Octree + + +class Strivec(Octree): + def __init__( + self, + resolution: int, + aabb: list, + sh_degree: int = 0, + rank: int = 8, + dim: int = 8, + device: str = "cuda", + ): + assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" + self.resolution = resolution + depth = int(np.round(np.log2(resolution))) + super().__init__( + depth=depth, + aabb=aabb, + sh_degree=sh_degree, + primitive="trivec", + primitive_config={"rank": rank, "dim": dim}, + device=device, + ) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de7a4da0b4290f883b37e3685aa5aa8780c4f8ce --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/postprocessing_utils.py @@ -0,0 +1,835 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import * +import numpy as np +import torch +import utils3d +from PIL import Image +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..renderers import GaussianRenderer +from ..representations import Strivec, Gaussian, MeshExtractResult +from loguru import logger + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False, +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for yaw, pitch in zip(yaws, pitchs): + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ] + ) + .cuda() + .float() + * radius + ) + view = utils3d.torch.view_look_at( + orig, + torch.tensor([0, 0, 0]).float().cuda(), + torch.tensor([0, 0, 1]).float().cuda(), + ) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend="cuda") + for i in tqdm( + range(views.shape[0]), + total=views.shape[0], + disable=not verbose, + desc="Rasterizing", + ): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, + verts[None], + faces, + resolution, + resolution, + view=view, + projection=projection, + ) + face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components( + faces, edges, face2edge + ) + outer_face_indices = torch.zeros( + faces.shape[0], dtype=torch.bool, device=faces.device + ) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[ + connected_components[i] + ] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f"Found {inner_face_indices.shape[0]} invisible faces") + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm( + verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1 + ) + if verbose: + tqdm.write(f"Dual graph: {dual_edges.shape[0]} edges") + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es["weight"] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex("s") + g.add_vertex("t") + + ### connect invisible faces to source + g.add_edges( + [(f, "s") for f in inner_face_indices], + attributes={ + "weight": torch.ones(inner_face_indices.shape[0], dtype=torch.float32) + .cpu() + .numpy() + }, + ) + + ### connect outer faces to target + g.add_edges( + [(f, "t") for f in outer_face_indices], + attributes={ + "weight": torch.ones(outer_face_indices.shape[0], dtype=torch.float32) + .cpu() + .numpy() + }, + ) + + ### solve mincut + cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist()) + remove_face_indices = torch.tensor( + [v for v in cut.partition[0] if v < faces.shape[0]], + dtype=torch.long, + device=faces.device, + ) + if verbose: + tqdm.write(f"Mincut solved, start checking the cut") + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components( + faces[remove_face_indices] + ) + if debug: + tqdm.write(f"Number of connected components of the cut: {len(to_remove_cc)}") + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f"visblity_median: {visblity_median}") + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique( + face2edge[remove_face_indices[cc]], return_counts=True + ) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[ + ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices) + ] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components( + edges[cc_new_boundary_edge_indices] + ) + cc_new_boundary_edges_cc_center = [ + verts[edges[cc_new_boundary_edge_indices[edge_cc]]] + .mean(dim=1) + .mean(dim=0) + for edge_cc in cc_new_boundary_edge_cc + ] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = ( + verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] + - cc_new_boundary_edges_cc_center[i] + ) + _e2 = ( + verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] + - cc_new_boundary_edges_cc_center[i] + ) + cc_new_boundary_edges_cc_area.append( + torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5 + ) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f"Area of the cutting loop: {cc_new_boundary_edges_cc_area}") + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[ + remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy() + ] = [255, 0, 0] + utils3d.io.write_ply( + "dbg_dual.ply", face_v, edges=vis_dual_edges, vertex_colors=vis_colors + ) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply("dbg_cut.ply", vis_verts, edges=vis_edges) + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f"Removed {(~mask).sum()} faces by mincut") + else: + if verbose: + tqdm.write(f"Removed 0 faces by mincut") + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor( + verts, device="cuda", dtype=torch.float32 + ), torch.tensor(faces, device="cuda", dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write( + f"Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData( + vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1) + ) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write( + f"After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) + + # Remove invisible faces + if fill_holes: + vertices, faces = ( + torch.tensor(vertices).cuda(), + torch.tensor(faces.astype(np.int32)).cuda(), + ) + vertices, faces = _fill_holes( + vertices, + faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write( + f"After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces" + ) + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + +@torch.inference_mode(False) +@torch.enable_grad() +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal["fast", "opt"] = "opt", + lambda_tv: float = 1e-2, + verbose: bool = False, + rendering_engine: str = "nvdiffrast", # nvdiffrast OR "pytorch3d" + device: str = "cuda", + +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + + + vertices = torch.tensor(vertices).to(device) + faces = torch.tensor(faces.astype(np.int32)).to(device) + uvs = torch.tensor(uvs).to(device) + observations = [torch.tensor(obs / 255.0).float().to(device) for obs in observations] + masks = [torch.tensor(m > 0).bool().to(device) for m in masks] + views = [ + utils3d.torch.extrinsics_to_view(torch.tensor(extr).to(device)) + for extr in extrinsics + ] + projections = [ + utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).to(device), near, far) + for intr in intrinsics + ] + + if mode == "fast": + texture = torch.zeros( + (texture_size * texture_size, 3), dtype=torch.float32 + ).to(device) + texture_weights = torch.zeros( + (texture_size * texture_size), dtype=torch.float32 + ).to(device) + rastctx = utils3d.torch.RastContext(backend=device if device.startswith("cuda") else "cuda") + for observation, view, projection in tqdm( + zip(observations, views, projections), + total=len(observations), + disable=not verbose, + desc="Texture baking (fast)", + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + vertices[None], + faces, + observation.shape[1], + observation.shape[0], + uv=uvs[None], + view=view, + projection=projection, + ) + uv_map = rast["uv"][0].detach().flip(0) + mask = rast["mask"][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add( + 0, + idx, + torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device), + ) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip( + texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) + + # inpaint + mask = ( + (texture_weights == 0) + .cpu() + .numpy() + .astype(np.uint8) + .reshape(texture_size, texture_size) + ) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == "opt": + rastctx = utils3d.torch.RastContext(backend=device if device.startswith("cuda") else "cuda") + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm( + zip(observations, views, projections), + total=len(views), + disable=not verbose, + desc="Texture baking (opt): UV", + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + vertices[None], + faces, + observation.shape[1], + observation.shape[0], + uv=uvs[None], + view=view, + projection=projection, + ) + _uv.append(rast["uv"].detach()) + _uv_dr.append(rast["uv_dr"].detach()) + + texture = torch.nn.Parameter( + torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).to(device) + ) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + + def tv_loss(texture): + return torch.nn.functional.l1_loss( + texture[:, :-1, :, :], texture[:, 1:, :, :] + ) + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + + + def render_pt3d_texture(texture, uv, uv_dr=None): + import torch.nn.functional as F + texture_perm = texture.permute(0, 3, 1, 2) + grid = uv * 2 - 1 + if grid.dim() == 3: + grid = grid.unsqueeze(0) # (1, H, W, 2) + elif grid.dim() == 4 and grid.shape[0] == 1: + pass + elif grid.dim() == 4 and grid.shape[1] == 1: + grid = grid.squeeze(1) # remove extra batch dimension if necessary + else: + raise ValueError(f"Unexpected grid shape: {grid.shape}") + render = F.grid_sample( + texture_perm, grid, mode='bilinear', padding_mode='border', align_corners=True + ) + render = render.permute(0, 2, 3, 1)[0] # (H_out, W_out, 3) + return render + + + total_steps = 2500 + + with tqdm( + total=total_steps, + disable=not verbose, + desc="Texture baking (opt): optimizing", + ) as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = ( + _uv[selected], + _uv_dr[selected], + observations[selected], + masks[selected], + ) + + if rendering_engine == "nvdiffrast": + import nvdiffrast.torch as dr + render = dr.texture(texture, uv, uv_dr)[0] + + if rendering_engine == "pytorch3d": + render = render_pt3d_texture(texture, uv) + + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]["lr"] = cosine_anealing( + optimizer, step, total_steps, 1e-2, 1e-5 + ) + pbar.set_postfix({"loss": loss.item()}) + pbar.update() + texture = np.clip( + texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )["mask"][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f"Unknown mode: {mode}") + + return texture + + +def to_glb( + app_rep: Union[Strivec, Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, + with_mesh_postprocess=True, + with_texture_baking=True, + use_vertex_color=False, + rendering_engine: str = "nvdiffrast", # nvdiffrast OR "pytorch3d" +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Strivec, Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.float().cpu().numpy() + faces = mesh.faces.cpu().numpy() + vert_colors = mesh.vertex_attrs[:, :3].cpu().numpy() + + if with_mesh_postprocess: + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, + faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1 - simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + if with_texture_baking: + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + logger.info("Baking texture ...") + + # bake texture + observations, extrinsics, intrinsics = render_multiview( + app_rep, resolution=1024, nviews=100 + ) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, + faces, + uvs, + observations, + masks, + extrinsics, + intrinsics, + texture_size=texture_size, + mode="opt", + lambda_tv=0.01, + verbose=verbose, + rendering_engine=rendering_engine + ) + texture = Image.fromarray(texture) + material = trimesh.visual.material.PBRMaterial( + roughnessFactor=1.0, + baseColorTexture=texture, + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), + ) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + + if not with_mesh_postprocess and not with_texture_baking and use_vertex_color: + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + mesh.visual.vertex_colors = vert_colors + else: + mesh = trimesh.Trimesh( + vertices, + faces, + visual=( + trimesh.visual.TextureVisuals(uv=uvs, material=material) + if with_texture_baking + else None + ), + ) + + return mesh + + +def simplify_gs( + gs: Gaussian, + simplify: float = 0.95, + verbose: bool = True, +): + """ + Simplify 3D Gaussians + NOTE: this function is not used in the current implementation for the unsatisfactory performance. + + Args: + gs (Gaussian): 3D Gaussian. + simplify (float): Ratio of Gaussians to remove in simplification. + """ + if simplify <= 0: + return gs + + # simplify + observations, extrinsics, intrinsics = render_multiview( + gs, resolution=1024, nviews=100 + ) + observations = [ + torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) + for obs in observations + ] + + # Following https://arxiv.org/pdf/2411.06019 + renderer = GaussianRenderer( + { + "resolution": 1024, + "near": 0.8, + "far": 1.6, + "ssaa": 1, + "bg_color": (0, 0, 0), + } + ) + new_gs = Gaussian(**gs.init_params) + new_gs._features_dc = gs._features_dc.clone() + new_gs._features_rest = ( + gs._features_rest.clone() if gs._features_rest is not None else None + ) + new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) + new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) + new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) + new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) + + start_lr = [1e-4, 1e-3, 5e-3, 0.025] + end_lr = [1e-6, 1e-5, 5e-5, 0.00025] + optimizer = torch.optim.Adam( + [ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], + lr=start_lr[0], + ) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + + _zeta = new_gs.get_opacity.clone().detach().squeeze() + _lambda = torch.zeros_like(_zeta) + _delta = 1e-7 + _interval = 10 + num_target = int((1 - simplify) * _zeta.shape[0]) + + with tqdm(total=2500, disable=not verbose, desc="Simplifying Gaussian") as pbar: + for i in range(2500): + # prune + if i % 100 == 0: + mask = new_gs.get_opacity.squeeze() > 0.05 + mask = torch.nonzero(mask).squeeze() + new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) + new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) + new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) + new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) + new_gs._features_dc = new_gs._features_dc[mask] + new_gs._features_rest = ( + new_gs._features_rest[mask] + if new_gs._features_rest is not None + else None + ) + _zeta = _zeta[mask] + _lambda = _lambda[mask] + # update optimizer state + for param_group, new_param in zip( + optimizer.param_groups, + [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity], + ): + stored_state = optimizer.state[param_group["params"][0]] + if "exp_avg" in stored_state: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + del optimizer.state[param_group["params"][0]] + param_group["params"][0] = new_param + optimizer.state[param_group["params"][0]] = stored_state + + opacity = new_gs.get_opacity.squeeze() + + # sparisfy + if i % _interval == 0: + _zeta = _lambda + opacity.detach() + if opacity.shape[0] > num_target: + index = _zeta.topk(num_target)[1] + _m = torch.ones_like(_zeta, dtype=torch.bool) + _m[index] = 0 + _zeta[_m] = 0 + _lambda = _lambda + opacity.detach() - _zeta + + # sample a random view + view_idx = np.random.randint(len(observations)) + observation = observations[view_idx] + extrinsic = extrinsics[view_idx] + intrinsic = intrinsics[view_idx] + + color = renderer.render(new_gs, extrinsic, intrinsic)["color"] + rgb_loss = torch.nn.functional.l1_loss(color, observation) + loss = rgb_loss + _delta * torch.sum( + torch.pow(_lambda + opacity - _zeta, 2) + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update lr + for j in range(len(optimizer.param_groups)): + optimizer.param_groups[j]["lr"] = cosine_anealing( + optimizer, i, 2500, start_lr[j], end_lr[j] + ) + + pbar.set_postfix( + { + "loss": rgb_loss.item(), + "num": opacity.shape[0], + "lambda": _lambda.mean().item(), + } + ) + pbar.update() + + new_gs._xyz = new_gs._xyz.data + new_gs._rotation = new_gs._rotation.data + new_gs._scaling = new_gs._scaling.data + new_gs._opacity = new_gs._opacity.data + + return new_gs diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/random_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..309419e50cd21dff9f5ca2bda242df95e17097d2 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/random_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/render_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7d85147a8c25beee904eceb09569531aac7dc72 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/backbone/tdfy_dit/utils/render_utils.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from ..modules import sparse as sp +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = ( + torch.tensor( + [ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ] + ).cuda() + * r + ) + extr = utils3d.torch.extrinsics_look_at( + orig, + torch.tensor([0, 0, 0]).float().cuda(), + torch.tensor([0, 0, 1]).float().cuda(), + ) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def render_frames( + sample, + extrinsics, + intrinsics, + options={}, + colors_overwrite=None, + verbose=True, + **kwargs, +): + if isinstance(sample, Octree): + renderer = OctreeRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 0.8) + renderer.rendering_options.far = options.get("far", 1.6) + renderer.rendering_options.bg_color = options.get("bg_color", (0, 0, 0)) + renderer.rendering_options.ssaa = options.get("ssaa", 4) + renderer.pipe.primitive = sample.primitive + elif isinstance(sample, Gaussian): + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 0.8) + renderer.rendering_options.far = options.get("far", 1.6) + renderer.rendering_options.bg_color = options.get("bg_color", (0, 0, 0)) + renderer.rendering_options.ssaa = options.get("ssaa", 1) + renderer.rendering_options.backend = options.get("backend", "inria") + renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 1) + renderer.rendering_options.far = options.get("far", 100) + renderer.rendering_options.ssaa = options.get("ssaa", 4) + else: + raise ValueError(f"Unsupported sample type: {type(sample)}") + + rets = {} + for j, (extr, intr) in tqdm( + enumerate(zip(extrinsics, intrinsics)), desc="Rendering", disable=not verbose + ): + if not isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + if "color" not in rets: + rets["color"] = [] + if "depth" not in rets: + rets["depth"] = [] + rets["color"].append( + np.clip( + res["color"].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8) + ) + if "percent_depth" in res: + rets["depth"].append(res["percent_depth"].detach().cpu().numpy()) + elif "depth" in res: + rets["depth"].append(res["depth"].detach().cpu().numpy()) + else: + rets["depth"].append(None) + else: + res = renderer.render(sample, extr, intr) + if "normal" not in rets: + rets["normal"] = [] + rets["normal"].append( + np.clip( + res["normal"].detach().cpu().numpy().transpose(1, 2, 0) * 255, + 0, + 255, + ).astype(np.uint8) + ) + return rets + + +def render_gaussian_color_stay_in_device( + sample, + extrinsics, + intrinsics, + options={}, + colors_overwrite=None, + verbose=True, + **kwargs, +): + assert isinstance(sample, Gaussian) + renderer = GaussianRenderer() + renderer.rendering_options.resolution = options.get("resolution", 512) + renderer.rendering_options.near = options.get("near", 0.8) + renderer.rendering_options.far = options.get("far", 1.6) + renderer.rendering_options.bg_color = options.get("bg_color", (0, 0, 0)) + renderer.rendering_options.ssaa = options.get("ssaa", 1) + renderer.rendering_options.backend = options.get("backend", "inria") + renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1) + renderer.pipe.use_mip_gaussian = True + + rets = {} + for _, (extr, intr) in tqdm( + enumerate(zip(extrinsics, intrinsics)), desc="Rendering", disable=not verbose + ): + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + color = (res["color"].permute(1, 2, 0) * 255).to(torch.uint8) + if "color" not in rets: + rets["color"] = [] + rets["color"].append(color) + return rets + +def render_video( + sample, + resolution=512, + bg_color=(0, 0, 0), + num_frames=300, + r=2, + fov=40, + backend="inria", + **kwargs, +): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitch, r, fov + ) + return render_frames( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color, "backend": backend}, + **kwargs, + ) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaws, pitchs, r, fov + ) + res = render_frames( + sample, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": (0, 0, 0)}, + ) + return res["color"], extrinsics, intrinsics + + +def render_snapshot( + samples, + resolution=512, + bg_color=(0, 0, 0), + offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), + r=10, + fov=8, + **kwargs, +): + yaw = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics( + yaw, pitch, r, fov + ) + return render_frames( + samples, + extrinsics, + intrinsics, + {"resolution": resolution, "bg_color": bg_color}, + **kwargs, + ) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/io.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/io.py new file mode 100644 index 0000000000000000000000000000000000000000..9d05de6260edc2386a6898d20b9f097e3e0a98df --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/io.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Any, Callable, Dict, List, Optional, Union, Iterable +import lightning.pytorch as pl +import torch +from pathlib import Path +import os +import re +from loguru import logger +from lightning.pytorch.utilities.consolidate_checkpoint import ( + _format_checkpoint, + _load_distributed_checkpoint, +) +from glob import glob + +from sam3d_objects.data.utils import get_child, set_child + + +def rename_checkpoint_weights_using_suffix_matching( + checkpoint_path_in, + checkpoint_path_out, + model: torch.nn.Module, + strict: bool = True, + keys: Optional[List[Any]] = (), +): + # extract model names + param_names = [n for n, _ in model.named_parameters()] + buffer_names = [n for n, _ in model.named_buffers()] + model_names = param_names + buffer_names + + # load stored weights + state = torch.load(checkpoint_path_in, weights_only=False) + + model_state = get_child(state, *keys) + model_state_names = list(model_state.keys()) + + # sort reversed names (sort by suffix) + model_names_rev = sorted([n[::-1] for n in model_names]) + model_state_names_rev = sorted([n[::-1] for n in model_state_names]) + + if strict and len(model_names) != len(model_state_names): + raise RuntimeError( + f"model and state don't have the same number of parameters ({len(model_names)} != {len(model_state_names)}), cannot match them (set strict = False to relax constraint)" + ) + + def common_prefix_length(str_0: str, str_1: str): + for count in range(min(len(str_0), len(str_1))): + if str_0[count] != str_1[count]: + break + return count + + # attempt to match every model names to largest suffic matched weight + name_mapping = {} + i, j = 0, 0 + last_n = 0 + while i < len(model_names_rev): + if j < len(model_state_names_rev): + n = common_prefix_length(model_names_rev[i], model_state_names_rev[j]) + else: + n = 0 + + if n >= last_n: + last_n = n + j += 1 + else: + last_n = 0 + name_mapping[model_names_rev[i][::-1]] = model_state_names_rev[j - 1][::-1] + i += 1 + + if not j < len(model_state_names_rev) + 1: + break + + # not all names might have been matched + if i < len(model_names): + raise RuntimeError("could not suffix match parameter names") + + for k, v in name_mapping.items(): + logger.debug(f"{k} <- {v}") + + # rename weights according to matches and save to disk + model_state_out = {k: model_state[v] for k, v in name_mapping.items()} + set_child(state, model_state_out, *keys) + torch.save(state, checkpoint_path_out) + + +def remove_prefix_state_dict_fn(prefix: str): + n = len(prefix) + + def state_dict_fn(state_dict): + return { + (key[n:] if key.startswith(prefix) else key): value + for key, value in state_dict.items() + } + + return state_dict_fn + + +def add_prefix_state_dict_fn(prefix: str): + def state_dict_fn(state_dict): + return {prefix + key: value for key, value in state_dict.items()} + + return state_dict_fn + + +def filter_and_remove_prefix_state_dict_fn(prefix: str): + n = len(prefix) + + def state_dict_fn(state_dict): + return { + key[n:]: value + for key, value in state_dict.items() + if key.startswith(prefix) + } + + return state_dict_fn + + +def get_last_checkpoint(path: str): + checkpoints = glob(os.path.join(path, "epoch=*-step=*.ckpt")) + prog = re.compile(r"epoch=(\d+)-step=(\d+).ckpt") + + checkpoints_to_sort = [] + for checkpoint in checkpoints: + checkpoint_name = os.path.basename(checkpoint) + match = prog.match(checkpoint_name) + if match is not None: + n_epoch, n_step = prog.match(checkpoint_name).groups() + n_epoch, n_step = int(n_epoch), int(n_step) + checkpoints_to_sort.append((n_epoch, n_step, checkpoint)) + + sorted_checkpoints = sorted(checkpoints_to_sort) + if not len(sorted_checkpoints) > 0: + raise RuntimeError(f"no checkpoint has been found at path : {path}") + return sorted_checkpoints[-1][2] + + +def load_sharded_checkpoint(path: str, device: Optional[str]): + if device != "cpu": + raise RuntimeError( + f'loading sharded weights on device "{device}" is not available, please use the "cpu" device instead' + ) + checkpoint = _load_distributed_checkpoint(Path(path)) + checkpoint = _format_checkpoint(checkpoint) + return checkpoint + + +def load_model_from_checkpoint( + model: Union[pl.LightningModule, torch.nn.Module], + checkpoint_path: str, + strict: bool = True, + device: Optional[str] = None, + freeze: bool = False, + eval: bool = False, + map_name: Union[Dict[str, str], None] = None, + remove_name: Union[List[str], None] = None, + state_dict_key: Union[None, str, Iterable[str]] = "state_dict", + state_dict_fn: Optional[Callable[[Any], Any]] = None, +): + logger.info(f"Loading checkpoint from {checkpoint_path}") + if os.path.isfile(checkpoint_path): + checkpoint = torch.load( + checkpoint_path, + map_location=device, + weights_only=False, + ) + elif os.path.isdir(checkpoint_path): # sharded + checkpoint = load_sharded_checkpoint(checkpoint_path, device=device) + else: # if neither a file nor a directory, path does not exist + raise FileNotFoundError(checkpoint_path) + + if isinstance(model, pl.LightningModule): + model.on_load_checkpoint(checkpoint) + + # get state dictionary + state_dict = checkpoint + if state_dict_key is not None: + if isinstance(state_dict_key, str): + state_dict_key = (state_dict_key,) + state_dict = get_child(state_dict, *state_dict_key) + + # remove names + if remove_name is not None: + for name in remove_name: + del state_dict[name] + + # remap names + if map_name is not None: + for src, dst in map_name.items(): + if src not in state_dict: + continue + state_dict[dst] = state_dict[src] + del state_dict[src] + + # apply custom changes to dict + if state_dict_fn is not None: + state_dict = state_dict_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) + + if device is not None: + model = model.to(device) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + eval = True + + if eval: + model.eval() + + return model diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/model/layers/llama3/ff.py b/thirdparty/sam3d/sam3d/sam3d_objects/model/layers/llama3/ff.py new file mode 100644 index 0000000000000000000000000000000000000000..ddde75d4ca01fe1aef5750a6ba5ef0ad88e8d5e5 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/model/layers/llama3/ff.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from torch import nn +import torch.nn.functional as F +from typing import Optional + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + output_dim: Optional[int] = None, + skip_w2: bool = False, + ): + """ + Llama3 FeedForward layer + https://github.com/meta-llama/llama3/blob/a0940f9cf7065d45bb6675660f80d305c041a754/llama/model.py#L193 + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + if output_dim is None: + output_dim = dim + + self.skip_w2 = skip_w2 + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + if not self.skip_w2: + self.w2 = nn.Linear(hidden_dim, output_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + x = F.silu(self.w1(x)) * self.w3(x) + if self.skip_w2: + return x + return self.w2(x) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/base.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f919b37e326096065306a70ce586425e2bd3ae --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/base.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch + + +class DepthModel: + def __init__(self, model, device="cuda"): + self.model = model + self.device = torch.device(device) + self.model.to(self.device) + self.model.eval() + + def __call__(self, image): + pass \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/moge.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/moge.py new file mode 100644 index 0000000000000000000000000000000000000000..0e654a099e6ab0721a48e7ae1769469c46c8db1d --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/depth_models/moge.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .base import DepthModel + +class MoGe(DepthModel): + def __call__(self, image): + output = self.model.infer( + image.to(self.device), force_projection=False + ) + pointmaps = output["points"] + output["pointmaps"] = pointmaps + return output \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b81fecdadfb64aba1b9a70a5731b7def41f28d --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline.py @@ -0,0 +1,844 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os + +from tqdm import tqdm +import torch +from loguru import logger +from functools import wraps +from torch.utils._pytree import tree_map_only + + +def set_attention_backend(): + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + + logger.info(f"GPU name is {gpu_name}") + if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name: + # logger.info("Use flash_attn") + os.environ["ATTN_BACKEND"] = "flash_attn" + os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn" + +set_attention_backend() + +from typing import List, Union +from hydra.utils import instantiate +from omegaconf import OmegaConf +import numpy as np + +from PIL import Image +from sam3d_objects.pipeline import preprocess_utils +from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( + get_mask, +) +from sam3d_objects.pipeline.inference_utils import ( + get_pose_decoder, + SLAT_MEAN, + SLAT_STD, + downsample_sparse_structure, + prune_sparse_structure, +) + +from sam3d_objects.model.io import ( + load_model_from_checkpoint, + filter_and_remove_prefix_state_dict_fn, +) + +from sam3d_objects.model.backbone.tdfy_dit.modules import sparse as sp +from sam3d_objects.model.backbone.tdfy_dit.utils import postprocessing_utils +from safetensors.torch import load_file + + +class InferencePipeline: + def __init__( + self, + ss_generator_config_path, + ss_generator_ckpt_path, + slat_generator_config_path, + slat_generator_ckpt_path, + ss_decoder_config_path, + ss_decoder_ckpt_path, + slat_decoder_gs_config_path, + slat_decoder_gs_ckpt_path, + slat_decoder_mesh_config_path, + slat_decoder_mesh_ckpt_path, + slat_decoder_gs_4_config_path=None, + slat_decoder_gs_4_ckpt_path=None, + ss_encoder_config_path=None, + ss_encoder_ckpt_path=None, + decode_formats=["gaussian", "mesh"], + dtype="bfloat16", + pad_size=1.0, + version="v0", + device="cuda", + ss_preprocessor=preprocess_utils.get_default_preprocessor(), + slat_preprocessor=preprocess_utils.get_default_preprocessor(), + ss_condition_input_mapping=["image"], + slat_condition_input_mapping=["image"], + pose_decoder_name="default", + workspace_dir="", + downsample_ss_dist=0, # the distance we use to downsample + ss_inference_steps=25, + ss_rescale_t=3, + ss_cfg_strength=7, + ss_cfg_interval=[0, 500], + ss_cfg_strength_pm=0.0, + slat_inference_steps=25, + slat_rescale_t=3, + slat_cfg_strength=5, + slat_cfg_interval=[0, 500], + rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d, + shape_model_dtype=None, + compile_model=False, + slat_mean=SLAT_MEAN, + slat_std=SLAT_STD, + ): + self.rendering_engine = rendering_engine + self.device = torch.device(device) + self.compile_model = compile_model + logger.info(f"self.device: {self.device}") + logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}") + logger.info(f"Actually using GPU: {torch.cuda.current_device()}") + with self.device: + self.decode_formats = decode_formats + self.pad_size = pad_size + self.version = version + self.ss_condition_input_mapping = ss_condition_input_mapping + self.slat_condition_input_mapping = slat_condition_input_mapping + self.workspace_dir = workspace_dir + self.downsample_ss_dist = downsample_ss_dist + self.ss_inference_steps = ss_inference_steps + self.ss_rescale_t = ss_rescale_t + self.ss_cfg_strength = ss_cfg_strength + self.ss_cfg_interval = ss_cfg_interval + self.ss_cfg_strength_pm = ss_cfg_strength_pm + self.slat_inference_steps = slat_inference_steps + self.slat_rescale_t = slat_rescale_t + self.slat_cfg_strength = slat_cfg_strength + self.slat_cfg_interval = slat_cfg_interval + + self.dtype = self._get_dtype(dtype) + if shape_model_dtype is None: + self.shape_model_dtype = self.dtype + else: + self.shape_model_dtype = self._get_dtype(shape_model_dtype) + + + # Setup preprocessors + self.pose_decoder = self.init_pose_decoder(ss_generator_config_path, pose_decoder_name) + self.ss_preprocessor = self.init_ss_preprocessor(ss_preprocessor, ss_generator_config_path) + self.slat_preprocessor = slat_preprocessor + + logger.info("Loading model weights...") + + ss_generator = self.init_ss_generator( + ss_generator_config_path, ss_generator_ckpt_path + ) + slat_generator = self.init_slat_generator( + slat_generator_config_path, slat_generator_ckpt_path + ) + ss_decoder = self.init_ss_decoder( + ss_decoder_config_path, ss_decoder_ckpt_path + ) + ss_encoder = self.init_ss_encoder( + ss_encoder_config_path, ss_encoder_ckpt_path + ) + slat_decoder_gs = self.init_slat_decoder_gs( + slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path + ) + slat_decoder_gs_4 = self.init_slat_decoder_gs( + slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path + ) + slat_decoder_mesh = self.init_slat_decoder_mesh( + slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path + ) + + # Load conditioner embedder so that we only load it once + ss_condition_embedder = self.init_ss_condition_embedder( + ss_generator_config_path, ss_generator_ckpt_path + ) + slat_condition_embedder = self.init_slat_condition_embedder( + slat_generator_config_path, slat_generator_ckpt_path + ) + + self.condition_embedders = { + "ss_condition_embedder": ss_condition_embedder, + "slat_condition_embedder": slat_condition_embedder, + } + + # override generator and condition embedder setting + self.override_ss_generator_cfg_config( + ss_generator, + cfg_strength=ss_cfg_strength, + inference_steps=ss_inference_steps, + rescale_t=ss_rescale_t, + cfg_interval=ss_cfg_interval, + cfg_strength_pm=ss_cfg_strength_pm, + ) + self.override_slat_generator_cfg_config( + slat_generator, + cfg_strength=slat_cfg_strength, + inference_steps=slat_inference_steps, + rescale_t=slat_rescale_t, + cfg_interval=slat_cfg_interval, + ) + + self.models = torch.nn.ModuleDict( + { + "ss_generator": ss_generator, + "slat_generator": slat_generator, + "ss_encoder": ss_encoder, + "ss_decoder": ss_decoder, + "slat_decoder_gs": slat_decoder_gs, + "slat_decoder_gs_4": slat_decoder_gs_4, + "slat_decoder_mesh": slat_decoder_mesh, + } + ) + logger.info("Loading model weights completed!") + + if self.compile_model: + logger.info("Compiling model...") + self._compile() + logger.info("Model compilation completed!") + self.slat_mean = torch.tensor(slat_mean) + self.slat_std = torch.tensor(slat_std) + + def _compile(self): + torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + torch._dynamo.config.capture_scalar_outputs = True + compile_mode = "max-autotune" + logger.info(f"Compile mode {compile_mode}") + + def clone_output_wrapper(f): + @wraps(f) + def wrapped(*args, **kwargs): + outputs = f(*args, **kwargs) + return tree_map_only( + torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs + ) + + return wrapped + + self.embed_condition = clone_output_wrapper( + torch.compile( + self.embed_condition, + mode=compile_mode, + fullgraph=True, # _preprocess_input in dino is not compatible with fullgraph + ) + ) + self.models["ss_generator"].reverse_fn.inner_forward = clone_output_wrapper( + torch.compile( + self.models["ss_generator"].reverse_fn.inner_forward, + mode=compile_mode, + fullgraph=True, + ) + ) + + self.models["ss_decoder"].forward = clone_output_wrapper( + torch.compile( + self.models["ss_decoder"].forward, + mode=compile_mode, + fullgraph=True, + ) + ) + + self._warmup() + + def _warmup(self, num_warmup_iters=3): + test_image = np.ones((512, 512, 4), dtype=np.uint8) * 255 + test_image[:, :, :3] = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) + image = Image.fromarray(test_image) + mask = None + image = self.merge_image_and_mask(image, mask) + + for _ in tqdm(range(num_warmup_iters)): + ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) + slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) + ss_return_dict = self.sample_sparse_structure(ss_input_dict) + coords = ss_return_dict["coords"] + slat = self.sample_slat(slat_input_dict, coords) + + def instantiate_and_load_from_pretrained( + self, + config, + ckpt_path, + state_dict_fn=None, + state_dict_key="state_dict", + device="cuda", + ): + model = instantiate(config) + + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path, device="cuda") + if state_dict_fn is not None: + state_dict = state_dict_fn(state_dict) + model.load_state_dict(state_dict, strict=False) + model.eval() + else: + model = load_model_from_checkpoint( + model, + ckpt_path, + strict=True, + device="cpu", + freeze=True, + eval=True, + state_dict_key=state_dict_key, + state_dict_fn=state_dict_fn, + ) + model = model.to(device) + + return model + + def init_pose_decoder(self, ss_generator_config_path, pose_decoder_name): + if pose_decoder_name is None: + pose_decoder_name = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["module"]["pose_target_convention"] + logger.info(f"Using pose decoder: {pose_decoder_name}") + return get_pose_decoder(pose_decoder_name) + + def init_ss_preprocessor(self, ss_preprocessor, ss_generator_config_path): + if ss_preprocessor is not None: + return ss_preprocessor + config = OmegaConf.load(os.path.join(self.workspace_dir, ss_generator_config_path))["tdfy"]["val_preprocessor"] + return instantiate(config) + + def init_ss_generator(self, ss_generator_config_path, ss_generator_ckpt_path): + config = OmegaConf.load( + os.path.join(self.workspace_dir, ss_generator_config_path) + )["module"]["generator"]["backbone"] + + state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( + "_base_models.generator." + ) + + return self.instantiate_and_load_from_pretrained( + config, + os.path.join(self.workspace_dir, ss_generator_ckpt_path), + state_dict_fn=state_dict_prefix_func, + device=self.device, + ) + + def init_slat_generator(self, slat_generator_config_path, slat_generator_ckpt_path): + config = OmegaConf.load( + os.path.join(self.workspace_dir, slat_generator_config_path) + )["module"]["generator"]["backbone"] + state_dict_prefix_func = filter_and_remove_prefix_state_dict_fn( + "_base_models.generator." + ) + return self.instantiate_and_load_from_pretrained( + config, + os.path.join(self.workspace_dir, slat_generator_ckpt_path), + state_dict_fn=state_dict_prefix_func, + device=self.device, + ) + + def init_ss_encoder(self, ss_encoder_config_path, ss_encoder_ckpt_path): + if ss_encoder_ckpt_path is not None: + # override to avoid problem loading + config = OmegaConf.load( + os.path.join(self.workspace_dir, ss_encoder_config_path) + ) + if "pretrained_ckpt_path" in config: + del config["pretrained_ckpt_path"] + return self.instantiate_and_load_from_pretrained( + config, + os.path.join(self.workspace_dir, ss_encoder_ckpt_path), + device=self.device, + state_dict_key=None, + ) + else: + return None + + def init_ss_decoder(self, ss_decoder_config_path, ss_decoder_ckpt_path): + # override to avoid problem loading + config = OmegaConf.load( + os.path.join(self.workspace_dir, ss_decoder_config_path) + ) + if "pretrained_ckpt_path" in config: + del config["pretrained_ckpt_path"] + return self.instantiate_and_load_from_pretrained( + config, + os.path.join(self.workspace_dir, ss_decoder_ckpt_path), + device=self.device, + state_dict_key=None, + ) + + def init_slat_decoder_gs( + self, slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path + ): + if slat_decoder_gs_config_path is None: + return None + else: + return self.instantiate_and_load_from_pretrained( + OmegaConf.load( + os.path.join(self.workspace_dir, slat_decoder_gs_config_path) + ), + os.path.join(self.workspace_dir, slat_decoder_gs_ckpt_path), + device=self.device, + state_dict_key=None, + ) + + def init_slat_decoder_mesh( + self, slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path + ): + return self.instantiate_and_load_from_pretrained( + OmegaConf.load( + os.path.join(self.workspace_dir, slat_decoder_mesh_config_path) + ), + os.path.join(self.workspace_dir, slat_decoder_mesh_ckpt_path), + device=self.device, + state_dict_key=None, + ) + + def init_ss_condition_embedder( + self, ss_generator_config_path, ss_generator_ckpt_path + ): + conf = OmegaConf.load( + os.path.join(self.workspace_dir, ss_generator_config_path) + ) + if "condition_embedder" in conf["module"]: + return self.instantiate_and_load_from_pretrained( + conf["module"]["condition_embedder"]["backbone"], + os.path.join(self.workspace_dir, ss_generator_ckpt_path), + state_dict_fn=filter_and_remove_prefix_state_dict_fn( + "_base_models.condition_embedder." + ), + device=self.device, + ) + else: + return None + + def init_slat_condition_embedder( + self, slat_generator_config_path, slat_generator_ckpt_path + ): + return self.init_ss_condition_embedder( + slat_generator_config_path, slat_generator_ckpt_path + ) + + + def override_ss_generator_cfg_config( + self, + ss_generator, + cfg_strength=7, + inference_steps=25, + rescale_t=3, + cfg_interval=[0, 500], + cfg_strength_pm=0.0, + ): + # override generator setting + ss_generator.inference_steps = inference_steps + ss_generator.reverse_fn.strength = cfg_strength + ss_generator.reverse_fn.interval = cfg_interval + ss_generator.rescale_t = rescale_t + ss_generator.reverse_fn.backbone.condition_embedder.normalize_images = True + ss_generator.reverse_fn.unconditional_handling = "add_flag" + ss_generator.reverse_fn.strength_pm = cfg_strength_pm + + logger.info( + "ss_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}, cfg_strength_pm={}", + inference_steps, + cfg_strength, + cfg_interval, + rescale_t, + cfg_strength_pm, + ) + + def override_slat_generator_cfg_config( + self, + slat_generator, + cfg_strength=5, + inference_steps=25, + rescale_t=3, + cfg_interval=[0, 500], + ): + slat_generator.inference_steps = inference_steps + slat_generator.reverse_fn.strength = cfg_strength + slat_generator.reverse_fn.interval = cfg_interval + slat_generator.rescale_t = rescale_t + + logger.info( + "slat_generator parameters: inference_steps={}, cfg_strength={}, cfg_interval={}, rescale_t={}", + inference_steps, + cfg_strength, + cfg_interval, + rescale_t, + ) + + + def run( + self, + image: Union[None, Image.Image, np.ndarray], + mask: Union[None, Image.Image, np.ndarray] = None, + seed=42, + stage1_only=False, + with_mesh_postprocess=True, + with_texture_baking=True, + use_vertex_color=False, + stage1_inference_steps=None, + stage2_inference_steps=None, + use_stage1_distillation=False, + use_stage2_distillation=False, + decode_formats=None, + ) -> dict: + """ + Parameters: + - image (Image): The input image to be processed. + - seed (int, optional): The random seed for reproducibility. Default is 42. + - stage1_only (bool, optional): If True, only the sparse structure is sampled and returned. Default is False. + - with_mesh_postprocess (bool, optional): If True, performs mesh post-processing. Default is True. + - with_texture_baking (bool, optional): If True, applies texture baking to the 3D model. Default is True. + Returns: + - dict: A dictionary containing the GLB file and additional data from the sparse structure sampling. + """ + # This should only happen if called from demo + image = self.merge_image_and_mask(image, mask) + with self.device: + ss_input_dict = self.preprocess_image(image, self.ss_preprocessor) + slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) + torch.manual_seed(seed) + ss_return_dict = self.sample_sparse_structure( + ss_input_dict, + inference_steps=stage1_inference_steps, + use_distillation=use_stage1_distillation, + ) + + ss_return_dict.update(self.pose_decoder(ss_return_dict)) + + if "scale" in ss_return_dict: + logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']}") + ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] + if stage1_only: + logger.info("Finished!") + ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 + return ss_return_dict + + coords = ss_return_dict["coords"] + slat = self.sample_slat( + slat_input_dict, + coords, + inference_steps=stage2_inference_steps, + use_distillation=use_stage2_distillation, + ) + outputs = self.decode_slat( + slat, self.decode_formats if decode_formats is None else decode_formats + ) + outputs = self.postprocess_slat_output( + outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color + ) + logger.info("Finished!") + + return { + **ss_return_dict, + **outputs, + } + + def postprocess_slat_output( + self, outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color + ): + # GLB files can be extracted from the outputs + logger.info( + f"Postprocessing mesh with option with_mesh_postprocess {with_mesh_postprocess}, with_texture_baking {with_texture_baking}..." + ) + if "mesh" in outputs: + glb = postprocessing_utils.to_glb( + outputs["gaussian"][0], + outputs["mesh"][0], + # Optional parameters + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB + verbose=False, + with_mesh_postprocess=with_mesh_postprocess, + with_texture_baking=with_texture_baking, + use_vertex_color=use_vertex_color, + rendering_engine=self.rendering_engine, + ) + + # glb.export("sample.glb") + else: + glb = None + + outputs["glb"] = glb + + if "gaussian" in outputs: + outputs["gs"] = outputs["gaussian"][0] + + if "gaussian_4" in outputs: + outputs["gs_4"] = outputs["gaussian_4"][0] + + return outputs + + def merge_image_and_mask( + self, + image: Union[np.ndarray, Image.Image], + mask: Union[None, np.ndarray, Image.Image], + ): + if mask is not None: + if isinstance(image, Image.Image): + image = np.array(image) + + mask = np.array(mask) + if mask.ndim == 2: + mask = mask[..., None] + + logger.info(f"Replacing alpha channel with the provided mask") + assert mask.shape[:2] == image.shape[:2] + image = np.concatenate([image[..., :3], mask], axis=-1) + + image = np.array(image) + return image + + def decode_slat( + self, + slat: sp.SparseTensor, + formats: List[str] = ["mesh", "gaussian"], + ) -> dict: + """ + Decode the structured latent. + + Args: + slat (sp.SparseTensor): The structured latent. + formats (List[str]): The formats to decode the structured latent to. + + Returns: + dict: The decoded structured latent. + """ + logger.info("Decoding sparse latent...") + ret = {} + with torch.no_grad(): + if "mesh" in formats: + ret["mesh"] = self.models["slat_decoder_mesh"](slat) + if "gaussian" in formats: + ret["gaussian"] = self.models["slat_decoder_gs"](slat) + if "gaussian_4" in formats: + ret["gaussian_4"] = self.models["slat_decoder_gs_4"](slat) + # if "radiance_field" in formats: + # ret["radiance_field"] = self.models["slat_decoder_rf"](slat) + return ret + + def is_mm_dit(self, model_name="ss_generator"): + return hasattr(self.models[model_name].reverse_fn.backbone, "latent_mapping") + + def embed_condition(self, condition_embedder, *args, **kwargs): + if condition_embedder is not None: + tokens = condition_embedder(*args, **kwargs) + return tokens, None, None + return None, args, kwargs + + def get_condition_input(self, condition_embedder, input_dict, input_mapping): + condition_args = self.map_input_keys(input_dict, input_mapping) + condition_kwargs = { + k: v for k, v in input_dict.items() if k not in input_mapping + } + logger.info("Running condition embedder ...") + embedded_cond, condition_args, condition_kwargs = self.embed_condition( + condition_embedder, *condition_args, **condition_kwargs + ) + logger.info("Condition embedder finishes!") + if embedded_cond is not None: + condition_args = (embedded_cond,) + condition_kwargs = {} + + return condition_args, condition_kwargs + + def sample_sparse_structure( + self, ss_input_dict: dict, inference_steps=None, use_distillation=False + ): + ss_generator = self.models["ss_generator"] + ss_decoder = self.models["ss_decoder"] + if use_distillation: + ss_generator.no_shortcut = False + ss_generator.reverse_fn.strength = 0 + ss_generator.reverse_fn.strength_pm = 0 + else: + ss_generator.no_shortcut = True + ss_generator.reverse_fn.strength = self.ss_cfg_strength + ss_generator.reverse_fn.strength_pm = self.ss_cfg_strength_pm + + prev_inference_steps = ss_generator.inference_steps + if inference_steps: + ss_generator.inference_steps = inference_steps + + image = ss_input_dict["image"] + bs = image.shape[0] + logger.info( + "Sampling sparse structure: inference_steps={}, strength={}, interval={}, rescale_t={}, cfg_strength_pm={}", + ss_generator.inference_steps, + ss_generator.reverse_fn.strength, + ss_generator.reverse_fn.interval, + ss_generator.rescale_t, + ss_generator.reverse_fn.strength_pm, + ) + + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=self.shape_model_dtype): + if self.is_mm_dit(): + latent_shape_dict = { + k: (bs,) + (v.pos_emb.shape[0], v.input_layer.in_features) + for k, v in ss_generator.reverse_fn.backbone.latent_mapping.items() + } + else: + latent_shape_dict = (bs,) + (4096, 8) + + condition_args, condition_kwargs = self.get_condition_input( + self.condition_embedders["ss_condition_embedder"], + ss_input_dict, + self.ss_condition_input_mapping, + ) + return_dict = ss_generator( + latent_shape_dict, + image.device, + *condition_args, + **condition_kwargs, + ) + if not self.is_mm_dit(): + return_dict = {"shape": return_dict} + + shape_latent = return_dict["shape"] + ss = ss_decoder( + shape_latent.permute(0, 2, 1) + .contiguous() + .view(shape_latent.shape[0], 8, 16, 16, 16) + ) + coords = torch.argwhere(ss > 0)[:, [0, 2, 3, 4]].int() + + # downsample output + return_dict["coords_original"] = coords + original_shape = coords.shape + if self.downsample_ss_dist > 0: + coords = prune_sparse_structure( + coords, + max_neighbor_axes_dist=self.downsample_ss_dist, + ) + coords, downsample_factor = downsample_sparse_structure(coords) + logger.info( + f"Downsampled coords from {original_shape[0]} to {coords.shape[0]}" + ) + return_dict["coords"] = coords + return_dict["downsample_factor"] = downsample_factor + + ss_generator.inference_steps = prev_inference_steps + return return_dict + + def sample_slat( + self, + slat_input: dict, + coords: torch.Tensor, + inference_steps=25, + use_distillation=False, + ) -> sp.SparseTensor: + image = slat_input["image"] + DEVICE = image.device + slat_generator = self.models["slat_generator"] + latent_shape = (image.shape[0],) + (coords.shape[0], 8) + prev_inference_steps = slat_generator.inference_steps + if inference_steps: + slat_generator.inference_steps = inference_steps + if use_distillation: + slat_generator.no_shortcut = False + slat_generator.reverse_fn.strength = 0 + else: + slat_generator.no_shortcut = True + slat_generator.reverse_fn.strength = self.slat_cfg_strength + + logger.info( + "Sampling sparse latent: inference_steps={}, strength={}, interval={}, rescale_t={}", + slat_generator.inference_steps, + slat_generator.reverse_fn.strength, + slat_generator.reverse_fn.interval, + slat_generator.rescale_t, + ) + + with torch.autocast(device_type="cuda", dtype=self.dtype): + with torch.no_grad(): + condition_args, condition_kwargs = self.get_condition_input( + self.condition_embedders["slat_condition_embedder"], + slat_input, + self.slat_condition_input_mapping, + ) + condition_args += (coords.cpu().numpy(),) + slat = slat_generator( + latent_shape, DEVICE, *condition_args, **condition_kwargs + ) + slat = sp.SparseTensor( + coords=coords, + feats=slat[0], + ).to(DEVICE) + slat = slat * self.slat_std.to(DEVICE) + self.slat_mean.to(DEVICE) + + slat_generator.inference_steps = prev_inference_steps + return slat + + def _apply_transform(self, input: torch.Tensor, transform): + if input is not None: + input = transform(input) + + return input + + def _preprocess_image_and_mask( + self, rgb_image, mask_image, img_mask_joint_transform + ): + for trans in img_mask_joint_transform: + rgb_image, mask_image = trans(rgb_image, mask_image) + return rgb_image, mask_image + + def map_input_keys(self, item, condition_input_mapping): + output = [item[k] for k in condition_input_mapping] + + return output + + def image_to_float(self, image): + image = np.array(image) + image = image / 255 + image = image.astype(np.float32) + return image + + def preprocess_image( + self, image: Union[Image.Image, np.ndarray], preprocessor + ) -> torch.Tensor: + # canonical type is numpy + if not isinstance(input, np.ndarray): + image = np.array(image) + + assert image.ndim == 3 # no batch dimension as of now + assert image.shape[-1] == 4 # rgba format + assert image.dtype == np.uint8 # [0,255] range + + rgba_image = torch.from_numpy(self.image_to_float(image)) + rgba_image = rgba_image.permute(2, 0, 1).contiguous() + rgb_image = rgba_image[:3] + rgb_image_mask = (get_mask(rgba_image, None, "ALPHA_CHANNEL") > 0).float() + processed_rgb_image, processed_mask = self._preprocess_image_and_mask( + rgb_image, rgb_image_mask, preprocessor.img_mask_joint_transform + ) + + # transform tensor to model input + processed_rgb_image = self._apply_transform( + processed_rgb_image, preprocessor.img_transform + ) + processed_mask = self._apply_transform( + processed_mask, preprocessor.mask_transform + ) + + # full image, with only processing from the image + rgb_image = self._apply_transform(rgb_image, preprocessor.img_transform) + rgb_image_mask = self._apply_transform( + rgb_image_mask, preprocessor.mask_transform + ) + item = { + "mask": processed_mask[None].to(self.device), + "image": processed_rgb_image[None].to(self.device), + "rgb_image": rgb_image[None].to(self.device), + "rgb_image_mask": rgb_image_mask[None].to(self.device), + } + + return item + + @staticmethod + def _get_dtype(dtype): + if dtype == "bfloat16": + return torch.bfloat16 + elif dtype == "float16": + return torch.float16 + elif dtype == "float32": + return torch.float32 + else: + raise NotImplementedError diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline_pointmap.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline_pointmap.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5230071d0ed8ef19bf6338c785d8bac7bf858c --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_pipeline_pointmap.py @@ -0,0 +1,493 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Union, Optional +from copy import deepcopy +import numpy as np +import torch +from tqdm import tqdm +import torchvision +from loguru import logger +from PIL import Image + +from pytorch3d.renderer import look_at_view_transform +from pytorch3d.transforms import Transform3d + +from sam3d_objects.model.backbone.dit.embedder.pointmap import PointPatchEmbed +from sam3d_objects.pipeline.inference_pipeline import InferencePipeline +from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( + get_mask, +) +from sam3d_objects.data.dataset.tdfy.transforms_3d import ( + DecomposedTransform, +) +from sam3d_objects.pipeline.utils.pointmap import infer_intrinsics_from_pointmap +from sam3d_objects.pipeline.inference_utils import o3d_plane_estimation, estimate_plane_area + + +def camera_to_pytorch3d_camera(device="cpu") -> DecomposedTransform: + """ + R3 camera space --> PyTorch3D camera space + Also needed for pointmaps + """ + r3_to_p3d_R, r3_to_p3d_T = look_at_view_transform( + eye=np.array([[0, 0, -1]]), + at=np.array([[0, 0, 0]]), + up=np.array([[0, -1, 0]]), + device=device, + ) + return DecomposedTransform( + rotation=r3_to_p3d_R, + translation=r3_to_p3d_T, + scale=torch.tensor(1.0, dtype=r3_to_p3d_R.dtype, device=device), + ) + + +def recursive_fn_factory(fn): + def recursive_fn(b): + if isinstance(b, dict): + return {k: recursive_fn(b[k]) for k in b} + if isinstance(b, list): + return [recursive_fn(t) for t in b] + if isinstance(b, tuple): + return tuple(recursive_fn(t) for t in b) + if isinstance(b, torch.Tensor): + return fn(b) + # Yes, writing out an explicit white list of + # trivial types is tedious, but so are bugs that + # come from not applying fn, when expected to have + # applied it. + if b is None: + return b + trivial_types = [bool, int, float] + for t in trivial_types: + if isinstance(b, t): + return b + raise TypeError(f"Unexpected type {type(b)}") + + return recursive_fn + + +recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous()) +recursive_clone = recursive_fn_factory(torch.clone) + + +def compile_wrapper( + fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None +): + compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic) + + def compiled_fn_wrapper(*args, **kwargs): + with torch.autograd.profiler.record_function( + f"compiled {fn}" if name is None else name + ): + cont_args = recursive_contiguous(args) + cont_kwargs = recursive_contiguous(kwargs) + result = compiled_fn(*cont_args, **cont_kwargs) + cloned_result = recursive_clone(result) + return cloned_result + + return compiled_fn_wrapper + + +class InferencePipelinePointMap(InferencePipeline): + + def __init__( + self, *args, depth_model, layout_post_optimization_method=None, clip_pointmap_beyond_scale=None, **kwargs + ): + self.depth_model = depth_model + self.layout_post_optimization_method = layout_post_optimization_method + self.clip_pointmap_beyond_scale = clip_pointmap_beyond_scale + super().__init__(*args, **kwargs) + + def _compile(self): + torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.accumulated_cache_size_limit = 2048 + torch._dynamo.config.capture_scalar_outputs = True + compile_mode = "max-autotune" + + for embedder, _ in self.condition_embedders[ + "ss_condition_embedder" + ].embedder_list: + if isinstance(embedder, PointPatchEmbed): + logger.info("Found PointPatchEmbed") + embedder.inner_forward = compile_wrapper( + embedder.inner_forward, + mode=compile_mode, + fullgraph=True, + ) + else: + embedder.forward = compile_wrapper( + embedder.forward, + mode=compile_mode, + fullgraph=True, + ) + + self.models["ss_generator"].reverse_fn.inner_forward = compile_wrapper( + self.models["ss_generator"].reverse_fn.inner_forward, + mode=compile_mode, + fullgraph=True, + ) + + self.models["ss_decoder"].forward = compile_wrapper( + self.models["ss_decoder"].forward, + mode=compile_mode, + fullgraph=True, + ) + + self._warmup() + + def _warmup(self, num_warmup_iters=3): + test_image = np.ones((512, 512, 4), dtype=np.uint8) * 255 + test_image[:, :, :3] = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) + image = Image.fromarray(test_image) + mask = None + image = self.merge_image_and_mask(image, mask) + with torch.inference_mode(False): + with torch.no_grad(): + for _ in tqdm(range(num_warmup_iters)): + pointmap_dict = recursive_clone(self.compute_pointmap(image)) + pointmap = pointmap_dict["pointmap"] + + ss_input_dict = self.preprocess_image( + image, self.ss_preprocessor, pointmap=pointmap + ) + ss_return_dict = self.sample_sparse_structure( + ss_input_dict, inference_steps=None + ) + + _ = self.run_layout_model( + ss_input_dict, + ss_return_dict, + inference_steps=None, + ) + + def _preprocess_image_and_mask_pointmap( + self, rgb_image, mask_image, pointmap, img_mask_pointmap_joint_transform + ): + for trans in img_mask_pointmap_joint_transform: + rgb_image, mask_image, pointmap = trans( + rgb_image, mask_image, pointmap=pointmap + ) + return rgb_image, mask_image, pointmap + + def preprocess_image( + self, + image: Union[Image.Image, np.ndarray], + preprocessor, + pointmap=None, + ) -> torch.Tensor: + # canonical type is numpy + if not isinstance(image, np.ndarray): + image = np.array(image) + + assert image.ndim == 3 # no batch dimension as of now + assert image.shape[-1] == 4 # rgba format + assert image.dtype == np.uint8 # [0,255] range + + rgba_image = torch.from_numpy(self.image_to_float(image)) + rgba_image = rgba_image.permute(2, 0, 1).contiguous() + rgb_image = rgba_image[:3] + rgb_image_mask = get_mask(rgba_image, None, "ALPHA_CHANNEL") + + preprocessor_return_dict = preprocessor._process_image_mask_pointmap_mess( + rgb_image, rgb_image_mask, pointmap + ) + + # Put in a for loop? + _item = preprocessor_return_dict + item = { + "mask": _item["mask"][None].to(self.device), + "image": _item["image"][None].to(self.device), + "rgb_image": _item["rgb_image"][None].to(self.device), + "rgb_image_mask": _item["rgb_image_mask"][None].to(self.device), + } + + if pointmap is not None and preprocessor.pointmap_transform != (None,): + item["pointmap"] = _item["pointmap"][None].to(self.device) + item["rgb_pointmap"] = _item["rgb_pointmap"][None].to(self.device) + item["pointmap_scale"] = _item["pointmap_scale"][None].to(self.device) + item["pointmap_shift"] = _item["pointmap_shift"][None].to(self.device) + item["rgb_pointmap_scale"] = _item["rgb_pointmap_scale"][None].to(self.device) + item["rgb_pointmap_shift"] = _item["rgb_pointmap_shift"][None].to(self.device) + + return item + + def _clip_pointmap(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.clip_pointmap_beyond_scale is None: + return pointmap + + pointmap_size = (pointmap.shape[1], pointmap.shape[2]) + if mask.dim() == 2: + mask = mask.unsqueeze(0) + mask_resized = torchvision.transforms.functional.resize( + mask, pointmap_size, + interpolation=torchvision.transforms.InterpolationMode.NEAREST + ).squeeze(0) + + pointmap_flat = pointmap.reshape(3, -1) + # Get valid points from the mask + mask_bool = mask_resized.reshape(-1) > 0.5 + mask_points = pointmap_flat[:, mask_bool] + mask_distance = mask_points.nanmedian(dim=-1).values[-1] + logger.info(f"mask_distance: {mask_distance}") + pointmap_clipped_flat = torch.where( + pointmap_flat[2, ...].abs() > self.clip_pointmap_beyond_scale * mask_distance, + torch.full_like(pointmap_flat, float('nan')), + pointmap_flat + ) + pointmap_clipped = pointmap_clipped_flat.reshape(pointmap.shape) + return pointmap_clipped + + + + def compute_pointmap(self, image, pointmap=None): + loaded_image = self.image_to_float(image) + loaded_image = torch.from_numpy(loaded_image) + loaded_mask = loaded_image[..., -1] + loaded_image = loaded_image.permute(2, 0, 1).contiguous()[:3] + + if pointmap is None: + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + output = self.depth_model(loaded_image) + pointmaps = output["pointmaps"] + camera_convention_transform = ( + Transform3d() + .rotate(camera_to_pytorch3d_camera(device=self.device).rotation) + .to(self.device) + ) + points_tensor = camera_convention_transform.transform_points(pointmaps) + intrinsics = output.get("intrinsics", None) + else: + output = {} + points_tensor = pointmap.to(self.device) + if loaded_image.shape != points_tensor.shape: + # Interpolate points_tensor to match loaded_image size + # loaded_image has shape [3, H, W], we need H and W + points_tensor = torch.nn.functional.interpolate( + points_tensor.permute(2, 0, 1).unsqueeze(0), + size=(loaded_image.shape[1], loaded_image.shape[2]), + mode="nearest", + ).squeeze(0).permute(1, 2, 0) + intrinsics = None + + points_tensor = points_tensor.permute(2, 0, 1) + points_tensor = self._clip_pointmap(points_tensor, loaded_mask) + + # Prepare the point map tensor + point_map_tensor = { + "pointmap": points_tensor, + "pts_color": loaded_image, + } + + # If depth model doesn't provide intrinsics, infer them + if intrinsics is None: + intrinsics_result = infer_intrinsics_from_pointmap( + points_tensor.permute(1, 2, 0), device=self.device + ) + point_map_tensor["intrinsics"] = intrinsics_result["intrinsics"] + + return point_map_tensor + + def run_post_optimization(self, mesh_glb, intrinsics, pose_dict, layout_input_dict): + intrinsics = intrinsics.clone() + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + re_focal = min(fx, fy) + intrinsics[0, 0], intrinsics[1, 1] = re_focal, re_focal + revised_quat, revised_t, revised_scale, final_iou, _, _ = ( + self.layout_post_optimization_method( + mesh_glb, + pose_dict["rotation"], + pose_dict["translation"], + pose_dict["scale"], + layout_input_dict["rgb_image_mask"][0, 0], + layout_input_dict["rgb_pointmap"][0].permute(1, 2, 0), + intrinsics, + min_size=518, + ) + ) + return { + "rotation": revised_quat, + "translation": revised_t, + "scale": revised_scale, + "iou": final_iou, + } + + + def run( + self, + image: Union[None, Image.Image, np.ndarray], + mask: Union[None, Image.Image, np.ndarray] = None, + seed: Optional[int] = None, + stage1_only=False, + with_mesh_postprocess=True, + with_texture_baking=True, + with_layout_postprocess=True, + use_vertex_color=False, + stage1_inference_steps=None, + stage2_inference_steps=None, + use_stage1_distillation=False, + use_stage2_distillation=False, + pointmap=None, + decode_formats=None, + estimate_plane=False, + ) -> dict: + image = self.merge_image_and_mask(image, mask) + with self.device: + pointmap_dict = self.compute_pointmap(image, pointmap) + pointmap = pointmap_dict["pointmap"] + pts = type(self)._down_sample_img(pointmap) + pts_colors = type(self)._down_sample_img(pointmap_dict["pts_color"]) + + if estimate_plane: + return self.estimate_plane(pointmap_dict, image) + + ss_input_dict = self.preprocess_image( + image, self.ss_preprocessor, pointmap=pointmap + ) + + slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) + if seed is not None: + torch.manual_seed(seed) + ss_return_dict = self.sample_sparse_structure( + ss_input_dict, + inference_steps=stage1_inference_steps, + use_distillation=use_stage1_distillation, + ) + + # We could probably use the decoder from the models themselves + pointmap_scale = ss_input_dict.get("pointmap_scale", None) + pointmap_shift = ss_input_dict.get("pointmap_shift", None) + ss_return_dict.update( + self.pose_decoder( + ss_return_dict, + scene_scale=pointmap_scale, + scene_shift=pointmap_shift, + ) + ) + + logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling") + ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] + + if stage1_only: + logger.info("Finished!") + ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 + return { + **ss_return_dict, + "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 + "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), # HxWx3 + } + # return ss_return_dict + + coords = ss_return_dict["coords"] + slat = self.sample_slat( + slat_input_dict, + coords, + inference_steps=stage2_inference_steps, + use_distillation=use_stage2_distillation, + ) + outputs = self.decode_slat( + slat, self.decode_formats if decode_formats is None else decode_formats + ) + outputs = self.postprocess_slat_output( + outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color + ) + glb = outputs.get("glb", None) + + try: + if ( + with_layout_postprocess + and self.layout_post_optimization_method is not None + ): + assert glb is not None, "require mesh to run postprocessing" + logger.info("Running layout post optimization method...") + postprocessed_pose = self.run_post_optimization( + deepcopy(glb), + pointmap_dict["intrinsics"], + ss_return_dict, + ss_input_dict, + ) + ss_return_dict.update(postprocessed_pose) + except Exception as e: + logger.error( + f"Error during layout post optimization: {e}", exc_info=True + ) + + # glb.export("sample.glb") + logger.info("Finished!") + + return { + **ss_return_dict, + **outputs, + "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 + "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), # HxWx3 + } + + @staticmethod + def _down_sample_img(img_3chw: torch.Tensor): + # img_3chw: (3, H, W) + x = img_3chw.unsqueeze(0) + if x.dtype == torch.uint8: + x = x.float() / 255.0 + max_side = max(x.shape[2], x.shape[3]) + scale_factor = 1.0 + + # heuristics + if max_side > 3800: + scale_factor = 0.125 + if max_side > 1900: + scale_factor = 0.25 + elif max_side > 1200: + scale_factor = 0.5 + + x = torch.nn.functional.interpolate( + x, + scale_factor=(scale_factor, scale_factor), + mode="bilinear", + align_corners=False, + antialias=True, + ) # -> (1, 3, H/4, W/4) + return x.squeeze(0) + + def estimate_plane(self, pointmap_dict, image, ground_area_threshold=0.25, min_points=100): + assert image.shape[-1] == 4 # rgba format + # Extract mask from alpha channel + floor_mask = type(self)._down_sample_img(torch.from_numpy(image[..., -1]).float().unsqueeze(0))[0] > 0.5 + pts = type(self)._down_sample_img(pointmap_dict["pointmap"]) + + # Get all points in 3D space (H, W, 3) + pts_hwc = pts.cpu().permute((1, 2, 0)) + + valid_mask_points = floor_mask.cpu().numpy() + # Extract points that fall within the mask + if valid_mask_points.any(): + # Get points within mask + masked_points = pts_hwc[valid_mask_points] + # Filter out invalid points (zero points from depth estimation failures) + valid_points_mask = torch.norm(masked_points, dim=-1) > 1e-6 + valid_points = masked_points[valid_points_mask] + points = valid_points.numpy() + else: + points = np.array([]).reshape(0, 3) + + # Calculate area coverage and check num of points + overlap_area = estimate_plane_area(floor_mask) + has_enough_points = len(points) >= min_points + + logger.info(f"Plane estimation: {len(points)} points, {overlap_area:.3f} area coverage") + if overlap_area > ground_area_threshold and has_enough_points: + try: + mesh = o3d_plane_estimation(points) + logger.info("Successfully estimated plane mesh") + except Exception as e: + logger.error(f"Failed to estimate plane: {e}") + mesh = None + else: + logger.info(f"Skipping plane estimation: area={overlap_area:.3f}, points={len(points)}") + mesh = None + + return { + "glb": mesh, + "translation": torch.tensor([[0.0, 0.0, 0.0]]), + "scale": torch.tensor([[1.0, 1.0, 1.0]]), + "rotation": torch.tensor([[1.0, 0.0, 0.0, 0.0]]), + } diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7adcd1e61e3daf3a8a1814acf7988e249b4b9b0 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/inference_utils.py @@ -0,0 +1,864 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +import numpy as np +import open3d as o3d +import trimesh +from pytorch3d.structures import Meshes +from pytorch3d.transforms import quaternion_to_matrix, Transform3d, matrix_to_quaternion +from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform, decompose_transform +from sam3d_objects.data.dataset.tdfy.pose_target import PoseTargetConverter +from loguru import logger +from sam3d_objects.pipeline.layout_post_optimization_utils import ( + run_ICP, + compute_iou, + set_seed, + apply_transform, + get_mesh, + get_mask_renderer, + run_alignment, + run_render_compare, + check_occlusion, +) + + +SLAT_STD = torch.tensor( + [ + 2.377650737762451, + 2.386378288269043, + 2.124418020248413, + 2.1748552322387695, + 2.663944721221924, + 2.371192216873169, + 2.6217446327209473, + 2.684523105621338, + ] +) +SLAT_MEAN = torch.tensor( + [ + -2.1687545776367188, + -0.004347046371549368, + -0.13352349400520325, + -0.08418072760105133, + -0.5271206498146057, + 0.7238689064979553, + -1.1414450407028198, + 1.2039363384246826, + ] +) + +ROTATION_6D_MEAN = torch.tensor( + [ + -0.06366084883674913, + 0.008438224692279752, + 0.00017084786438302483, + 0.0007126610473540038, + -0.0030916726538816417, + 0.5166093753457688, + ] +) +ROTATION_6D_STD = torch.tensor( + [ + 0.6656971967514863, + 0.6787012271867754, + 0.30345010594844524, + 0.4394504420678794, + 0.39817973931717104, + 0.6176286868761914, + ] +) + +def layout_post_optimization( + Mesh, + Quaternion, + Translation, + Scale, + Mask, + Point_Map, + Intrinsics, + Enable_shape_ICP=True, + Enable_rendering_optimization=True, + min_size=512, + device=None, +): + + set_seed(100) + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # init transform and process mesh + Rotation = quaternion_to_matrix(Quaternion.squeeze(1)) + center = Translation[0].clone() + tfm_ori = compose_transform(scale=Scale, rotation=Rotation, translation=Translation) + mesh, faces_idx, textures = get_mesh(Mesh, tfm_ori, device) + + # get mask and renderer + mask, renderer = get_mask_renderer(Mask, min_size, Intrinsics, device) + + # check occlusion + if check_occlusion(mask[0, 0].cpu().numpy(), Point_Map.cpu().numpy()): + return ( + Quaternion, + Translation, + Scale, + -1.0, + False, + False, + ) + + # Step 1: Manual Alignment + source_points, target_points, center, tfm1, mesh, ori_iou, final_iou, flag_notgt = ( + run_alignment( + Point_Map, mask, mesh, center, faces_idx, textures, renderer, device + ) + ) + + # return original layout if no target points. + if flag_notgt: + return ( + Quaternion, + Translation, + Scale, + -1.0, + False, + False, + ) + + # Step 2: Shape ICP + if Enable_shape_ICP: + Flag_ICP = True + points_aligned_icp, transformation = run_ICP( + mesh, source_points, target_points, threshold=0.05 + ) + mesh_ICP = Meshes( + verts=[points_aligned_icp], faces=[faces_idx], textures=textures + ) + rendered = renderer(mesh_ICP) + ori_iou_shapeICP = compute_iou( + rendered[..., 3][0][None, None], mask, threshold=0.5 + ) + # determine whether accept ICP + if ori_iou_shapeICP > ori_iou: + mesh = mesh_ICP + final_iou = ori_iou_shapeICP.cpu().item() + T_o3d = torch.tensor(transformation, dtype=torch.float32, device=device) + T_o3d = T_o3d.T + A = T_o3d[:3, :3] + t = T_o3d[3, :3] + scale = A.norm(dim=1) + R = A / scale[:, None] + center = ((center[None] * scale) @ R + t)[0] # transform center + tfm2 = ( + Transform3d(device=device) + .scale(scale[None]) + .rotate(R[None]) + .translate(t[None]) + ) + else: + Flag_ICP = False + scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to( + device + ) + tfm2 = ( + Transform3d(device=device) + .scale(scale_2.expand(3)[None]) + .translate(translation_2[None]) + ) + else: + Flag_ICP = False + scale_2, translation_2 = torch.tensor(1).to(device), torch.zeros([3]).to(device) + tfm2 = ( + Transform3d(device=device) + .scale(scale_2.expand(3)[None]) + .translate(translation_2[None]) + ) + + # Step 3: Render-and-Compare + if not Enable_rendering_optimization: + Flag_optim = False + tfm = tfm_ori.compose(tfm1).compose(tfm2) + else: + quat, translation, scale, R = run_render_compare( + mesh, center, renderer, mask, device + ) + with torch.no_grad(): + transformed = apply_transform(mesh, center, quat, translation, scale) + rendered = renderer(transformed) + optimized_iou = compute_iou( + rendered[..., 3][0][None, None], mask, threshold=0.5 + ) + # Criterior to use layout optimization + if optimized_iou < 0.5 or optimized_iou <= ori_iou: + Flag_optim = False + tfm = tfm_ori # reject manual alignment and ICP as well. + # tfm = tfm_ori.compose(tfm1).compose(tfm2) # only reject render-compare but keep manual alignment and ICP. + else: + Flag_optim = True + final_iou = optimized_iou.detach().cpu().item() + tfm3 = ( + Transform3d(device=device) + .translate(-center[None]) # move to center + .scale(scale.expand(3)[None]) + .rotate(R.T[None]) + .translate(center[None]) # move back + .translate(translation[None]) + ) + tfm = tfm_ori.compose(tfm1).compose(tfm2).compose(tfm3) + + M = tfm.get_matrix()[0] + T_final = M[3, :3][None] + A = M[:3, :3] + scale_final = A.norm(dim=1)[None] + R_final = A / scale_final[:, None] + quat_final = matrix_to_quaternion(R_final) + + return ( + quat_final, + T_final, + scale_final, + round(float(final_iou), 4), + Flag_ICP, + Flag_optim, + ) + + +def pose_decoder( + pose_target_convention, +): + def decode(model_output_dict, scene_scale=None, scene_shift=None): + x = model_output_dict + + # BEGIN: copied from generative.py + key_mapping = { + "shape": "x_shape_latent", + "quaternion": "x_instance_rotation", + "6drotation": "x_instance_rotation_6d", + "6drotation_normalized": "x_instance_rotation_6d_normalized", + "translation": "x_instance_translation", + "scale": "x_instance_scale", + "translation_scale": "x_translation_scale", + } + + # Decodes for metrics + pose_target_dict = {} + for k, v in x.items(): + pose_target_dict[key_mapping.get(k, k)] = v + + # TODO: Hao & Bowen please do clean this up! + # Convert 6D rotation to quaternion if needed + if ( + "x_instance_rotation_6d" in pose_target_dict + or "x_instance_rotation_6d_normalized" in pose_target_dict + ): + # Extract the two 3D vectors + if "x_instance_rotation_6d_normalized" in pose_target_dict: + rot_6d = pose_target_dict[ + "x_instance_rotation_6d_normalized" + ] * ROTATION_6D_STD.to( + pose_target_dict["x_instance_rotation_6d_normalized"].device + ) + ROTATION_6D_MEAN.to( + pose_target_dict["x_instance_rotation_6d_normalized"].device + ) + else: + rot_6d = pose_target_dict["x_instance_rotation_6d"] + a1 = rot_6d[..., 0:3] + a2 = rot_6d[..., 3:6] + + # Normalize first vector + b1 = torch.nn.functional.normalize(a1, dim=-1) + + # Make second vector orthogonal to first + b2 = a2 - torch.sum(b1 * a2, dim=-1, keepdim=True) * b1 + b2 = torch.nn.functional.normalize(b2, dim=-1) + + # Compute third vector as cross product + b3 = torch.cross(b1, b2, dim=-1) + + # Stack to create rotation matrix + rotation_matrix = torch.stack([b1, b2, b3], dim=-1) + + # Convert to quaternion + quaternion = matrix_to_quaternion(rotation_matrix) + pose_target_dict["x_instance_rotation"] = quaternion + + if "x_instance_scale" in pose_target_dict: + pose_target_dict["x_instance_scale"] = torch.exp( + pose_target_dict["x_instance_scale"] + ) + + if "x_translation_scale" in pose_target_dict: + pose_target_dict["x_translation_scale"] = torch.exp( + pose_target_dict["x_translation_scale"] + ) + + pose_target_dict["pose_target_convention"] = [pose_target_convention] * x[ + "shape" + ].shape[0] + # END: copied from generative.py + + # Fake pointmap moments + device = x["shape"].device + _scene_scale = ( + scene_scale if scene_scale is not None else torch.tensor(1.0, device=device) + ) + _scene_shift = ( + scene_shift + if scene_shift is not None + else torch.tensor([[0, 0, 0]], device=device) + ) + pose_target_dict["x_scene_scale"] = _scene_scale + pose_target_dict["x_scene_center"] = _scene_shift + + # Convert to instance pose + pose_instance_dict = PoseTargetConverter.dicts_pose_target_to_instance_pose( + pose_target_convention=pose_target_convention, + x_instance_scale=pose_target_dict["x_instance_scale"], + x_instance_translation=pose_target_dict["x_instance_translation"], + x_instance_rotation=pose_target_dict["x_instance_rotation"], + x_translation_scale=pose_target_dict["x_translation_scale"], + x_scene_scale=pose_target_dict["x_scene_scale"], + x_scene_center=pose_target_dict["x_scene_center"], + ) + return { + "translation": pose_instance_dict["instance_position_l2c"].squeeze(0), + "rotation": pose_instance_dict["instance_quaternion_l2c"].squeeze(0), + "scale": pose_instance_dict["instance_scale_l2c"].squeeze(0).mean(-1, keepdim=True).expand(1,3), + } + + return decode + +def zero_prediction_decoder(): + def decode(model_output_dict, scene_scale=None, scene_shift=None): + import copy + from loguru import logger + _pose_decoder = pose_decoder("ScaleShiftInvariant") + model_output_dict = copy.deepcopy(model_output_dict) + logger.warning("Overwriting predictions to zero prediction") + model_output_dict["translation"] = torch.zeros_like(model_output_dict["translation"]) + model_output_dict["translation_scale"] = torch.zeros_like(model_output_dict["translation_scale"]) + model_output_dict["scale"] = torch.zeros_like(model_output_dict["scale"]) + 1.337 # Empirical average on R3 + return _pose_decoder(model_output_dict, scene_scale, scene_shift) + + return decode + + +def get_default_pose_decoder(): + def decode(model_output_dict, **kwargs): + return {} + + return decode + + +POSE_DECODERS = { + "default": get_default_pose_decoder(), + "ApparentSize": pose_decoder("ApparentSize"), + "DisparitySpace": pose_decoder("DisparitySpace"), + "ScaleShiftInvariant": pose_decoder("ScaleShiftInvariant"), + "ZeroPredictionScaleShiftInvariant": zero_prediction_decoder(), +} + + +def get_pose_decoder(name): + if name not in POSE_DECODERS: + raise NotImplementedError + + return POSE_DECODERS[name] + + +def prune_sparse_structure( + coord_batch, + max_neighbor_axes_dist=1, +): + coords, batch = coord_batch[:, 1:], coord_batch[:, 0].unsqueeze(-1) + device = coords.device + # 1) shift coords so minimum is zero + min_xyz = coords.min(0)[0] + coords0 = coords - min_xyz + # 2) build occupancy grid + max_xyz = coords0.max(0)[0] + 1 # size in each dim + D, H, W = max_xyz.tolist() + # shape (1,1,D,H,W) + occ = torch.zeros((1, 1, D, H, W), dtype=torch.uint8, device=device) + x, y, z = coords0.unbind(1) + occ[0, 0, x, y, z] = 1 + # 3) 3×3×3 convolution to count each voxel + neighbors + kernel = torch.ones( + ( + 1, + 1, + 2 * max_neighbor_axes_dist + 1, + 2 * max_neighbor_axes_dist + 1, + 2 * max_neighbor_axes_dist + 1, + ), + dtype=torch.uint8, + device=device, + ) + # pad so output is same size + pad = max_neighbor_axes_dist + counts = torch.nn.functional.conv3d(occ.float(), kernel.float(), padding=pad) + # interior voxels have count == (2*max_neighbor_axes_dist+1)**3 + full_count = (2 * max_neighbor_axes_dist + 1) ** 3 + # 4) lookup counts at each original coord + counts_at_pts = counts[0, 0, x, y, z] # (N,) + is_surface = counts_at_pts < full_count + # 5) return filtered batch+coords (shift back if you want original coords) + kept = is_surface.nonzero(as_tuple=False).squeeze(1) + out_batch = batch[kept] + out_coords = coords[kept] + coords = torch.cat([out_batch, out_coords], dim=1) + + return torch.cat([out_batch, out_coords], dim=1) + + +def downsample_sparse_structure( + coord_batch, + max_coords=42000, + downsample_factor=2, +): + """ + Downsample sparse structure coordinates when there are more than max_coords. + + Downsamples by rescaling coordinates, effectively shrinking the grid while preserving + the structure. The downsampled grid is centered in the original space. + + Args: + coord_batch: tensor of shape (N, 4) where [:, 0] is batch index and [:, 1:] are coords + max_coords: maximum number of coordinates to keep + 42000 should be safe number. Calculation: max(int32) / (64*768) ~= 43691 + Only needed for mesh decoding. + downsample_factor: factor by which to downsample (e.g., 2 means half resolution) + + Returns: + Downsampled coord_batch with coordinates rescaled if downsampling is needed + """ + if coord_batch.shape[0] <= max_coords: + return coord_batch, 1 + + # Extract coordinates and batch indices + coords = coord_batch[:, 1:].float() # Shape: (N, 3), convert to float for scaling + batch_indices = coord_batch[:, 0:1] # Shape: (N, 1) + + # Find the actual coordinate bounds + coords_min = coords.min(dim=0)[0] # Shape: (3,) + coords_max = coords.max(dim=0)[0] # Shape: (3,) + original_size = coords_max - coords_min + 1 # Add 1 since coordinates are discrete + + # Calculate target size after downsampling + target_size = original_size / downsample_factor + + # Calculate the offset to center the downsampled grid + offset = (original_size - target_size) / 2 + target_min = coords_min + offset + target_max = coords_min + offset + target_size - 1 + + # Normalize coordinates to [0, 1] within their actual range + coords_normalized = (coords - coords_min) / (coords_max - coords_min) + + # Scale to the target range + coords_rescaled = coords_normalized * (target_size - 1) + target_min + + # Round to integers to get discrete grid coordinates + coords_rescaled = torch.round(coords_rescaled).int() + + # Clamp to ensure we stay within bounds + coords_rescaled = torch.clamp(coords_rescaled, target_min.int(), target_max.int()) + + # Remove duplicates that may have been created by the downsampling + # Concatenate batch and coords for duplicate removal + combined = torch.cat([batch_indices, coords_rescaled], dim=1) + unique_combined = torch.unique(combined, dim=0) + + # If still too many after deduplication, randomly subsample + if unique_combined.shape[0] > max_coords: + indices = torch.randperm(unique_combined.shape[0], device=coord_batch.device)[ + :max_coords + ] + unique_combined = unique_combined[indices] + + return unique_combined.int(), downsample_factor + + +def normalize_mesh_verts(verts): + vmin = verts.min(axis=0) + vmax = verts.max(axis=0) + center = (vmax + vmin) / 2.0 + extent = vmax - vmin # largest side length + max_extent = np.max(extent) + if max_extent == 0: + vertices = verts - center + scale = 1 + else: + scale = 1.0 / max_extent + vertices = (verts - center) * scale + return vertices, scale, center + + +def voxelize_mesh(mesh, resolution=64): + verts = np.asarray(mesh.vertices) + # rotate mesh (from z-up to y-up) + verts = verts @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).T + # normalize vertices + # skip vertices to avoid losing points, likely already normalized + if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3: + vertices, scale, center = verts, None, None + else: + vertices, scale, center = normalize_mesh_verts(verts) + + vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6) + mesh.vertices = o3d.utility.Vector3dVector(vertices) + voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds( + mesh, + voxel_size=1 / 64, + min_bound=(-0.5, -0.5, -0.5), + max_bound=(0.5, 0.5, 0.5), + ) + vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()]) + vertices = (vertices + 0.5) / 64 - 0.5 + coords = ((torch.tensor(vertices) + 0.5) * resolution).int().contiguous() + ss = torch.zeros(1, resolution, resolution, resolution, dtype=torch.long) + ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 + return ss, scale, center + + +def preprocess_mesh(mesh: trimesh.Trimesh): + verts = mesh.vertices + if np.abs(verts.min() + 0.5) < 1e-3 and np.abs(verts.max() - 0.5) < 1e-3: + return mesh + vertices, _, _ = normalize_mesh_verts(verts) + mesh.vertices = vertices + return mesh + + +def trimesh2o3d_mesh(trimesh_mesh): + verts = np.asarray(trimesh_mesh.vertices) + faces = np.asarray(trimesh_mesh.faces) + return o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(faces) + ) + + +def update_layout(pred_t, pred_s, pred_quat, center, scale, to_halo=True): + if center is None and not to_halo: + return pred_t, pred_s, pred_quat + pred_transform = compose_transform( + pred_s, quaternion_to_matrix(pred_quat[0]), pred_t + ) + if center is None: + comb_transform = pred_transform + else: + norm_transform = compose_transform( + scale * torch.ones_like(pred_t), + torch.eye(3, dtype=pred_t.dtype).to(pred_t.device)[None], + scale * -torch.tensor(center, dtype=pred_t.dtype).to(pred_t.device)[None], + ) + comb_transform = norm_transform.compose(pred_transform) + comb_transform = convert_to_halo(comb_transform, pred_t.device, pred_t.dtype) + decomposed = decompose_transform(comb_transform) + quat = matrix_to_quaternion(decomposed.rotation) + return decomposed.translation, decomposed.scale, quat + + +def convert_to_halo(pred_transform, device, dtype): + on_mesh_transform = Transform3d(dtype=dtype, device=device).rotate( + torch.tensor( + [ + [1, 0, 0], + [0, 0, 1], + [0, -1, 0], + ], + dtype=dtype, + ) + ) + on_pm_transform = Transform3d(dtype=dtype, device=device).rotate( + torch.tensor( + [ + [-1, 0, 0], + [0, -1, 0], + [0, 0, 1], + ], + dtype=dtype, + ) + ) + return on_mesh_transform.compose(pred_transform).compose(on_pm_transform) + + +def quat_wxyz_to_euler_XYZ(q: torch.Tensor) -> torch.Tensor: + """ + Convert PyTorch3D quaternions (w,x,y,z) to SciPy-style Euler angles + with sequence 'XYZ' (extrinsic, radians). Works with batch dims. + + Args: + q: (..., 4) tensor in w,x,y,z order. Doesn't need to be normalized. + Returns: + angles: (..., 3) tensor [alpha_X, beta_Y, gamma_Z] in radians. + """ + q = q / q.norm(dim=-1, keepdim=True) # normalize + R = quaternion_to_matrix(q) # (..., 3, 3) + R = R.transpose(-1, -2) + + r00 = R[..., 0, 0] + r10 = R[..., 1, 0] + r20 = R[..., 2, 0] + r21 = R[..., 2, 1] + r22 = R[..., 2, 2] + + # For extrinsic XYZ (R = Rz(gamma) @ Ry(beta) @ Rx(alpha)): + # beta = atan2(-r20, sqrt(r00^2 + r10^2)) + # alpha = atan2(r21, r22) + # gamma = atan2(r10, r00) + eps = torch.finfo(R.dtype).eps + beta = torch.atan2(-r20, torch.clamp((r00 * r00 + r10 * r10).sqrt(), min=eps)) + alpha = torch.atan2(r21, r22) + gamma = torch.atan2(r10, r00) + + return -torch.stack((alpha, beta, gamma), dim=-1) + + +def format_to_halo(layout_output): + json_out = {} + quaternion = layout_output["quaternion"][0, 0] + translation = layout_output["translation"][0] + scale = list(layout_output["scale"][0]) + + euler = quat_wxyz_to_euler_XYZ(quaternion) + json_out["roll"] = float(euler[0]) + json_out["pitch"] = float(euler[1]) + json_out["yaw"] = float(euler[2]) + json_out["pred_scale"] = [float(s) for s in scale] + rot_matrix = quaternion_to_matrix(quaternion) + pred_transform = torch.eye(4, dtype=quaternion.dtype).to(quaternion.device) + pred_transform[:3, :3] = rot_matrix + pred_transform[:3, 3] = translation + pred_transform_list = [ + [float(t) for t in trans_row] for trans_row in pred_transform + ] + json_out["pred_transform"] = pred_transform_list + return json_out + + +def json_to_halo_payloads(target_data): + pred_transform = target_data["pred_transform"] + pred_scale = target_data["pred_scale"] + roll = target_data.get("roll", 0) + pitch = target_data.get("pitch", 0) + yaw = target_data.get("yaw", 0) + # Update positions, rotation, and scale in the payload + item_attachments = {} + item_attachments["positions"] = { + "x": pred_transform[0][3], + "y": pred_transform[1][3], + "z": pred_transform[2][3] - 1, # Adjust for Halo design + } + item_attachments["rotation"] = {"x": roll, "y": pitch, "z": yaw} + item_attachments["scale"] = { + "x": pred_scale[0], + "y": pred_scale[1], + "z": pred_scale[2], + } + return item_attachments + + +def o3d_plane_estimation(points): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + plane_model, inliers = pcd.segment_plane(0.02, 3, 1000) + + [a, b, c, d] = plane_model + logger.info(f"Plane equation: {a:.2f}x + {b:.2f}y + {c:.2f}z + {d:.2f} = 0") + + # Get the inlier points from RANSAC + inlier_points = np.asarray(pcd.points)[inliers] + + # Adaptive flying point removal based on Z-range + z_range = np.max(inlier_points[:, 2]) - np.min(inlier_points[:, 2]) + if z_range > 6.0: # Large range - likely flying points + thresh = 0.90 # Remove 10% + elif z_range > 2.0: # Moderate range + thresh = 0.93 # Remove 7% + else: # Small range - clean + thresh = 0.95 # Remove 5% + + depth_quantile = np.quantile(inlier_points[:, 2], thresh) + clean_points = inlier_points[inlier_points[:, 2] <= depth_quantile] + + logger.info(f"Flying point removal: {len(inlier_points)} -> {len(clean_points)} points (z_range: {z_range:.2f}m, thresh: {thresh})") + logger.info(f"Clean points Z range: [{clean_points[:, 2].min():.3f}, {clean_points[:, 2].max():.3f}]") + + # Get the normal vector of the plane + normal = np.array([a, b, c]) + normal = normal / np.linalg.norm(normal) + + # Create two orthogonal vectors in the plane using camera-aware approach + # Use Z-axis as primary tangent (depth direction in camera coords) + # This helps align one plane axis with the camera's depth direction + if abs(normal[2]) < 0.9: # Use Z-axis if normal isn't too close to Z + tangent = np.array([0, 0, 1]) + else: + tangent = np.array([1, 0, 0]) # Use X-axis otherwise + + v1 = np.cross(normal, tangent) + v1 = v1 / np.linalg.norm(v1) + v2 = np.cross(normal, v1) + v2 = v2 / np.linalg.norm(v2) # Explicit normalization for numerical stability + + # Ensure consistent right-handed coordinate system + if np.dot(np.cross(v1, v2), normal) < 0: + v2 = -v2 + + logger.info(f"Plane basis vectors - v1: [{v1[0]:.3f}, {v1[1]:.3f}, {v1[2]:.3f}], v2: [{v2[0]:.3f}, {v2[1]:.3f}, {v2[2]:.3f}]") + + # Calculate centroid using bounding box center (more robust to density bias) + min_vals = np.min(clean_points, axis=0) + max_vals = np.max(clean_points, axis=0) + centroid = (min_vals + max_vals) / 2 + logger.info(f"Bbox centroid: [{centroid[0]:.3f}, {centroid[1]:.3f}, {centroid[2]:.3f}]") + + # Project clean points onto the plane's coordinate system + relative_points = clean_points - centroid + u_coords = np.dot(relative_points, v1) # coordinates along v1 direction + v_coords = np.dot(relative_points, v2) # coordinates along v2 direction + + # Since flying points are already removed, use minimal percentile filtering [0, 99] + u_min, u_max = np.percentile(u_coords, [0, 100]) + v_min, v_max = np.percentile(v_coords, [0, 100]) + + # Calculate extents + u_extent = u_max - u_min + v_extent = v_max - v_min + + # Ensure minimum size + u_extent = max(u_extent, 0.1) # minimum 10cm + v_extent = max(v_extent, 0.1) + logger.info(f"Plane size: {u_extent:.3f}m x {v_extent:.3f}m") + + # Calculate direction away from camera center (at origin [0,0,0]) + camera_pos = np.array([0, 0, 0]) # Camera at origin + camera_to_centroid = centroid - camera_pos # Direction from camera to plane center + camera_distance = np.linalg.norm(camera_to_centroid) + away_direction = camera_to_centroid / camera_distance + + # Project away direction onto the plane (remove component normal to plane) + away_in_plane = away_direction - np.dot(away_direction, normal) * normal + away_in_plane_norm = np.linalg.norm(away_in_plane) + + # Create plane coordinate system based on camera direction + if away_in_plane_norm > 1e-6: # Only if there's a meaningful in-plane component + # Define plane axes directly based on camera direction + away_axis = away_in_plane / away_in_plane_norm # Away from camera direction (in plane) + perp_axis = np.cross(normal, away_axis) # Perpendicular to away direction (in plane) + perp_axis = perp_axis / np.linalg.norm(perp_axis) + + logger.info(f"Camera-based plane axes:") + logger.info(f" Away axis: [{away_axis[0]:.3f}, {away_axis[1]:.3f}, {away_axis[2]:.3f}]") + logger.info(f" Perp axis: [{perp_axis[0]:.3f}, {perp_axis[1]:.3f}, {perp_axis[2]:.3f}]") + + # Project all points onto this camera-aligned coordinate system + relative_points = clean_points - centroid + away_coords = np.dot(relative_points, away_axis) # coordinates along away direction + perp_coords = np.dot(relative_points, perp_axis) # coordinates perpendicular to away + + # Calculate extents in camera-aligned system + away_min, away_max = np.percentile(away_coords, [0, 100]) + perp_min, perp_max = np.percentile(perp_coords, [0, 100]) + + away_extent = max(away_max - away_min, 0.1) + perp_extent = max(perp_max - perp_min, 0.1) + + # Asymmetric extension: 10% towards camera, 50% away from camera, 20% perpendicular both sides + away_extent_extended = away_extent * 1.6 # 60% larger in away direction (10% + 50%) + perp_extent_extended = perp_extent * 1.4 # 40% larger in perpendicular direction (20% each side) + + logger.info(f"Original extents: away={away_extent:.3f}m, perp={perp_extent:.3f}m") + logger.info(f"Extended extents: away={away_extent_extended:.3f}m, perp={perp_extent_extended:.3f}m") + + # Extension amounts for each direction + away_extension_near = away_extent * 0.1 # 10% extension towards camera (near side) + away_extension_far = away_extent * 0.5 # 50% extension away from camera (far side) + perp_extension = perp_extent * 0.2 # 20% extension on each perpendicular side + + logger.info(f"Extensions: near={away_extension_near:.3f}m, far={away_extension_far:.3f}m, perp={perp_extension:.3f}m per side") + logger.info(f"Extending plane asymmetrically: 10% towards camera, 50% away from camera, 20% perpendicular both sides") + + corners = [] + for da in [-1, 1]: + for dp in [-1, 1]: + # Asymmetric extension in away direction + if da == 1: # Away from camera side - extend by 50% + away_distance = away_extent/2 + away_extension_far + else: # Near camera side - extend by 10% + away_distance = da * (away_extent/2 + away_extension_near) + + # Extend perpendicular direction by 20% on both sides + perp_distance = dp * (perp_extent/2 + perp_extension) + + corner = (centroid + + away_distance * away_axis + + perp_distance * perp_axis) + corners.append(corner) + else: + # If plane is parallel to camera direction, use original v1/v2 system + logger.info("Plane parallel to camera direction, using original coordinate system") + corners = [] + for dx in [-1, 1]: + for dy in [-1, 1]: + corner = centroid + dx * (u_extent/2) * v1 + dy * (v_extent/2) * v2 + corners.append(corner) + corners = np.array(corners) + # Create a quad mesh using trimesh + # Define vertices (4 corners) + vertices = corners + # Define a single quad face (indices of the 4 vertices) + # Make sure the order is correct for proper orientation + faces = np.array([[0, 1, 3, 2]]) # quad face + # Create trimesh with quad faces + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + mesh = trimesh.Trimesh( + vertices=vertices, + faces=faces, + process=False # Important: prevents automatic triangulation + ) + # Optional: set face colors + mesh.visual.face_colors = [128, 128, 128, 255] # gray color (RGBA) + + return mesh + + +def estimate_plane_area(mask): + """ + Calculate the area covered by the mask's 2D bounding box as a fraction of total image area. + """ + if mask.numel() == 0: + return 0.0 + + # Find coordinates where mask > 0.5 (valid mask pixels) + valid_mask = mask > 0.5 + + # If no valid pixels, return 0 + if not torch.any(valid_mask): + return 0.0 + + # Get mask dimensions + H, W = mask.shape + total_area = H * W + + # Find bounding box coordinates + # Get row and column indices of valid pixels + valid_coords = torch.nonzero(valid_mask, as_tuple=False) # Returns [N, 2] array of [row, col] + + if valid_coords.size(0) == 0: + return 0.0 + + # Find min/max coordinates to form bounding box + min_row = torch.min(valid_coords[:, 0]).item() + max_row = torch.max(valid_coords[:, 0]).item() + min_col = torch.min(valid_coords[:, 1]).item() + max_col = torch.max(valid_coords[:, 1]).item() + + # Calculate bounding box dimensions + bbox_height = max_row - min_row + 1 + bbox_width = max_col - min_col + 1 + bbox_area = bbox_height * bbox_width + + # Return ratio of bounding box area to total image area + return bbox_area / total_area \ No newline at end of file diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/layout_post_optimization_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/layout_post_optimization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d680868461068db7555a3e9f5a8dbb844717d35 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/layout_post_optimization_utils.py @@ -0,0 +1,445 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import os +import torch +import torch.nn.functional as F +import numpy as np +import cv2 +from pytorch3d.structures import Meshes +from pytorch3d.transforms import quaternion_to_matrix +from pytorch3d.renderer import ( + PerspectiveCameras, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftSilhouetteShader, + BlendParams, + TexturesVertex, +) +from pytorch3d.transforms import quaternion_to_matrix, Transform3d +import random +import open3d as o3d +from scipy.ndimage import label, binary_dilation, binary_fill_holes, binary_erosion, minimum_filter + +def remove_small_regions(mask, min_area=100): + """ + Remove small disconnected regions (floating points) from the mask. + Keeps all regions with area >= min_area. + """ + labeled_mask, num_labels = label(mask) + cleaned = np.zeros_like(mask, dtype=bool) + for i in range(1, num_labels + 1): + region = (labeled_mask == i) + if region.sum() >= min_area: + cleaned |= region + return cleaned + +def is_near_image_border(mask, border_thickness=10): + """ + Check if the mask touches the image border within a given thickness. + """ + border_mask = np.zeros_like(mask, dtype=bool) + border_mask[:border_thickness, :] = True + border_mask[-border_thickness:, :] = True + border_mask[:, :border_thickness] = True + border_mask[:, -border_thickness:] = True + return np.any(mask & border_mask) + +def is_occluded_by_others(mask, point_map, dilation_iter=2, z_thresh=0.05, filter_size=3): + """ + Efficient occlusion detection using depth map and internal/external edges. + """ + z_map = point_map[..., 2] + if not np.any(mask): + return False + + # Create internal and external edge masks + eroded = binary_erosion(mask, iterations=dilation_iter) + dilated = binary_dilation(mask, iterations=dilation_iter) + + internal_edge = mask & (~eroded) + external_edge = dilated & (~mask) + + # Set invalid areas to +inf so they don't affect min-pooling + z_ext = np.where(external_edge, z_map, np.inf) + + # Apply minimum filter to get local min depth around internal edges + z_ext_min = minimum_filter(z_ext, size=filter_size, mode='constant', cval=np.inf) + + # Depth values at internal edge + z_int = np.where(internal_edge, z_map, np.nan) + + # Compare depth difference + diff = z_int - z_ext_min + occlusion_mask = (diff > z_thresh) & (~np.isnan(diff)) + + # return np.any(occlusion_mask) + return np.sum(occlusion_mask) > 10 + +def has_internal_occlusion(mask, min_hole_area=20): + """ + Check if the mask has internal holes or has been split into fragments. + This may indicate internal occlusion. + """ + # Check number of connected components + labeled, num_features = label(mask) + if num_features > 1: + return True # Mask is fragmented + + # Check for internal holes + filled = binary_fill_holes(mask) + holes = filled & (~mask) + return np.sum(holes) >= min_hole_area + +def check_occlusion(mask, point_map, + min_region_area=25, + border_thickness=5, + z_thresh=0.3, + min_hole_area=100): + """ + Main function to check different types of occlusion for a given mask and 3D point map. + """ + # clean mask by removing floating points + cleaned_mask = remove_small_regions(mask, min_area=min_region_area) + dilation_iter = 2 + filter_size = 2 * dilation_iter + 1 + + # run occlusion checks + return ( + is_near_image_border(cleaned_mask, border_thickness) + or is_occluded_by_others(cleaned_mask, point_map, dilation_iter, z_thresh, filter_size) + or has_internal_occlusion(cleaned_mask, min_hole_area) + ) + +def get_mesh(Mesh, tfm_ori, device): + mesh_vertices = Mesh.vertices.copy() + # rotate mesh (from z-up to y-up) + mesh_vertices = mesh_vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]).T + mesh_vertices = torch.from_numpy(mesh_vertices).float().cuda() + points_world = tfm_ori.transform_points(mesh_vertices.unsqueeze(0)) + Mesh.vertices = points_world[0].cpu().numpy() # pytorch3d, y-up, x left, z inwards. + verts, faces_idx = load_and_simplify_mesh(Mesh, device) + # === Add dummy white texture === + textures = TexturesVertex(verts_features=torch.ones_like(verts)[None]) # (1, V, 3) + mesh = Meshes(verts=[verts], faces=[faces_idx], textures=textures) + + return mesh, faces_idx, textures + + +def get_mask_renderer(Mask, min_size, Intrinsics, device): + orig_h, orig_w = Mask.shape[-2:] + min_orig_size = min(orig_w, orig_h) + scale_factor = min_size / min_orig_size + mask = F.interpolate( + Mask[None, None], + scale_factor=scale_factor, + mode="bilinear", + align_corners=False, + ) + H, W = mask.shape[-2:] + + intrinsics = denormalize_f(Intrinsics.cpu().numpy(), H, W) + cameras = PerspectiveCameras( + focal_length=torch.tensor( + [[intrinsics[0, 0], intrinsics[1, 1]]], device=device, dtype=torch.float32 + ), + principal_point=torch.tensor( + [[intrinsics[0, 2], intrinsics[1, 2]]], device=device, dtype=torch.float32 + ), + image_size=torch.tensor([[H, W]], device=device, dtype=torch.float32), + in_ndc=False, + device=device, + ) + raster_settings = RasterizationSettings( + image_size=(H, W), + blur_radius=1e-6, + faces_per_pixel=50, + max_faces_per_bin=50000, + ) + blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0)) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=SoftSilhouetteShader(blend_params=blend_params), + ) + + return mask, renderer + + +def run_alignment( + Point_Map, + mask, + mesh, + center, + faces_idx, + textures, + renderer, + device, + align_pm_coordinate=False, +): + + # from point map coordinate to pytorch3d + target_object_points = Point_Map[mask[0, 0].bool()] + if align_pm_coordinate: + target_object_points[:, 0] *= -1 + target_object_points[:, 1] *= -1 + # Get rid of flying points + thresh = 0.9 + depth_quantile = torch.quantile(target_object_points[:, 2], thresh) + target_object_points = target_object_points[ + target_object_points[:, 2] <= depth_quantile + ] + flag_notgt = False + + if target_object_points.shape[0] == 0: + flag_notgt = True + return None, None, None, None, None, None, None, flag_notgt + + source_points, target_points = mesh.verts_packed(), target_object_points + # align to moge object points. + height_src = torch.max(source_points[:, 1]) - torch.min(source_points[:, 1]) + height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1]) + scale_1 = height_tgt / height_src + source_points *= scale_1 + center *= scale_1 + + center_src = torch.mean(source_points, dim=0) + center_tgt = torch.mean(target_points, dim=0) + translation_1 = center_tgt - center_src + + source_points += translation_1 + center += translation_1 + + # manually align based on moge point cloud. + tfm1 = ( + Transform3d(device=device) + .scale(scale_1.expand(3)[None]) + .translate(translation_1[None]) + ) + mesh = Meshes(verts=[source_points], faces=[faces_idx], textures=textures) + rendered = renderer(mesh) + ori_iou = compute_iou(rendered[..., 3][0][None, None], mask, threshold=0.5) + final_iou = ori_iou.cpu().item() + + return source_points, target_points, center, tfm1, mesh, ori_iou, final_iou, flag_notgt + + +def apply_transform(mesh, center, quat, translation, scale): + quat_normalized = quat / quat.norm() + R = quaternion_to_matrix(quat_normalized) + # transform to the world coordinate system center. + verts = mesh.verts_packed() - center + # perform operation + verts = verts * scale + verts = verts @ R.transpose(0, 1) + # transform back to the original position after rotation. + verts += center + verts = verts + translation + + transformed_mesh = Meshes( + verts=[verts], faces=[mesh.faces_packed()], textures=mesh.textures + ) + return transformed_mesh + + +def compute_loss(rendered, mask_gt, loss_weights, quat, translation, scale): + + pred_mask = rendered[..., 3][0] + # === 1. MSE Loss on mask === + loss_mask = F.mse_loss(pred_mask, mask_gt[0, 0]) + + # === 2. Reg Loss on quaternion === + quat_normalized = quat / quat.norm() + loss_reg_q = F.mse_loss( + quat_normalized, torch.tensor([1.0, 0.0, 0.0, 0.0], device=quat.device) + ) + loss_reg_t = torch.norm(translation) ** 2 + loss_reg_s = (scale - 1.0) ** 2 + + # === Total weighted loss === + total_loss = ( + loss_weights["mask"] * loss_mask + + loss_weights["reg_q"] * loss_reg_q + + loss_weights["reg_t"] * loss_reg_t + + loss_weights["reg_s"] * loss_reg_s + ) + + return total_loss + + +def export_transformed_mesh_glb( + verts, mesh_obj, center, quat, translation, scale, output_path +): + quat_normalized = quat / quat.norm() + + R = quaternion_to_matrix(quat_normalized) + # transform to the world coordinate system center. + verts -= center + # perform operations. + verts = verts * scale + verts = verts @ R.transpose(0, 1) + # transform back to the original position after rotation. + verts += center + verts = verts + translation + + mesh_obj.vertices = verts.cpu().numpy() + output_path = os.path.join(output_path, "result.glb") + # import pdb + # pdb.set_trace() + mesh_obj.export(output_path) + return + + +def set_seed(seed=100): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def load_and_simplify_mesh(Mesh, device, target_triangles=5000): + + vertices = np.asarray(Mesh.vertices) + faces = np.asarray(Mesh.faces) + mesh_o3d = o3d.geometry.TriangleMesh() + mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices) + mesh_o3d.triangles = o3d.utility.Vector3iVector(faces) + + mesh_o3d.remove_duplicated_vertices() + mesh_o3d.remove_degenerate_triangles() + mesh_o3d.remove_duplicated_triangles() + mesh_o3d.remove_non_manifold_edges() + + if len(mesh_o3d.triangles) > target_triangles: + mesh_simplified = mesh_o3d.simplify_quadric_decimation(target_triangles) + else: + mesh_simplified = mesh_o3d + + verts = torch.tensor( + np.asarray(mesh_simplified.vertices), dtype=torch.float32, device=device + ) + faces = torch.tensor( + np.asarray(mesh_simplified.triangles), dtype=torch.int64, device=device + ) + + return verts, faces + + +def compute_iou(render_mask_obj, mask_obj_gt, threshold=0.5): + + # Binarize masks + pred = (render_mask_obj > threshold).float() + gt_obj = (mask_obj_gt > threshold).float() + + # mask = pred[0, 0].cpu().numpy() * 255 + # mask_uint8 = mask.astype(np.uint8) + # cv2.imwrite(path, mask_uint8) + + # Compute intersection and union + intersection = (pred * gt_obj).sum() + union = ((pred + gt_obj) > 0).float().sum() + + if union == 0: + return torch.tensor(1.0 if intersection == 0 else 0.0) # avoid division by zero + + iou = intersection / union + return iou + + +def denormalize_f(norm_K, height, width): + # Extract cx and cy from the normalized K matrix + cx_norm = norm_K[0][2] # c_x is at K[0][2] + cy_norm = norm_K[1][2] # c_y is at K[1][2] + + fx_norm = norm_K[0][0] # Normalized fx + fy_norm = norm_K[1][1] # Normalized fy + s_norm = norm_K[0][1] # Skew (usually 0) + + # Scale to absolute values + fx_abs = fx_norm * width + fy_abs = fy_norm * height + cx_abs = cx_norm * width + cy_abs = cy_norm * height + s_abs = s_norm * width + + # Construct absolute K matrix + abs_K = np.array([[fx_abs, s_abs, cx_abs], [0.0, fy_abs, cy_abs], [0.0, 0.0, 1.0]]) + return abs_K + + +# Convert torch tensors to Open3D point clouds +def tensor_to_o3d_pcd(tensor): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(tensor.cpu().numpy()) + return pcd + + +# Convert Open3D back to torch tensor +def o3d_to_tensor(pcd): + return torch.tensor(np.asarray(pcd.points), dtype=torch.float32) + + +def run_ICP(source_points_mesh, source_points, target_points, threshold): + # Convert your point clouds + mesh_src_pcd = tensor_to_o3d_pcd(source_points_mesh.verts_padded().squeeze(0)) + src_pcd = tensor_to_o3d_pcd(source_points) + tgt_pcd = tensor_to_o3d_pcd(target_points) + + # Run ICP + trans_init = np.eye(4) + reg_p2p = o3d.pipelines.registration.registration_icp( + src_pcd, + tgt_pcd, + threshold, + trans_init, + o3d.pipelines.registration.TransformationEstimationPointToPoint(), + ) + + # Apply transformation + mesh_src_pcd.transform(reg_p2p.transformation) + points_aligned_icp = o3d_to_tensor(mesh_src_pcd).to(source_points.device) + + return points_aligned_icp, reg_p2p.transformation + + +def run_render_compare(mesh, center, renderer, mask, device): + + quat = torch.nn.Parameter( + torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, requires_grad=True) + ) + translation = torch.nn.Parameter( + torch.tensor([0.0, 0.0, 0.0], device=device, requires_grad=True) + ) + scale = torch.nn.Parameter(torch.tensor(1.0, device=device, requires_grad=True)) + + def get_optimizer(stage): + if stage == 1: + return torch.optim.Adam([translation, scale], lr=1e-2) + elif stage == 2: + return torch.optim.Adam([quat, translation, scale], lr=5e-3) + + loss_weights = {"mask": 200, "reg_q": 0.1, "reg_t": 0.05, "reg_s": 0.05} + prev_loss = None + + global_step = 0 + for stage in [1, 2]: + optimizer = get_optimizer(stage) + iters = [5, 25] + for i in range(iters[stage - 1]): + optimizer.zero_grad() + transformed = apply_transform(mesh, center, quat, translation, scale) + rendered = renderer(transformed) + loss = compute_loss(rendered, mask, loss_weights, quat, translation, scale) + loss.backward() + optimizer.step() + global_step += 1 + if prev_loss is not None and abs(loss.item() - prev_loss) < 1e-5: + break + prev_loss = loss.item() + + quat, translation, scale = quat.detach(), translation.detach(), scale.detach() + quat_normalized = quat / quat.norm() + R = quaternion_to_matrix(quat_normalized) + + return quat, translation, scale, R diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/preprocess_utils.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/preprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe601558e65bfb3850e800a895d68cbb900a5fa --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/preprocess_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from typing import Union +import torch +import numpy as np +from functools import partial + +from PIL import Image + + +from sam3d_objects.data.dataset.tdfy.preprocessor import PreProcessor +from torchvision.transforms import Compose, Resize, InterpolationMode +from sam3d_objects.data.dataset.tdfy.img_processing import pad_to_square_centered +from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( + rembg, + crop_around_mask_with_padding, +) + + +def get_default_preprocessor(): + preprocessor = PreProcessor() + img_transform = Compose( + transforms=[ + partial(pad_to_square_centered), + Resize(size=518, interpolation=InterpolationMode.BICUBIC), + ] + ) + mask_transform = Compose( + transforms=[ + partial(pad_to_square_centered), + Resize(size=518, interpolation=0), + ] + ) + img_mask_joint_transform = [ + partial(crop_around_mask_with_padding, box_size_factor=1.0, padding_factor=0.1), + rembg, + ] + preprocessor.img_transform = img_transform + preprocessor.mask_transform = mask_transform + preprocessor.img_mask_joint_transform = img_mask_joint_transform + + return preprocessor diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/utils/pointmap.py b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/utils/pointmap.py new file mode 100644 index 0000000000000000000000000000000000000000..82a2e5e4be071472a566142be94741f5aefe1f30 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/pipeline/utils/pointmap.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +""" +Utility functions for point map processing and intrinsics inference. +Extracted from moge library for use in sam3d_objects pipeline. +""" + +from typing import Optional, Tuple, Union +import torch +import utils3d +# Import directly from moge for exact compatibility +from moge.utils.geometry_torch import ( + normalized_view_plane_uv, + recover_focal_shift, +) +from moge.utils.geometry_numpy import ( + solve_optimal_focal_shift, + solve_optimal_shift, +) + + +def infer_intrinsics_from_pointmap( + points: torch.Tensor, + mask: Optional[torch.Tensor] = None, + fov_x: Optional[Union[float, torch.Tensor]] = None, + mask_threshold: float = 0.5, + force_projection: bool = False, + apply_mask: bool = False, + device: Optional[torch.device] = None +) -> dict: + """ + Infer camera intrinsics from a point map. + + Exact implementation matching moge library's inference logic. + + Args: + points: Point map tensor of shape (B, H, W, 3) or (H, W, 3) + mask: Optional mask tensor of shape (B, H, W) or (H, W) + fov_x: Optional horizontal field of view in degrees. If None, inferred from points + mask_threshold: Threshold for binary mask creation + force_projection: If True, recompute points using depth and intrinsics + apply_mask: If True, apply mask to output points and depth + device: Device for computation. If None, uses points.device + + Returns: + Dictionary containing: + - 'points': Camera-space points + - 'intrinsics': Camera intrinsics matrix + - 'depth': Depth map + - 'mask': Binary mask + """ + if device is None: + device = points.device + + # Handle batch dimension + squeeze_batch = False + if points.dim() == 3: + points = points.unsqueeze(0) + if mask is not None: + mask = mask.unsqueeze(0) + squeeze_batch = True + + height, width = points.shape[1:3] + aspect_ratio = width / height + + # Always process the output in fp32 precision + with torch.autocast(device_type=device.type, dtype=torch.float32): + points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x]) + + mask_binary = mask > mask_threshold if mask is not None else torch.ones_like(points[..., 0], dtype=torch.bool) + + # Add finite check to handle NaN and inf values + finite_mask = torch.isfinite(points).all(dim=-1) + mask_binary = mask_binary & finite_mask + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + # BUG: Recover focal shift numpy method has flipped outputs: https://github.com/microsoft/MoGe/issues/110 + shift, focal = recover_focal_shift(points, mask_binary) + else: + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + else: + shift_stacked = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] + points = points + shift_stacked + + # Apply mask if needed + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + return_dict = { + 'points': points.squeeze(0) if squeeze_batch else points, + 'intrinsics': intrinsics.squeeze(0) if squeeze_batch else intrinsics, + 'depth': depth.squeeze(0) if squeeze_batch else depth, + 'mask': mask_binary.squeeze(0) if squeeze_batch else mask_binary, + } + + return return_dict diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca4b12c770afea62d06f97064cdf0c97d40ed7 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/__init__.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65e0c535b199effb04803102b81b65a4f8854e4b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from .scene_visualizer import SceneVisualizer +from .plotly.plot_scene import plot_tdfy_scene diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/image_mesh.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/image_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..81f996f5d5184df7e6b2d9130c2ea5cbf56fe1ed --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/image_mesh.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +from collections import namedtuple +from typing import Tuple, Optional, Union +import numpy as np +import torch +from pytorch3d.structures import Meshes +from pytorch3d.renderer.mesh.textures import TexturesVertex + +from utils3d.numpy import ( + depth_edge, + normals_edge, + points_to_normals, + image_uv, + image_mesh, +) + +np.acos = np.arccos # sam3d_objects-3dfy version of numpy doesn't have acos + +MeshAndTexture = namedtuple( + "mesh_and_texture", ["faces", "vertices", "vertex_colors", "vertex_uvs"] +) + + +def mesh_from_pointmap( + pointmap: np.ndarray, + image: np.ndarray, + mask: Optional[np.ndarray] = None, + depth: Optional[np.ndarray] = None, + filter_edges: bool = True, + depth_edge_rtol: float = 0.03, + depth_edge_tol: float = 5, +) -> MeshAndTexture: + """ + Create a mesh from pointmap and image. + Returns: + faces, vertices, vertex_colors, vertex_uvs: Mesh components + """ + assert pointmap.ndim == 3, pointmap.shape + assert pointmap.shape[-1] == 3, pointmap.shape + assert image.ndim == 3, image.shape + assert image.shape[-1] == 3, image.shape + + if mask is None: + mask = np.ones_like(pointmap[..., 2], dtype=np.float32) > 0 + + if depth is None: + depth = pointmap[..., 2] + + height, width = image.shape[:2] + normals, normals_mask = points_to_normals(pointmap, mask=mask) + + if filter_edges: + mask = mask & ~( + depth_edge(depth, rtol=depth_edge_rtol, mask=mask) + & normals_edge(normals, tol=depth_edge_tol, mask=normals_mask) + ) + + faces, vertices, vertex_colors, vertex_uvs = image_mesh( + pointmap, + image.astype(np.float32), + image_uv(width=width, height=height), + mask=mask, + tri=True, + ) + vertices, vertex_uvs = vertices * [1, 1, 1], vertex_uvs * [1, -1] + [0, 1] + return MeshAndTexture(faces, vertices, vertex_colors, vertex_uvs) + + +def create_textured_mesh( + verts: torch.Tensor, + faces: torch.Tensor, + vert_colors: torch.Tensor, +) -> Meshes: + tex = TexturesVertex(verts_features=[vert_colors]) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + return mesh diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/plot_scene.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/plot_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..79aab3bb1c84e81457e45042bbfb3a0bb1557788 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/plot_scene.py @@ -0,0 +1,563 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# Adapted from pytorch3d.viz.plotly_vis which has license: +# BSD License + +# For PyTorch3D software + +# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name Meta nor the names of its contributors may be used to +# endorse or promote products derived from this software without specific +# prior written permission. + +import math + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import warnings +from typing import Dict, List, Optional, Union + +import numpy as np +import plotly.graph_objects as go +import torch +from plotly.subplots import make_subplots +from pytorch3d.renderer import ( + HeterogeneousRayBundle, + RayBundle, + TexturesAtlas, + TexturesVertex, + ray_bundle_to_ray_points, +) +from pytorch3d.renderer.camera_utils import camera_to_eye_at_up +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import ( + Meshes, + Pointclouds, + join_meshes_as_scene, +) +from pytorch3d.vis.plotly_vis import ( + AxisArgs, + Lighting, + _add_camera_trace, + _add_pointcloud_trace, + _add_ray_bundle_trace, + _is_ray_bundle, + _scale_camera_to_bounds, + _update_axes_bounds, +) + + +Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle, HeterogeneousRayBundle] + + +default_axisargs = dict( + xaxis={"backgroundcolor": "rgb(200, 200, 230)"}, + yaxis={"backgroundcolor": "rgb(230, 200, 200)"}, + zaxis={"backgroundcolor": "rgb(200, 230, 200)"}, + axis_args=AxisArgs(showgrid=True), +) + +NO_BACKGROUND = dict( + xaxis={"backgroundcolor": "rgb(255, 255, 255)", "visible": False}, + yaxis={"backgroundcolor": "rgb(255, 255, 255)", "visible": False}, + zaxis={"backgroundcolor": "rgb(255, 255, 255)", "visible": False}, +) + + +@torch.no_grad() +def plot_tdfy_scene( + plots: Dict[str, Dict[str, Struct]], + *, + viewpoint_cameras: Optional[CamerasBase] = None, + ncols: int = 1, + camera_scale: float = 0.3, + camera_wireframe_width: int = 3, + pointcloud_max_points: int = 20000, + pointcloud_marker_size: int = 1, + raybundle_max_rays: int = 20000, + raybundle_max_points_per_ray: int = 1000, + raybundle_ray_point_marker_size: int = 1, + raybundle_ray_line_width: int = 1, + boxes_wireframe_width: int = 1, + boxes_add_cross_face_bars: bool = False, + boxes_name_int_to_display_name_dict: Optional[Dict[int, str]] = None, + boxes_plot_together: bool = False, + height: int = None, + width: int = None, + use_orthographic: bool = False, + equalticks: bool = True, + ticklen: float = 1.0, + aspectmode: str = "cube", + **kwargs, +): # pragma: no cover + """ + Main function to visualize Cameras, Meshes, Pointclouds, and RayBundle. + Plots input Cameras, Meshes, Pointclouds, and RayBundle data into named subplots, + with named traces based on the dictionary keys. Cameras are + rendered at the camera center location using a wireframe. + + Args: + plots: A dict containing subplot and trace names, + as well as the Meshes, Cameras and Pointclouds objects to be rendered. + See below for examples of the format. + viewpoint_cameras: an instance of a Cameras object providing a location + to view the plotly plot from. If the batch size is equal + to the number of subplots, it is a one to one mapping. + If the batch size is 1, then that viewpoint will be used + for all the subplots will be viewed from that point. + Otherwise, the viewpoint_cameras will not be used. + ncols: the number of subplots per row + camera_scale: determines the size of the wireframe used to render cameras. + pointcloud_max_points: the maximum number of points to plot from + a pointcloud. If more are present, a random sample of size + pointcloud_max_points is used. + pointcloud_marker_size: the size of the points rendered by plotly + when plotting a pointcloud. + raybundle_max_rays: maximum number of rays of a RayBundle to visualize. Randomly + subsamples without replacement in case the number of rays is bigger than max_rays. + raybundle_max_points_per_ray: the maximum number of points per ray in RayBundle + to visualize. If more are present, a random sample of size + max_points_per_ray is used. + raybundle_ray_point_marker_size: the size of the ray points of a plotted RayBundle + raybundle_ray_line_width: the width of the plotted rays of a RayBundle + **kwargs: Accepts lighting (a Lighting object) and any of the args xaxis, + yaxis and zaxis which Plotly's scene accepts. Accepts axis_args, + which is an AxisArgs object that is applied to all 3 axes. + Example settings for axis_args and lighting are given at the + top of this file. + + Example: + + ..code-block::python + + mesh = ... + point_cloud = ... + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }) + fig.show() + + The above example will render one subplot which has both a mesh and pointcloud. + + If the Meshes, Pointclouds, or Cameras objects are batched, then every object in that batch + will be plotted in a single trace. + + ..code-block::python + mesh = ... # batch size 2 + point_cloud = ... # batch size 2 + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }) + fig.show() + + The above example renders one subplot with 2 traces, each of which renders + both objects from their respective batched data. + + Multiple subplots follow the same pattern: + ..code-block::python + mesh = ... # batch size 2 + point_cloud = ... # batch size 2 + fig = plot_scene({ + "subplot1_title": { + "mesh_trace_title": mesh[0], + "pointcloud_trace_title": point_cloud[0] + }, + "subplot2_title": { + "mesh_trace_title": mesh[1], + "pointcloud_trace_title": point_cloud[1] + } + }, + ncols=2) # specify the number of subplots per row + fig.show() + + The above example will render two subplots, each containing a mesh + and a pointcloud. The ncols argument will render two subplots in one row + instead of having them vertically stacked because the default is one subplot + per row. + + To view plotly plots from a PyTorch3D camera's point of view, we can use + viewpoint_cameras: + ..code-block::python + mesh = ... # batch size 2 + R, T = look_at_view_transform(2.7, 0, [0, 180]) # 2 camera angles, front and back + # Any instance of CamerasBase works, here we use FoVPerspectiveCameras + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + fig = plot_scene({ + "subplot1_title": { + "mesh_trace_title": mesh[0] + }, + "subplot2_title": { + "mesh_trace_title": mesh[1] + } + }, + viewpoint_cameras=cameras) + fig.show() + + The above example will render the first subplot seen from the camera on the +z axis, + and the second subplot from the viewpoint of the camera on the -z axis. + + We can visualize these cameras as well: + ..code-block::python + mesh = ... + R, T = look_at_view_transform(2.7, 0, [0, 180]) # 2 camera angles, front and back + # Any instance of CamerasBase works, here we use FoVPerspectiveCameras + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + fig = plot_scene({ + "subplot1_title": { + "mesh_trace_title": mesh, + "cameras_trace_title": cameras, + }, + }) + fig.show() + + The above example will render one subplot with the mesh object + and two cameras. + + RayBundle visualization is also supproted: + ..code-block::python + cameras = PerspectiveCameras(...) + ray_bundle = RayBundle(origins=..., lengths=..., directions=..., xys=...) + fig = plot_scene({ + "subplot1_title": { + "ray_bundle_trace_title": ray_bundle, + "cameras_trace_title": cameras, + }, + }) + fig.show() + + For an example of using kwargs, see below: + ..code-block::python + mesh = ... + point_cloud = ... + fig = plot_scene({ + "subplot_title": { + "mesh_trace_title": mesh, + "pointcloud_trace_title": point_cloud + } + }, + axis_args=AxisArgs(backgroundcolor="rgb(200,230,200)")) # kwarg axis_args + fig.show() + + The above example will render each axis with the input background color. + + See the tutorials in pytorch3d/docs/tutorials for more examples + (namely rendered_color_points.ipynb and rendered_textured_meshes.ipynb). + """ + + subplots = list(plots.keys()) + fig = _gen_fig_with_subplots(len(subplots), ncols, subplots) + lighting = kwargs.get("lighting", Lighting())._asdict() + axis_args_dict = kwargs.get("axis_args", AxisArgs(showgrid=True))._asdict() + + # Set axis arguments to defaults defined at the top of this file + x_settings = {**axis_args_dict} + y_settings = {**axis_args_dict} + z_settings = {**axis_args_dict} + + # Update the axes with any axis settings passed in as kwargs. + x_settings.update(**kwargs.get("xaxis", {"backgroundcolor": "rgb(200, 200, 230)"})) + y_settings.update(**kwargs.get("yaxis", {"backgroundcolor": "rgb(230, 200, 200)"})) + z_settings.update(**kwargs.get("zaxis", {"backgroundcolor": "rgb(200, 230, 200)"})) + camera = { + "up": { + "x": 0.0, + "y": 0.0, + "z": 1.0, + } # set the up vector to match PyTorch3D world coordinates conventions + } + viewpoints_eye_at_up_world = None + if viewpoint_cameras: + n_viewpoint_cameras = len(viewpoint_cameras) + if n_viewpoint_cameras == len(subplots) or n_viewpoint_cameras == 1: + # Calculate the vectors eye, at, up in world space + # to initialize the position of the camera in + # the plotly figure + viewpoints_eye_at_up_world = camera_to_eye_at_up( + viewpoint_cameras.get_world_to_view_transform().cpu() + ) + else: + msg = "Invalid number {} of viewpoint cameras were provided. Either 1 \ + or {} cameras are required".format( + len(viewpoint_cameras), len(subplots) + ) + warnings.warn(msg) + + for subplot_idx in range(len(subplots)): + subplot_name = subplots[subplot_idx] + traces = plots[subplot_name] + for trace_name, struct in traces.items(): + if isinstance(struct, Meshes): + _add_mesh_trace(fig, struct, trace_name, subplot_idx, ncols, lighting) + elif isinstance(struct, Pointclouds): + _add_pointcloud_trace( + fig, + struct, + trace_name, + subplot_idx, + ncols, + pointcloud_max_points, + pointcloud_marker_size, + ) + elif isinstance(struct, CamerasBase): + _add_camera_trace( + fig, struct, trace_name, subplot_idx, ncols, camera_scale + ) + elif isinstance(struct, CamTrace): + struct._add_camera_trace( + fig=fig, + trace_name=trace_name, + subplot_idx=subplot_idx, + ncols=ncols, + camera_wireframe_width=camera_wireframe_width, + ) + elif _is_ray_bundle(struct): + _add_ray_bundle_trace( + fig, + struct, + trace_name, + subplot_idx, + ncols, + raybundle_max_rays, + raybundle_max_points_per_ray, + raybundle_ray_point_marker_size, + raybundle_ray_line_width, + ) + else: + raise ValueError( + "struct {} is not a Cameras, Meshes, BBoxes3D, Pointclouds,".format( + struct + ) + + "RayBundle or HeterogeneousRayBundle object." + ) + + # Ensure update for every subplot. + plot_scene = "scene" + str(subplot_idx + 1) + current_layout = fig["layout"][plot_scene] + xaxis = current_layout["xaxis"] + yaxis = current_layout["yaxis"] + zaxis = current_layout["zaxis"] + + # mins = min([axis['range'][0] for axis in (xaxis, yaxis, zaxis)]) + # maxes = max([axis['range'][1] for axis in (xaxis, yaxis, zaxis)]) + # xaxis['range'] = [mins, maxes] + # yaxis['range'] = [mins, maxes] + # zaxis['range'] = [mins, maxes] + maxlen = max( + [abs(axis["range"][1] - axis["range"][0]) for axis in (xaxis, yaxis, zaxis)] + ) + halflen = maxlen / 2.0 + nticks = math.ceil(maxlen / ticklen) + xaxis["range"] = [ + sum(xaxis["range"]) / 2.0 + delta for delta in [-halflen, halflen] + ] + yaxis["range"] = [ + sum(yaxis["range"]) / 2.0 + delta for delta in [-halflen, halflen] + ] + zaxis["range"] = [ + sum(zaxis["range"]) / 2.0 + delta for delta in [-halflen, halflen] + ] + + xaxis["nticks"] = nticks + yaxis["nticks"] = nticks + zaxis["nticks"] = nticks + + # Update the axes with our above default and provided settings. + xaxis.update(**x_settings) + yaxis.update(**y_settings) + zaxis.update(**z_settings) + + # update camera viewpoint if provided + if viewpoints_eye_at_up_world is not None: + # Use camera params for batch index or the first camera if only one provided. + viewpoint_idx = min(n_viewpoint_cameras - 1, subplot_idx) + + eye, at, up = (i[viewpoint_idx] for i in viewpoints_eye_at_up_world) + eye_x, eye_y, eye_z = eye.tolist() + at_x, at_y, at_z = at.tolist() + up_x, up_y, up_z = up.tolist() + + # scale camera eye to plotly [-1, 1] ranges + x_range = xaxis["range"] + y_range = yaxis["range"] + z_range = zaxis["range"] + + eye_x = _scale_camera_to_bounds(eye_x, x_range, True) + eye_y = _scale_camera_to_bounds(eye_y, y_range, True) + eye_z = _scale_camera_to_bounds(eye_z, z_range, True) + + at_x = _scale_camera_to_bounds(at_x, x_range, True) + at_y = _scale_camera_to_bounds(at_y, y_range, True) + at_z = _scale_camera_to_bounds(at_z, z_range, True) + + up_x = _scale_camera_to_bounds(up_x, x_range, False) + up_y = _scale_camera_to_bounds(up_y, y_range, False) + up_z = _scale_camera_to_bounds(up_z, z_range, False) + + camera["eye"] = {"x": eye_x, "y": eye_y, "z": eye_z} + camera["center"] = {"x": at_x, "y": at_y, "z": at_z} + camera["up"] = {"x": up_x, "y": up_y, "z": up_z} + camera["projection"] = {"type": "orthographic"} + + current_layout.update( + { + "xaxis": xaxis, + "yaxis": yaxis, + "zaxis": zaxis, + # "aspectmode": "data", + "aspectmode": aspectmode, + # "aspectratio": { + # 'x': 1.0, + # 'y': 1.0, + # 'z': 1.0, + # }, + "camera": camera, + } + ) + if width is not None or height is not None: + fig.update_layout( + width=width, + height=height, + # aspectmode="data" + ) + + if use_orthographic: + # fig.update_scenes(aspectmode='data') + fig.layout.scene.camera.projection.type = "orthographic" + return fig + + +def _gen_fig_with_subplots( + batch_size: int, + ncols: int, + subplot_titles: List[str], + row_heights: Optional[List[int]] = None, + column_widths: Optional[List[int]] = None, +): # pragma: no cover + """ + Takes in the number of objects to be plotted and generate a plotly figure + with the appropriate number and orientation of titled subplots. + Args: + batch_size: the number of elements in the batch of objects to be visualized. + ncols: number of subplots in the same row. + subplot_titles: titles for the subplot(s). list of strings of length batch_size. + + Returns: + Plotly figure with ncols subplots per row, and batch_size subplots. + """ + fig_rows = batch_size // ncols + if batch_size % ncols != 0: + fig_rows += 1 # allow for non-uniform rows + fig_cols = ncols + fig_type = [{"type": "scene"}] + specs = [fig_type * fig_cols] * fig_rows + # subplot_titles must have one title per subplot + fig = make_subplots( + rows=fig_rows, + cols=fig_cols, + specs=specs, + subplot_titles=subplot_titles, + column_widths=[1.0] * fig_cols, + ) + return fig + + +# From https://github.com/facebookresearch/pytorch3d/blob/0a59450f0ebbe12d9a8db3de937814932517633b/pytorch3d/vis/plotly_vis.py#L634 +def _add_mesh_trace( + fig: go.Figure, + meshes: Meshes, + trace_name: str, + subplot_idx: int, + ncols: int, + lighting: Lighting, +) -> None: # pragma: no cover + """ + Adds a trace rendering a Meshes object to the passed in figure, with + a given name and in a specific subplot. + + Args: + fig: plotly figure to add the trace within. + meshes: Meshes object to render. It can be batched. + trace_name: name to label the trace with. + subplot_idx: identifies the subplot, with 0 being the top left. + ncols: the number of subplots per row. + lighting: a Lighting object that specifies the Mesh3D lighting. + """ + + mesh = join_meshes_as_scene(meshes) + mesh = mesh.detach().cpu() + verts = mesh.verts_packed() + faces = mesh.faces_packed() + # If mesh has vertex colors or face colors, use them + # for figure, otherwise use plotly's default colors. + verts_rgb = None + faces_rgb = None + if isinstance(mesh.textures, TexturesVertex): + verts_rgb = mesh.textures.verts_features_packed() + verts_rgb.clamp_(min=0.0, max=1.0) + verts_rgb = (torch.tensor(255.0) * verts_rgb).to(torch.uint8) + if isinstance(mesh.textures, TexturesAtlas): + atlas = mesh.textures.atlas_packed() + # If K==1 + if atlas.shape[1] == 1 and atlas.shape[3] == 3: + faces_rgb = atlas[:, 0, 0] + + # Reposition the unused vertices to be "inside" the object + # (i.e. they won't be visible in the plot). + verts_used = torch.zeros((verts.shape[0],), dtype=torch.bool) + verts_used[torch.unique(faces)] = True + verts_center = verts[verts_used].mean(0) + verts[~verts_used] = verts_center + + row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1 + fig.add_trace( + go.Mesh3d( + x=verts[:, 0], + y=verts[:, 1], + z=verts[:, 2], + vertexcolor=verts_rgb, + facecolor=faces_rgb, + i=faces[:, 0], + j=faces[:, 1], + k=faces[:, 2], + lighting=lighting, + name=trace_name, + showlegend=True, + ), + row=row, + col=col, + ) + + # Access the current subplot's scene configuration + plot_scene = "scene" + str(subplot_idx + 1) + current_layout = fig["layout"][plot_scene] + + # update the bounds of the axes for the current trace + max_expand = (verts.max(0)[0] - verts.min(0)[0]).max() + _update_axes_bounds(verts_center, max_expand, current_layout) diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/save_scene.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/save_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..b37b354ff2743ae223de19680d98bab2ae09d41b --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/plotly/save_scene.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import plotly.graph_objects as go +from plotly.graph_objects import Figure +import numpy as np +import os +import imageio +from pathlib import Path +from typing import List, Optional, Tuple, Union, Dict, Any +from PIL import Image +import io +import numpy as np +from tqdm import tqdm + + +def img_bytes_to_np(img_bytes): + return np.array(Image.open(io.BytesIO(img_bytes))) + + +def make_video( + scene: Figure, + output_path: str = "scene_video.mp4", + fps: int = 15, + duration: int = 1, + camera_trajectory: Optional[List[Dict[str, Any]]] = None, + temp_dir: Optional[str] = None, + trajectory_kwargs: Optional[Dict[str, Any]] = None, +) -> str: + """ + Creates a video by updating the camera view location and saving snapshots. + + Args: + scene: A Plotly Figure object, typically created with plot_tdfy_scene + output_path: Path to save the output video file + fps: Frames per second for the output video + duration: Duration of the video in seconds + camera_trajectory: List of camera positions. If None, creates a default circular trajectory. + Each item should be a dict with eye, center, and up keys as expected by Plotly's scene.camera. + temp_dir: Directory to store temporary frame images. If None, uses ./tmp_frames + + Returns: + Path to the saved video file + """ + if not scene._has_subplots(): + raise ValueError("Scene must have subplots to create a video") + + num_frames = fps * duration + + if camera_trajectory is None: + if trajectory_kwargs is None: + trajectory_kwargs = {} + camera_trajectory = _create_default_camera_trajectory( + num_frames, **trajectory_kwargs + ) + + frames = [] + for i, camera_pos in tqdm(enumerate(camera_trajectory), total=num_frames): + # update the camera position + scene.update_scenes(camera=camera_pos) + img_as_png = scene.to_image(engine="kaleido") + frames.append(img_bytes_to_np(img_as_png)) + + return frames + + +def _create_default_camera_trajectory( + num_frames: int, + axis: str = "y", + elevation: float = 1.0, + radius: float = 2.0, + **kwargs, +) -> List[Dict[str, Any]]: + """ + Creates a default camera trajectory, rotating around the scene in a circle. + + Args: + num_frames: Number of frames in the trajectory + axis: Axis to rotate around ('x', 'y', or 'z') + + Returns: + List of camera positions + """ + trajectory = [] + + # Create a circular path + for i in range(num_frames): + angle = (i / num_frames) * 2 * np.pi + + # Default position (all zeros) + eye_x, eye_y, eye_z = 0.0, 0.0, 0.0 + + # Calculate camera position based on selected axis + if axis.lower() == "z": + # Rotate in the xy-plane (around z-axis) + eye_x = radius * np.sin(angle) + eye_y = radius * np.cos(angle) + eye_z = elevation # Slightly above the scene + up = {"x": 0, "y": 0, "z": 1} + elif axis.lower() == "y": + # Rotate in the xz-plane (around y-axis) + eye_x = radius * np.sin(angle) + eye_z = radius * np.cos(angle) + eye_y = elevation # Slightly offset from y-axis + up = {"x": 0, "y": 1, "z": 0} + elif axis.lower() == "x": + # Rotate in the yz-plane (around x-axis) + eye_y = radius * np.sin(angle) + eye_z = radius * np.cos(angle) + eye_x = elevation # Slightly offset from x-axis + up = {"x": 1, "y": 0, "z": 0} + else: + raise ValueError(f"Invalid axis: {axis}. Must be 'x', 'y', or 'z'") + + camera_pos = { + "eye": {"x": eye_x, "y": eye_y, "z": eye_z}, + "center": {"x": 0, "y": 0, "z": 0}, # Look at center + "up": up, # Orientation based on rotation axis + } + + trajectory.append(camera_pos) + + return trajectory diff --git a/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/scene_visualizer.py b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/scene_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..35f5d60c937a90a6dc6ca0f4f92b75f375039bc1 --- /dev/null +++ b/thirdparty/sam3d/sam3d/sam3d_objects/utils/visualization/scene_visualizer.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import torch +from typing import Optional + +from pytorch3d.renderer.cameras import PerspectiveCameras +from pytorch3d.structures import Pointclouds +from pytorch3d.transforms import quaternion_to_matrix + +from sam3d_objects.data.dataset.tdfy.transforms_3d import compose_transform +from sam3d_objects.utils.visualization.plotly.plot_scene import plot_tdfy_scene +from sam3d_objects.utils.visualization.image_mesh import ( + mesh_from_pointmap, + create_textured_mesh, +) + +from sam3d_objects.utils.visualization.plotly.plot_scene import NO_BACKGROUND, default_axisargs +from sam3d_objects.utils.visualization.plotly.save_scene import make_video as make_scene_video +import seaborn as sns +import copy + + +class SceneVisualizer: + make_video_from_fig = make_scene_video + + @staticmethod + def plot_scene( + points_local: torch.Tensor, + instance_quaternions_l2c: torch.Tensor, + instance_positions_l2c: torch.Tensor, + instance_scales_l2c: torch.Tensor, + pointmap: Optional[torch.Tensor] = None, + image: Optional[torch.Tensor] = None, + title: str = "Tdfy Scene", + height: int = 1000, + show_pointmap_as_mesh: bool = True, + clip_pointmap_colors_for_vis: bool = False, + filter_pointmap_edges: bool = True, + ): + cam = SceneVisualizer.camera() + + object_points = SceneVisualizer.object_pointcloud( + points_local=points_local.unsqueeze(0), + quat_l2c=instance_quaternions_l2c, + trans_l2c=instance_positions_l2c, + scale_l2c=instance_scales_l2c, + # colors=torch.ones_like(sample["instance_points_local"]) * torch.tensor([1, 0, 0]), + ) + + pointmap_struct_dict = SceneVisualizer._create_pointmap_structure( + pointmap=pointmap, + image=image, + show_pointmap_as_mesh=show_pointmap_as_mesh, + clip_pointmap_colors_for_vis=clip_pointmap_colors_for_vis, + filter_pointmap_edges=filter_pointmap_edges, + ) + return plot_tdfy_scene( + { + title: { + "camera": cam, + "object_points": object_points, + **pointmap_struct_dict, + } + }, + height=height, + ) + + @staticmethod + def plot_multi_objects( + pose_targets, + mask_names=None, + pointmap=None, + pointmap_colors=None, + mask_colors=None, + plot_tdfy_kwargs=None, + title="Tdfy Scene", + ): + if mask_colors is None: + mask_colors = sns.color_palette("husl", len(mask_names)) + if mask_names is None: + mask_names = [str(i) for i in range(len(pose_targets))] + + cam = SceneVisualizer.camera() + objects = {} + for i, mask_name in enumerate(mask_names): + if mask_name == None: + continue + + objects[mask_name] = SceneVisualizer.object_pointcloud( + points_local=pose_targets[i]["xyz_local"].unsqueeze(0), + quat_l2c=pose_targets[i]["rotation"], + trans_l2c=pose_targets[i]["translation"], + scale_l2c=pose_targets[i]["scale"], + colors=mask_colors[i], + ) + + pointmap_dict = {} + if pointmap is not None: + pointmap[pointmap.isnan()] = 0 + pointmap_dict = SceneVisualizer._create_pointmap_structure( + pointmap=pointmap, + image=pointmap_colors, + filter_pointmap_edges=True, + ) + + if plot_tdfy_kwargs is None: + plot_tdfy_kwargs = copy.deepcopy(NO_BACKGROUND) + if "height" not in plot_tdfy_kwargs: + plot_tdfy_kwargs["height"] = 1000 + if "width" not in plot_tdfy_kwargs: + plot_tdfy_kwargs["width"] = 1000 + + fig = plot_tdfy_scene( + { + title: { + "camera": cam, + **objects, + **pointmap_dict, + } + }, + **plot_tdfy_kwargs, + ) + return fig + + @staticmethod + def _create_pointmap_structure( + pointmap: torch.Tensor, + image: torch.Tensor, + show_pointmap_as_mesh: bool = True, + clip_pointmap_colors_for_vis: bool = True, + filter_pointmap_edges: bool = True, + ): + if pointmap is None: + return {} + + if show_pointmap_as_mesh: + if image is None: + image = torch.zeros_like(pointmap) + struct = SceneVisualizer.pointmap_to_mesh( + pointmap=pointmap, + image=image, + clip_pointmap_colors_for_vis=clip_pointmap_colors_for_vis, + filter_edges=filter_pointmap_edges, + ) + return {"Pointmap mesh": struct} + else: + struct = SceneVisualizer.pointmap_to_pointcloud( + pointmap=pointmap, image=image + ) + return {"Pointmap pointcloud": struct} + + @staticmethod + def camera( + quaternion: Optional[torch.Tensor] = None, + translation: Optional[torch.Tensor] = None, + ): + """ + Args: + quaternion: (4,) tensor of quaternion + translation: (3,) tensor of translation + """ + if quaternion is None: + quaternion = torch.tensor([1, 0, 0, 0]).unsqueeze(0) + if translation is None: + translation = torch.tensor([0, 0, 0]).unsqueeze(0) + R = quaternion_to_matrix(quaternion) + return PerspectiveCameras(R=R, T=translation) + + @staticmethod + def object_pointcloud( + points_local: torch.Tensor, + quat_l2c: torch.Tensor, + trans_l2c: torch.Tensor, + scale_l2c: torch.Tensor, + colors: Optional[torch.Tensor] = None, + ): + """ + Args: + points_local: (N, 3) tensor of point coordinates + colors: (N, 3) tensor of colors + """ + if colors is None: + colors = torch.ones_like(points_local) * torch.tensor( + (1.0, 0.0, 0.0), device=points_local.device + ) + elif isinstance(colors, tuple): + colors = torch.ones_like(points_local) * torch.tensor( + colors, device=points_local.device + ) + + R_l2c = quaternion_to_matrix(quat_l2c) + l2c_transform = compose_transform( + scale=scale_l2c, rotation=R_l2c, translation=trans_l2c + ) + points_world = l2c_transform.transform_points(points_local) + return Pointclouds(points=points_world, features=colors) + + @staticmethod + def pointmap_to_pointcloud(pointmap: torch.Tensor, image: torch.Tensor): + """ + Args: + pointmap: (H, W, 3) tensor of point coordinates + image: (H, W, 3) tensor of image + """ + if image is not None: + if image.shape[0] == 3: + image = image.permute(1, 2, 0) + image = image.reshape(-1, 3).unsqueeze(0).float() + + return Pointclouds( + points=pointmap.reshape(-1, 3).unsqueeze(0), + features=image, + ) + + @staticmethod + def pointmap_to_mesh( + pointmap: torch.Tensor, + image: torch.Tensor, + clip_pointmap_colors_for_vis: bool = True, + filter_edges: bool = True, + clamp_eps: float = 1 / 254, + ): + """ + Args: + pointmap: (H, W, 3) tensor of point coordinates + image: (H, W, 3) tensor of image + """ + pointmap = pointmap.cpu().numpy() + if image is None: + image = torch.zeros_like(pointmap) + if image.shape[0] == 3: + image = image.permute(1, 2, 0) + + if clip_pointmap_colors_for_vis: + # Not sure why, but this is needed to avoid underflow in the visualization + # We also clip to prevent overflow, just in case and since this is just for visualization + image = image.clamp(clamp_eps, 1 - clamp_eps) + image = image.cpu().numpy() + mesh = mesh_from_pointmap(pointmap, image, filter_edges=filter_edges) + vertices = torch.from_numpy(mesh.vertices) + faces = torch.from_numpy(mesh.faces) + vertex_colors = torch.from_numpy(mesh.vertex_colors) + return create_textured_mesh(vertices, faces, vertex_colors)