Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
715f79d
0
Parent(s):
Initial commit
Browse files- .gitattributes +37 -0
- .gitignore +42 -0
- .gitmodules +6 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +124 -0
- README.md +48 -0
- app.py +591 -0
- cowtracker/__init__.py +23 -0
- cowtracker/heads/__init__.py +14 -0
- cowtracker/heads/feature_extractor.py +100 -0
- cowtracker/heads/tracking_head.py +243 -0
- cowtracker/inference/__init__.py +12 -0
- cowtracker/inference/windowed.py +144 -0
- cowtracker/layers/__init__.py +26 -0
- cowtracker/layers/dpt_head.py +173 -0
- cowtracker/layers/patch_embed.py +90 -0
- cowtracker/layers/resnet_deconv.py +61 -0
- cowtracker/layers/temporal_attention.py +307 -0
- cowtracker/layers/video_transformer.py +411 -0
- cowtracker/models/__init__.py +23 -0
- cowtracker/models/cowtracker.py +228 -0
- cowtracker/models/cowtracker_windowed.py +218 -0
- cowtracker/thirdparty/DepthAnythingV2 +1 -0
- cowtracker/thirdparty/__init__.py +19 -0
- cowtracker/thirdparty/vggt +1 -0
- cowtracker/utils/__init__.py +32 -0
- cowtracker/utils/ops.py +151 -0
- cowtracker/utils/padding.py +168 -0
- cowtracker/utils/visualization.py +229 -0
- demo.py +179 -0
- docs/logo.jpg +3 -0
- docs/teaser.jpg +3 -0
- environments.yml +36 -0
- output.mp4 +3 -0
- packages.txt +2 -0
- requirements.txt +26 -0
- videos/apple.mp4 +3 -0
- videos/bear.mp4 +3 -0
- videos/bmx-bumps.mp4 +3 -0
- videos/cows.mp4 +3 -0
- videos/lab-coat.mp4 +3 -0
- videos/longboard.mp4 +3 -0
- videos/motocross-jump.mp4 +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
*.egg
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
eggs/
|
| 11 |
+
.eggs/
|
| 12 |
+
|
| 13 |
+
# Temporary files
|
| 14 |
+
tmp/
|
| 15 |
+
*.tmp
|
| 16 |
+
*.temp
|
| 17 |
+
|
| 18 |
+
# OS files
|
| 19 |
+
.DS_Store
|
| 20 |
+
.DS_Store?
|
| 21 |
+
._*
|
| 22 |
+
Thumbs.db
|
| 23 |
+
|
| 24 |
+
# IDE
|
| 25 |
+
.vscode/
|
| 26 |
+
.idea/
|
| 27 |
+
*.swp
|
| 28 |
+
*.swo
|
| 29 |
+
|
| 30 |
+
# Jupyter
|
| 31 |
+
.ipynb_checkpoints/
|
| 32 |
+
|
| 33 |
+
# Model checkpoints (use LFS for large files)
|
| 34 |
+
# *.pth
|
| 35 |
+
# *.pt
|
| 36 |
+
|
| 37 |
+
# Environment
|
| 38 |
+
.env
|
| 39 |
+
.venv/
|
| 40 |
+
venv/
|
| 41 |
+
ENV/
|
| 42 |
+
|
.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "cowtracker/thirdparty/DepthAnythingV2"]
|
| 2 |
+
path = cowtracker/thirdparty/DepthAnythingV2
|
| 3 |
+
url = https://github.com/DepthAnything/Depth-Anything-V2.git
|
| 4 |
+
[submodule "cowtracker/thirdparty/vggt"]
|
| 5 |
+
path = cowtracker/thirdparty/vggt
|
| 6 |
+
url = https://github.com/facebookresearch/vggt.git
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to cowtracker
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to cowtracker, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FAIR Noncommercial Research License
|
| 2 |
+
v1 Last Updated: December 22, 2025
|
| 3 |
+
|
| 4 |
+
“Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
|
| 5 |
+
|
| 6 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 10 |
+
Research Materials distributed by Meta.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 17 |
+
|
| 18 |
+
“Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
|
| 19 |
+
|
| 20 |
+
“Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 21 |
+
|
| 22 |
+
By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
1. License Rights and Redistribution.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
|
| 29 |
+
|
| 30 |
+
b. Redistribution and Use.
|
| 31 |
+
i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
|
| 41 |
+
2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 45 |
+
|
| 46 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 47 |
+
|
| 48 |
+
5. Intellectual Property.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 52 |
+
|
| 53 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
|
| 54 |
+
|
| 55 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 56 |
+
|
| 57 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
FAIR Acceptable Use Policy
|
| 64 |
+
|
| 65 |
+
The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
|
| 66 |
+
|
| 67 |
+
As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
|
| 68 |
+
|
| 69 |
+
Prohibited Uses
|
| 70 |
+
|
| 71 |
+
You agree you will not use, or allow others to use, Research Materials to:
|
| 72 |
+
|
| 73 |
+
Violate the law or others’ rights, including to:
|
| 74 |
+
Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
|
| 75 |
+
Violence or terrorism
|
| 76 |
+
Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
|
| 77 |
+
Human trafficking, exploitation, and sexual violence
|
| 78 |
+
The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
|
| 79 |
+
Sexual solicitation
|
| 80 |
+
Any other criminal activity
|
| 81 |
+
|
| 82 |
+
Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
|
| 83 |
+
|
| 84 |
+
Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
|
| 85 |
+
|
| 86 |
+
Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
|
| 87 |
+
|
| 88 |
+
Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
|
| 89 |
+
|
| 90 |
+
Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
|
| 91 |
+
|
| 92 |
+
Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
|
| 93 |
+
|
| 94 |
+
2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
|
| 95 |
+
|
| 96 |
+
Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
|
| 97 |
+
|
| 98 |
+
Guns and illegal weapons (including weapon development)
|
| 99 |
+
|
| 100 |
+
Illegal drugs and regulated/controlled substances
|
| 101 |
+
|
| 102 |
+
Operation of critical infrastructure, transportation technologies, or heavy machinery
|
| 103 |
+
|
| 104 |
+
Self-harm or harm to others, including suicide, cutting, and eating disorders
|
| 105 |
+
|
| 106 |
+
Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
|
| 107 |
+
|
| 108 |
+
3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
|
| 109 |
+
|
| 110 |
+
Generating, promoting, or furthering fraud or the creation or promotion of disinformation
|
| 111 |
+
|
| 112 |
+
Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
|
| 113 |
+
|
| 114 |
+
Generating, promoting, or further distributing spam
|
| 115 |
+
|
| 116 |
+
Impersonating another individual without consent, authorization, or legal right
|
| 117 |
+
|
| 118 |
+
Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated
|
| 119 |
+
|
| 120 |
+
Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
|
| 121 |
+
|
| 122 |
+
4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
|
| 123 |
+
|
| 124 |
+
Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
|
README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CoWTracker
|
| 3 |
+
emoji: 🐮
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.2.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: cc-by-nc-4.0
|
| 11 |
+
suggested_hardware: a10g-small
|
| 12 |
+
short_description: Dense Point Tracking with Cost-Volume Free Warping
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# 🐮 CoWTracker
|
| 16 |
+
|
| 17 |
+
**Cost-Volume Free Warping-Based Dense Point Tracking**
|
| 18 |
+
|
| 19 |
+
Zihang Lai<sup>1,2</sup>, Eldar Insafutdinov<sup>1</sup>, Edgar Sucar<sup>1</sup>, Andrea Vedaldi<sup>1,2</sup>
|
| 20 |
+
|
| 21 |
+
<sup>1</sup>Visual Geometry Group, University of Oxford <sup>2</sup>Meta AI
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
Upload a video and CoWTracker will track every pixel through time, visualizing the motion with colorful trajectories.
|
| 26 |
+
|
| 27 |
+
## Features
|
| 28 |
+
|
| 29 |
+
- **Dense Tracking**: Track all pixels simultaneously
|
| 30 |
+
- **Bidirectional**: Track forwards and backwards from any query frame
|
| 31 |
+
- **Interactive**: Choose query frame and visualization settings
|
| 32 |
+
- **Fast**: Efficient warping-based architecture
|
| 33 |
+
|
| 34 |
+
## Links
|
| 35 |
+
|
| 36 |
+
- [Project Page](https://cowtracker.github.io/)
|
| 37 |
+
- [GitHub](https://github.com/facebookresearch/cowtracker/)
|
| 38 |
+
- [Paper](#)
|
| 39 |
+
|
| 40 |
+
## Usage
|
| 41 |
+
|
| 42 |
+
1. Upload a video (or select an example)
|
| 43 |
+
2. Click "Process Video"
|
| 44 |
+
3. Select query frame using the slider
|
| 45 |
+
4. Click "Start Tracking"
|
| 46 |
+
5. Adjust visualization settings as needed
|
| 47 |
+
|
| 48 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
CoWTracker Gradio Demo.
|
| 10 |
+
|
| 11 |
+
Interactive web demo for dense point tracking using CoWTracker.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python app.py
|
| 15 |
+
python app.py --checkpoint /path/to/model.pth --port 8086
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import glob
|
| 19 |
+
import os
|
| 20 |
+
import tempfile
|
| 21 |
+
import uuid
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import spaces
|
| 26 |
+
import matplotlib
|
| 27 |
+
import mediapy
|
| 28 |
+
import numpy as np
|
| 29 |
+
import PIL.Image
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
from cowtracker import CoWTracker
|
| 33 |
+
from cowtracker.utils.padding import (
|
| 34 |
+
apply_padding,
|
| 35 |
+
compute_padding_params,
|
| 36 |
+
remove_padding_and_scale_back,
|
| 37 |
+
)
|
| 38 |
+
from cowtracker.utils.visualization import (
|
| 39 |
+
get_2d_colors,
|
| 40 |
+
get_colors_from_cmap,
|
| 41 |
+
paint_point_track,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# --- Constants ---
|
| 45 |
+
PREVIEW_WIDTH = 1024
|
| 46 |
+
PREVIEW_HEIGHT = 1024
|
| 47 |
+
FRAME_LIMIT = 512
|
| 48 |
+
# Default checkpoint: None means use the model's default HuggingFace URL
|
| 49 |
+
DEFAULT_CHECKPOINT = None
|
| 50 |
+
|
| 51 |
+
# --- Model Initialization ---
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def initialize_model(checkpoint_path: Optional[str] = None):
|
| 55 |
+
"""Initialize and load the CoWTracker model once at startup.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
checkpoint_path: Path to local checkpoint file.
|
| 59 |
+
If None, downloads from HuggingFace Hub.
|
| 60 |
+
"""
|
| 61 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 63 |
+
|
| 64 |
+
ckpt_path = checkpoint_path if checkpoint_path is not None else DEFAULT_CHECKPOINT
|
| 65 |
+
if ckpt_path:
|
| 66 |
+
print(f"Initializing CoWTracker model from {ckpt_path}...")
|
| 67 |
+
else:
|
| 68 |
+
print("Initializing CoWTracker model from HuggingFace Hub...")
|
| 69 |
+
|
| 70 |
+
model = CoWTracker.from_checkpoint(
|
| 71 |
+
ckpt_path,
|
| 72 |
+
device=device,
|
| 73 |
+
dtype=dtype,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
print("Model initialized successfully!")
|
| 77 |
+
return model
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Initialize model once at module level
|
| 81 |
+
GLOBAL_MODEL = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_model():
|
| 85 |
+
"""Get the global model, initializing if needed."""
|
| 86 |
+
global GLOBAL_MODEL
|
| 87 |
+
if GLOBAL_MODEL is None:
|
| 88 |
+
GLOBAL_MODEL = initialize_model()
|
| 89 |
+
return GLOBAL_MODEL
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# --- Core Logic Functions ---
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def preprocess_video_input(video_path):
|
| 96 |
+
"""Process uploaded video for tracking."""
|
| 97 |
+
if not video_path:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
video_arr = mediapy.read_video(video_path)
|
| 101 |
+
video_fps = video_arr.metadata.fps
|
| 102 |
+
num_frames = video_arr.shape[0]
|
| 103 |
+
|
| 104 |
+
if num_frames > FRAME_LIMIT:
|
| 105 |
+
gr.Warning(
|
| 106 |
+
f"Video is too long. Truncating to first {FRAME_LIMIT} frames.", duration=5
|
| 107 |
+
)
|
| 108 |
+
video_arr = video_arr[:FRAME_LIMIT]
|
| 109 |
+
num_frames = FRAME_LIMIT
|
| 110 |
+
|
| 111 |
+
height, width = video_arr.shape[1:3]
|
| 112 |
+
if height > width:
|
| 113 |
+
new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
|
| 114 |
+
else:
|
| 115 |
+
new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
|
| 116 |
+
|
| 117 |
+
# Resize logic to keep manageable size
|
| 118 |
+
if new_height * new_width > 768 * 1024:
|
| 119 |
+
new_height = new_height * 3 // 4
|
| 120 |
+
new_width = new_width * 3 // 4
|
| 121 |
+
|
| 122 |
+
# Make divisible by 16 for ffmpeg compatibility
|
| 123 |
+
new_height, new_width = new_height // 16 * 16, new_width // 16 * 16
|
| 124 |
+
|
| 125 |
+
preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
|
| 126 |
+
input_video = preview_video # using preview size for processing
|
| 127 |
+
|
| 128 |
+
preview_video = np.array(preview_video)
|
| 129 |
+
input_video = np.array(input_video)
|
| 130 |
+
|
| 131 |
+
return (
|
| 132 |
+
video_arr,
|
| 133 |
+
preview_video,
|
| 134 |
+
preview_video.copy(),
|
| 135 |
+
input_video,
|
| 136 |
+
video_fps,
|
| 137 |
+
preview_video[0],
|
| 138 |
+
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True),
|
| 139 |
+
gr.update(interactive=True),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def choose_frame(frame_num, video_preview_array):
|
| 144 |
+
"""Select frame for preview."""
|
| 145 |
+
if video_preview_array is None:
|
| 146 |
+
return None
|
| 147 |
+
return video_preview_array[int(frame_num)]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def paint_video(
|
| 151 |
+
video_preview,
|
| 152 |
+
query_frame,
|
| 153 |
+
video_fps,
|
| 154 |
+
tracks,
|
| 155 |
+
visibs,
|
| 156 |
+
rate=1,
|
| 157 |
+
show_bkg=True,
|
| 158 |
+
cmap="gist_rainbow",
|
| 159 |
+
):
|
| 160 |
+
"""Paint tracks onto video and save to file."""
|
| 161 |
+
T, H, W, _ = video_preview.shape
|
| 162 |
+
|
| 163 |
+
# Get colors based on colormap choice
|
| 164 |
+
if cmap == "bremm":
|
| 165 |
+
xy0 = tracks[:, query_frame]
|
| 166 |
+
colors = get_2d_colors(xy0, H, W)
|
| 167 |
+
else:
|
| 168 |
+
query_count = tracks.shape[0]
|
| 169 |
+
colors = get_colors_from_cmap(query_count, cmap)
|
| 170 |
+
|
| 171 |
+
painted_video = paint_point_track(
|
| 172 |
+
video_preview, tracks, visibs, colors, rate=rate, show_bkg=show_bkg
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Save video to temp directory
|
| 176 |
+
video_file_name = uuid.uuid4().hex + ".mp4"
|
| 177 |
+
video_path = os.path.join(tempfile.gettempdir(), "cowtracker_output")
|
| 178 |
+
video_file_path = os.path.join(video_path, video_file_name)
|
| 179 |
+
os.makedirs(video_path, exist_ok=True)
|
| 180 |
+
|
| 181 |
+
# Cleanup old jpgs
|
| 182 |
+
for f in glob.glob(os.path.join(video_path, "*.jpg")):
|
| 183 |
+
os.remove(f)
|
| 184 |
+
|
| 185 |
+
# Save frames and compile with ffmpeg
|
| 186 |
+
for ti in range(T):
|
| 187 |
+
temp_out_f = "%s/%03d.jpg" % (video_path, ti)
|
| 188 |
+
im = PIL.Image.fromarray(painted_video[ti])
|
| 189 |
+
im.save(temp_out_f)
|
| 190 |
+
|
| 191 |
+
os.system(
|
| 192 |
+
f'ffmpeg -y -hide_banner -loglevel error -f image2 -framerate {video_fps} '
|
| 193 |
+
f'-pattern_type glob -i "{video_path}/*.jpg" -c:v libx264 -crf 20 '
|
| 194 |
+
f'-pix_fmt yuv420p {video_file_path}'
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Cleanup used jpgs
|
| 198 |
+
for ti in range(T):
|
| 199 |
+
temp_out_f = "%s/%03d.jpg" % (video_path, ti)
|
| 200 |
+
if os.path.exists(temp_out_f):
|
| 201 |
+
os.remove(temp_out_f)
|
| 202 |
+
|
| 203 |
+
return video_file_path
|
| 204 |
+
|
| 205 |
+
@spaces.GPU
|
| 206 |
+
def update_vis(
|
| 207 |
+
rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs
|
| 208 |
+
):
|
| 209 |
+
"""Update visualization with new settings."""
|
| 210 |
+
if video_preview is None or len(tracks) == 0:
|
| 211 |
+
return None
|
| 212 |
+
T, H, W, _ = video_preview.shape
|
| 213 |
+
tracks_ = tracks.reshape(H, W, T, 2)[::rate, ::rate].reshape(-1, T, 2)
|
| 214 |
+
visibs_ = visibs.reshape(H, W, T)[::rate, ::rate].reshape(-1, T)
|
| 215 |
+
return paint_video(
|
| 216 |
+
video_preview,
|
| 217 |
+
query_frame,
|
| 218 |
+
video_fps,
|
| 219 |
+
tracks_,
|
| 220 |
+
visibs_,
|
| 221 |
+
rate=rate,
|
| 222 |
+
show_bkg=show_bkg,
|
| 223 |
+
cmap=cmap,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@spaces.GPU
|
| 228 |
+
def track(video_preview, video_input, video_fps, query_frame, rate, show_bkg, cmap):
|
| 229 |
+
"""Run tracking on video with bidirectional propagation."""
|
| 230 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 231 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 232 |
+
|
| 233 |
+
video_tensor = torch.tensor(video_input).unsqueeze(0).to(dtype)
|
| 234 |
+
|
| 235 |
+
# Use the globally initialized model
|
| 236 |
+
model = get_model()
|
| 237 |
+
print("Using pre-loaded model for tracking...")
|
| 238 |
+
|
| 239 |
+
video_tensor = video_tensor.permute(0, 1, 4, 2, 3)
|
| 240 |
+
_, T, _, H, W = video_tensor.shape
|
| 241 |
+
|
| 242 |
+
# Store original resolution
|
| 243 |
+
orig_H, orig_W = H, W
|
| 244 |
+
|
| 245 |
+
# Configure inference size and compute padding parameters
|
| 246 |
+
inf_H, inf_W = 336, 560
|
| 247 |
+
skip_upscaling = True
|
| 248 |
+
|
| 249 |
+
print(f"Original video size: {orig_H}x{orig_W}")
|
| 250 |
+
print(f"Inference size: {inf_H}x{inf_W}")
|
| 251 |
+
|
| 252 |
+
# Compute padding parameters
|
| 253 |
+
padding_info = compute_padding_params(
|
| 254 |
+
orig_H, orig_W, inf_H, inf_W, skip_upscaling=skip_upscaling
|
| 255 |
+
)
|
| 256 |
+
print(f"Scale factor: {padding_info['scale']:.4f}")
|
| 257 |
+
if padding_info["upscaling_skipped"]:
|
| 258 |
+
print(
|
| 259 |
+
f"Upscaling skipped (scale > 1.0) - using original size: {orig_H}x{orig_W}"
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
print(
|
| 263 |
+
f"Scaled size (before padding): {padding_info['scaled_H']}x{padding_info['scaled_W']}"
|
| 264 |
+
)
|
| 265 |
+
print(
|
| 266 |
+
f"Padding: top={padding_info['pad_top']}, bottom={padding_info['pad_bottom']}, "
|
| 267 |
+
f"left={padding_info['pad_left']}, right={padding_info['pad_right']}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
torch.cuda.empty_cache()
|
| 271 |
+
|
| 272 |
+
# Initialize output tensors for INFERENCE resolution
|
| 273 |
+
traj_maps_e = torch.zeros(
|
| 274 |
+
(1, T, inf_H, inf_W, 2), dtype=torch.float32, device="cpu"
|
| 275 |
+
)
|
| 276 |
+
visconf_maps_e = torch.zeros(
|
| 277 |
+
(1, T, inf_H, inf_W), dtype=torch.float32, device="cpu"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
# Forward pass
|
| 282 |
+
if query_frame < T - 1:
|
| 283 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
|
| 284 |
+
# Apply padding to forward video
|
| 285 |
+
forward_video = video_tensor[0, query_frame:]
|
| 286 |
+
forward_video_padded = apply_padding(forward_video, padding_info).to(device)
|
| 287 |
+
|
| 288 |
+
predictions = model.forward(
|
| 289 |
+
video=forward_video_padded,
|
| 290 |
+
queries=None,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Extract dense predictions (at INFERENCE resolution)
|
| 294 |
+
tracks_dense = predictions["track"][0] # (T_forward, inf_H, inf_W, 2)
|
| 295 |
+
visibility_dense = predictions["vis"][0] # (T_forward, inf_H, inf_W)
|
| 296 |
+
confidence_dense = predictions["conf"][0] # (T_forward, inf_H, inf_W)
|
| 297 |
+
|
| 298 |
+
# Store forward predictions
|
| 299 |
+
T_forward = tracks_dense.shape[0]
|
| 300 |
+
traj_maps_e[0, query_frame : query_frame + T_forward] = (
|
| 301 |
+
tracks_dense.cpu()
|
| 302 |
+
)
|
| 303 |
+
visconf_maps_e[0, query_frame : query_frame + T_forward] = (
|
| 304 |
+
visibility_dense * confidence_dense
|
| 305 |
+
).cpu()
|
| 306 |
+
|
| 307 |
+
# Backward pass
|
| 308 |
+
if query_frame > 0:
|
| 309 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
|
| 310 |
+
# Flip video for backward tracking and apply padding
|
| 311 |
+
backward_video = video_tensor[0, : query_frame + 1].flip([0])
|
| 312 |
+
backward_video_padded = apply_padding(
|
| 313 |
+
backward_video, padding_info
|
| 314 |
+
).to(device)
|
| 315 |
+
|
| 316 |
+
predictions = model.forward(
|
| 317 |
+
video=backward_video_padded,
|
| 318 |
+
queries=None,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Extract dense predictions (at INFERENCE resolution)
|
| 322 |
+
tracks_dense = predictions["track"][0] # (T_backward, inf_H, inf_W, 2)
|
| 323 |
+
visibility_dense = predictions["vis"][0] # (T_backward, inf_H, inf_W)
|
| 324 |
+
confidence_dense = predictions["conf"][0] # (T_backward, inf_H, inf_W)
|
| 325 |
+
|
| 326 |
+
# Flip back to original temporal order
|
| 327 |
+
backward_tracks = tracks_dense.flip([0]).cpu()
|
| 328 |
+
backward_visconf = (visibility_dense * confidence_dense).flip([0]).cpu()
|
| 329 |
+
|
| 330 |
+
# Store backward predictions (excluding query frame if needed)
|
| 331 |
+
end_idx = query_frame if query_frame < T - 1 else query_frame + 1
|
| 332 |
+
traj_maps_e[0, :end_idx] = backward_tracks[:end_idx]
|
| 333 |
+
visconf_maps_e[0, :end_idx] = backward_visconf[:end_idx]
|
| 334 |
+
|
| 335 |
+
# Remove padding and scale back to original resolution
|
| 336 |
+
print(f"Removing padding and scaling back to {orig_H}x{orig_W}")
|
| 337 |
+
tracks_final, _, confidence_final = remove_padding_and_scale_back(
|
| 338 |
+
traj_maps_e[0], # (T, inf_H, inf_W, 2)
|
| 339 |
+
torch.ones_like(visconf_maps_e[0]), # dummy visibility (not used here)
|
| 340 |
+
visconf_maps_e[0], # (T, inf_H, inf_W)
|
| 341 |
+
padding_info,
|
| 342 |
+
)
|
| 343 |
+
print(f"Tracks shape after unpadding: {tracks_final.shape}")
|
| 344 |
+
print(f"Confidence shape after unpadding: {confidence_final.shape}")
|
| 345 |
+
|
| 346 |
+
# Convert to numpy format
|
| 347 |
+
tracks = tracks_final.permute(1, 2, 0, 3).reshape(-1, T, 2).numpy()
|
| 348 |
+
confs = confidence_final.permute(1, 2, 0).reshape(-1, T).numpy()
|
| 349 |
+
visibs = confs > 0.1
|
| 350 |
+
|
| 351 |
+
return (
|
| 352 |
+
update_vis(
|
| 353 |
+
rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs
|
| 354 |
+
),
|
| 355 |
+
tracks,
|
| 356 |
+
visibs,
|
| 357 |
+
gr.update(interactive=True),
|
| 358 |
+
gr.update(interactive=True),
|
| 359 |
+
gr.update(interactive=True),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# --- Gradio UI Layout ---
|
| 364 |
+
|
| 365 |
+
custom_css = """
|
| 366 |
+
h1 {text-align: center; margin-bottom: 0 !important;}
|
| 367 |
+
.contain {max-width: 95% !important;}
|
| 368 |
+
#examples-accordion {margin-top: 10px;}
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def create_demo():
|
| 373 |
+
"""Create and return the Gradio demo interface."""
|
| 374 |
+
with gr.Blocks(title="CoWTracker Demo", theme=gr.themes.Ocean(), css=custom_css) as demo:
|
| 375 |
+
# State Variables
|
| 376 |
+
video_state = gr.State()
|
| 377 |
+
video_queried_preview = gr.State()
|
| 378 |
+
video_preview = gr.State()
|
| 379 |
+
video_input = gr.State()
|
| 380 |
+
video_fps = gr.State(24)
|
| 381 |
+
tracks = gr.State([])
|
| 382 |
+
visibs = gr.State([])
|
| 383 |
+
|
| 384 |
+
# Header
|
| 385 |
+
gr.Markdown(
|
| 386 |
+
"""
|
| 387 |
+
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
|
| 388 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">CoWTracker</h1>
|
| 389 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
| 390 |
+
Cost-Volume Free Warping-Based Dense Point Tracking.
|
| 391 |
+
<a href='https://cowtracker.github.io/' target='_blank'>Project Page</a> |
|
| 392 |
+
<a href='https://github.com/facebookresearch/cowtracker/' target='_blank'>GitHub</a> |
|
| 393 |
+
<a href='' target='_blank'>Paper</a>
|
| 394 |
+
</p>
|
| 395 |
+
</div>
|
| 396 |
+
"""
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
with gr.Row():
|
| 400 |
+
# --- Left Column: Input & Query ---
|
| 401 |
+
with gr.Column(scale=1):
|
| 402 |
+
with gr.Group():
|
| 403 |
+
gr.Markdown("### 1. Upload Video")
|
| 404 |
+
video_in = gr.Video(label="Input Video", format="mp4", height=300)
|
| 405 |
+
submit_btn = gr.Button("Step 1: Process Video", variant="primary")
|
| 406 |
+
|
| 407 |
+
# Query Frame Preview
|
| 408 |
+
with gr.Group():
|
| 409 |
+
gr.Markdown("### 2. Select Query Frame")
|
| 410 |
+
query_frame_slider = gr.Slider(
|
| 411 |
+
minimum=0,
|
| 412 |
+
maximum=100,
|
| 413 |
+
value=0,
|
| 414 |
+
step=1,
|
| 415 |
+
label="Frame Number",
|
| 416 |
+
interactive=False,
|
| 417 |
+
)
|
| 418 |
+
current_frame = gr.Image(
|
| 419 |
+
label="Query Frame Preview",
|
| 420 |
+
type="numpy",
|
| 421 |
+
interactive=False,
|
| 422 |
+
height=300,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
# --- Right Column: Visualization & Output ---
|
| 426 |
+
with gr.Column(scale=2):
|
| 427 |
+
gr.Markdown("### 3. Configure & Track")
|
| 428 |
+
|
| 429 |
+
with gr.Group():
|
| 430 |
+
with gr.Row():
|
| 431 |
+
rate_radio = gr.Radio(
|
| 432 |
+
[1, 2, 4, 8],
|
| 433 |
+
value=8,
|
| 434 |
+
label="Subsampling Rate",
|
| 435 |
+
interactive=False,
|
| 436 |
+
)
|
| 437 |
+
cmap_radio = gr.Radio(
|
| 438 |
+
["gist_rainbow", "rainbow", "jet", "turbo"],
|
| 439 |
+
value="gist_rainbow",
|
| 440 |
+
label="Colormap",
|
| 441 |
+
interactive=False,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
with gr.Row():
|
| 445 |
+
bkg_check = gr.Checkbox(
|
| 446 |
+
value=True, label="Overlay on Video", interactive=False
|
| 447 |
+
)
|
| 448 |
+
track_button = gr.Button(
|
| 449 |
+
"Step 2: Start Tracking", variant="primary", interactive=False
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Output takes entire width of this column
|
| 453 |
+
output_video = gr.Video(
|
| 454 |
+
label="Tracking Result",
|
| 455 |
+
interactive=False,
|
| 456 |
+
autoplay=True,
|
| 457 |
+
loop=True,
|
| 458 |
+
height=550,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# --- Full Width Row: Examples ---
|
| 462 |
+
with gr.Row():
|
| 463 |
+
with gr.Column():
|
| 464 |
+
video_folder = "videos"
|
| 465 |
+
gr.Markdown("### 📚 Example Videos")
|
| 466 |
+
video_dir = os.path.join(os.path.dirname(__file__), video_folder)
|
| 467 |
+
video_files = []
|
| 468 |
+
if os.path.exists(video_dir):
|
| 469 |
+
for filename in sorted(os.listdir(video_dir)):
|
| 470 |
+
if filename.endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")):
|
| 471 |
+
video_files.append(os.path.join(video_dir, filename))
|
| 472 |
+
|
| 473 |
+
if video_files:
|
| 474 |
+
gr.Examples(
|
| 475 |
+
examples=video_files,
|
| 476 |
+
inputs=[video_in],
|
| 477 |
+
examples_per_page=16,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# --- Interaction Logic ---
|
| 481 |
+
|
| 482 |
+
# 1. Submit Video
|
| 483 |
+
submit_btn.click(
|
| 484 |
+
fn=preprocess_video_input,
|
| 485 |
+
inputs=[video_in],
|
| 486 |
+
outputs=[
|
| 487 |
+
video_state,
|
| 488 |
+
video_preview,
|
| 489 |
+
video_queried_preview,
|
| 490 |
+
video_input,
|
| 491 |
+
video_fps,
|
| 492 |
+
current_frame,
|
| 493 |
+
query_frame_slider,
|
| 494 |
+
track_button,
|
| 495 |
+
],
|
| 496 |
+
queue=False,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# 2. Update Preview Frame on Slider Change
|
| 500 |
+
query_frame_slider.change(
|
| 501 |
+
fn=choose_frame,
|
| 502 |
+
inputs=[query_frame_slider, video_queried_preview],
|
| 503 |
+
outputs=[current_frame],
|
| 504 |
+
queue=False,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# 3. Run Tracking
|
| 508 |
+
track_button.click(
|
| 509 |
+
fn=track,
|
| 510 |
+
inputs=[
|
| 511 |
+
video_preview,
|
| 512 |
+
video_input,
|
| 513 |
+
video_fps,
|
| 514 |
+
query_frame_slider,
|
| 515 |
+
rate_radio,
|
| 516 |
+
bkg_check,
|
| 517 |
+
cmap_radio,
|
| 518 |
+
],
|
| 519 |
+
outputs=[
|
| 520 |
+
output_video,
|
| 521 |
+
tracks,
|
| 522 |
+
visibs,
|
| 523 |
+
rate_radio,
|
| 524 |
+
bkg_check,
|
| 525 |
+
cmap_radio,
|
| 526 |
+
],
|
| 527 |
+
queue=True,
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# 4. Instant Updates for Visualization Settings (after tracking is done)
|
| 531 |
+
vis_args = [
|
| 532 |
+
rate_radio,
|
| 533 |
+
bkg_check,
|
| 534 |
+
cmap_radio,
|
| 535 |
+
video_preview,
|
| 536 |
+
query_frame_slider,
|
| 537 |
+
video_fps,
|
| 538 |
+
tracks,
|
| 539 |
+
visibs,
|
| 540 |
+
]
|
| 541 |
+
rate_radio.change(
|
| 542 |
+
fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
|
| 543 |
+
)
|
| 544 |
+
cmap_radio.change(
|
| 545 |
+
fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
|
| 546 |
+
)
|
| 547 |
+
bkg_check.change(
|
| 548 |
+
fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
return demo
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
if __name__ == "__main__":
|
| 555 |
+
import argparse
|
| 556 |
+
|
| 557 |
+
parser = argparse.ArgumentParser(description="CoWTracker Gradio Demo")
|
| 558 |
+
parser.add_argument(
|
| 559 |
+
"--checkpoint",
|
| 560 |
+
type=str,
|
| 561 |
+
default=None,
|
| 562 |
+
help="Path to model checkpoint",
|
| 563 |
+
)
|
| 564 |
+
parser.add_argument(
|
| 565 |
+
"--port",
|
| 566 |
+
type=int,
|
| 567 |
+
default=7860,
|
| 568 |
+
help="Server port",
|
| 569 |
+
)
|
| 570 |
+
parser.add_argument(
|
| 571 |
+
"--share",
|
| 572 |
+
action="store_true",
|
| 573 |
+
help="Create a public share link",
|
| 574 |
+
)
|
| 575 |
+
args = parser.parse_args()
|
| 576 |
+
|
| 577 |
+
# Initialize model with custom checkpoint if provided
|
| 578 |
+
if args.checkpoint:
|
| 579 |
+
GLOBAL_MODEL = initialize_model(args.checkpoint)
|
| 580 |
+
|
| 581 |
+
print("=" * 60)
|
| 582 |
+
print("Starting CoWTracker Gradio Demo")
|
| 583 |
+
print("=" * 60)
|
| 584 |
+
|
| 585 |
+
demo = create_demo()
|
| 586 |
+
demo.launch(
|
| 587 |
+
share=args.share,
|
| 588 |
+
show_error=True,
|
| 589 |
+
server_port=args.port,
|
| 590 |
+
)
|
| 591 |
+
|
cowtracker/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CoWTracker: Cost-Volume Free Warping-Based Dense Point Tracking."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def __getattr__(name):
|
| 11 |
+
"""Lazy import to avoid import errors when dependencies are missing."""
|
| 12 |
+
if name == "CoWTracker":
|
| 13 |
+
from cowtracker.models.cowtracker import CoWTracker
|
| 14 |
+
|
| 15 |
+
return CoWTracker
|
| 16 |
+
if name == "CoWTrackerWindowed":
|
| 17 |
+
from cowtracker.models.cowtracker_windowed import CoWTrackerWindowed
|
| 18 |
+
|
| 19 |
+
return CoWTrackerWindowed
|
| 20 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["CoWTracker", "CoWTrackerWindowed"]
|
cowtracker/heads/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CowTracker heads."""
|
| 8 |
+
|
| 9 |
+
from cowtracker.heads.tracking_head import CowTrackingHead
|
| 10 |
+
from cowtracker.heads.feature_extractor import FeatureExtractor
|
| 11 |
+
import cowtracker.thirdparty # noqa: F401 - sets up vggt path
|
| 12 |
+
from vggt.heads.dpt_head import DPTHead
|
| 13 |
+
|
| 14 |
+
__all__ = ["CowTrackingHead", "FeatureExtractor", "DPTHead"]
|
cowtracker/heads/feature_extractor.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Feature extraction: DPT + ResNet side features."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
import cowtracker.thirdparty # noqa: F401 - sets up vggt path
|
| 14 |
+
from vggt.heads.dpt_head import DPTHead
|
| 15 |
+
from cowtracker.layers.resnet_deconv import ResNet18Deconv
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FeatureExtractor(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Combined DPT and ResNet feature extractor.
|
| 21 |
+
|
| 22 |
+
Takes aggregated tokens from backbone and raw images,
|
| 23 |
+
outputs combined features for tracking.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
DIM_IN = 2048 # 2 * embed_dim (1024)
|
| 27 |
+
PATCH_SIZE = 14
|
| 28 |
+
INTERMEDIATE_LAYER_IDX = [4, 11, 17, 23]
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
features: int = 128,
|
| 33 |
+
down_ratio: int = 2,
|
| 34 |
+
side_resnet_channels: int = 128,
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
features: Number of DPT output features.
|
| 39 |
+
down_ratio: Downsampling ratio relative to input image.
|
| 40 |
+
side_resnet_channels: Number of ResNet side feature channels.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.features = features
|
| 45 |
+
self.down_ratio = down_ratio
|
| 46 |
+
|
| 47 |
+
# DPT head for backbone features
|
| 48 |
+
self.dpt_head = DPTHead(
|
| 49 |
+
dim_in=self.DIM_IN,
|
| 50 |
+
patch_size=self.PATCH_SIZE,
|
| 51 |
+
features=features,
|
| 52 |
+
feature_only=True,
|
| 53 |
+
down_ratio=down_ratio,
|
| 54 |
+
pos_embed=False,
|
| 55 |
+
intermediate_layer_idx=self.INTERMEDIATE_LAYER_IDX,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# ResNet for raw image features
|
| 59 |
+
self.fnet = ResNet18Deconv(3, side_resnet_channels)
|
| 60 |
+
|
| 61 |
+
self.out_dim = features + side_resnet_channels
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self,
|
| 65 |
+
aggregated_tokens_list: list,
|
| 66 |
+
images: torch.Tensor,
|
| 67 |
+
patch_start_idx: int,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Extract combined features from backbone tokens and raw images.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
aggregated_tokens_list: List of tokens from aggregator.
|
| 74 |
+
images: Input images [B, S, 3, H, W].
|
| 75 |
+
patch_start_idx: Patch start index for DPT.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
combined_features: [B, S, C, H_out, W_out] where C = features + side_resnet_channels.
|
| 79 |
+
"""
|
| 80 |
+
B, S, _, H_img, W_img = images.shape
|
| 81 |
+
|
| 82 |
+
# DPT features from backbone tokens
|
| 83 |
+
backbone_features = self.dpt_head(aggregated_tokens_list, images, patch_start_idx)
|
| 84 |
+
_, _, _, H_out, W_out = backbone_features.shape
|
| 85 |
+
|
| 86 |
+
# Side ResNet features from raw images
|
| 87 |
+
images_flat = images.view(B * S, 3, H_img, W_img)
|
| 88 |
+
side_features = self.fnet(images_flat)[0]
|
| 89 |
+
_, side_channels, H_side, W_side = side_features.shape
|
| 90 |
+
|
| 91 |
+
# Resize side features to match backbone output if needed
|
| 92 |
+
if H_side != H_out or W_side != W_out:
|
| 93 |
+
side_features = F.interpolate(
|
| 94 |
+
side_features, size=(H_out, W_out), mode="bilinear", align_corners=True
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
side_features = side_features.view(B, S, side_channels, H_out, W_out)
|
| 98 |
+
|
| 99 |
+
return torch.cat([backbone_features, side_features], dim=2)
|
| 100 |
+
|
cowtracker/heads/tracking_head.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CowTracker tracking head - Warping-based iterative refinement."""
|
| 8 |
+
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from cowtracker.layers.video_transformer import MODEL_CONFIGS, VisionTransformerVideo
|
| 16 |
+
from cowtracker.utils.ops import bilinear_sampler, coords_grid
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CowTrackingHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Warping-based iterative refinement module.
|
| 22 |
+
|
| 23 |
+
Responsibility: features -> (tracks, visibility, confidence)
|
| 24 |
+
Does NOT handle: feature extraction, windowing
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
TEMPORAL_INTERLEAVE_STRIDE = 2
|
| 28 |
+
MAX_FRAMES = 256
|
| 29 |
+
MLP_RATIO = 4.0
|
| 30 |
+
REFINE_PATCH_SIZE = 4
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
feature_dim: int,
|
| 35 |
+
down_ratio: int = 2,
|
| 36 |
+
warp_iters: int = 5,
|
| 37 |
+
warp_model: str = "vits",
|
| 38 |
+
warp_vit_num_blocks: int = None,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
feature_dim: Input feature dimension (features + side_resnet_channels).
|
| 43 |
+
down_ratio: Feature downsampling ratio relative to original image.
|
| 44 |
+
warp_iters: Number of Warping-based iterative refinement iterations.
|
| 45 |
+
warp_model: Model configuration for video transformer.
|
| 46 |
+
warp_vit_num_blocks: Number of transformer blocks (None = use default).
|
| 47 |
+
"""
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.warp_iters = warp_iters
|
| 51 |
+
self.down_ratio = down_ratio
|
| 52 |
+
|
| 53 |
+
# Warping-based iterative refinement iteration dimension
|
| 54 |
+
self.iter_dim = MODEL_CONFIGS[warp_model]["features"]
|
| 55 |
+
|
| 56 |
+
# Video transformer for temporal attention
|
| 57 |
+
self.refine_net = VisionTransformerVideo(
|
| 58 |
+
warp_model,
|
| 59 |
+
self.iter_dim,
|
| 60 |
+
patch_size=self.REFINE_PATCH_SIZE,
|
| 61 |
+
temporal_interleave_stride=self.TEMPORAL_INTERLEAVE_STRIDE,
|
| 62 |
+
max_frames=self.MAX_FRAMES,
|
| 63 |
+
mlp_ratio=self.MLP_RATIO,
|
| 64 |
+
attn_dropout=0.0,
|
| 65 |
+
proj_dropout=0.0,
|
| 66 |
+
drop_path=0.0,
|
| 67 |
+
num_blocks=warp_vit_num_blocks,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Feature processing layers
|
| 71 |
+
self.fmap_conv = nn.Conv2d(feature_dim, self.iter_dim, 1, 1, 0, bias=True)
|
| 72 |
+
self.hidden_conv = nn.Conv2d(self.iter_dim * 2, self.iter_dim, 1, 1, 0, bias=True)
|
| 73 |
+
self.warp_linear = nn.Conv2d(3 * self.iter_dim + 2, self.iter_dim, 1, 1, 0, bias=True)
|
| 74 |
+
self.refine_transform = nn.Conv2d(self.iter_dim // 2 * 3, self.iter_dim, 1, 1, 0, bias=True)
|
| 75 |
+
|
| 76 |
+
# Upsampling weights
|
| 77 |
+
self.upsample_weight = nn.Sequential(
|
| 78 |
+
nn.Conv2d(self.iter_dim, 2 * self.iter_dim, 3, padding=1, bias=True),
|
| 79 |
+
nn.ReLU(inplace=True),
|
| 80 |
+
nn.Conv2d(2 * self.iter_dim, (down_ratio**2) * 9, 1, padding=0, bias=True),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Flow + visibility + confidence head
|
| 84 |
+
self.flow_head = nn.Sequential(
|
| 85 |
+
nn.Conv2d(self.iter_dim, 2 * self.iter_dim, 3, padding=1, bias=True),
|
| 86 |
+
nn.ReLU(inplace=True),
|
| 87 |
+
nn.Conv2d(2 * self.iter_dim, 4, 1, padding=0, bias=True),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
print(f"CowTrackingHead initialized: iter_dim={self.iter_dim}, warp_iters={warp_iters}")
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
features: torch.Tensor,
|
| 95 |
+
image_size: Tuple[int, int],
|
| 96 |
+
first_frame_features: torch.Tensor = None,
|
| 97 |
+
) -> dict:
|
| 98 |
+
"""
|
| 99 |
+
Run Warping-based iterative refinement.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
features: Extracted features [B, S, C, H, W].
|
| 103 |
+
image_size: Original image size (H_img, W_img) for upsampling.
|
| 104 |
+
first_frame_features: Optional first frame features [B, 1, C, H, W]
|
| 105 |
+
for cross-window tracking.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
dict with:
|
| 109 |
+
- track: Dense tracks [B, S, H_img, W_img, 2].
|
| 110 |
+
- vis: Visibility scores [B, S, H_img, W_img].
|
| 111 |
+
- conf: Confidence scores [B, S, H_img, W_img].
|
| 112 |
+
"""
|
| 113 |
+
B, S, _, H, W = features.shape
|
| 114 |
+
H_img, W_img = image_size
|
| 115 |
+
|
| 116 |
+
# Project features to iteration dimension
|
| 117 |
+
fmap = self.fmap_conv(features.view(B * S, -1, H, W)).view(B, S, -1, H, W)
|
| 118 |
+
|
| 119 |
+
# Frame 0 reference features
|
| 120 |
+
if first_frame_features is not None:
|
| 121 |
+
frame0_fmap = self.fmap_conv(first_frame_features.view(B, -1, H, W)).view(B, 1, -1, H, W)
|
| 122 |
+
else:
|
| 123 |
+
frame0_fmap = fmap[:, 0:1]
|
| 124 |
+
frame0_expanded = frame0_fmap.expand(B, S, -1, H, W)
|
| 125 |
+
|
| 126 |
+
# Initialize hidden state from concatenation of frame0 and current features
|
| 127 |
+
net = self.hidden_conv(
|
| 128 |
+
torch.cat([frame0_expanded, fmap], dim=2).view(B * S, -1, H, W)
|
| 129 |
+
).view(B, S, -1, H, W)
|
| 130 |
+
|
| 131 |
+
# Initialize flow to zero
|
| 132 |
+
flow = torch.zeros(B, S, 2, H, W, device=features.device, dtype=features.dtype)
|
| 133 |
+
|
| 134 |
+
# Iterative refinement
|
| 135 |
+
for _ in range(self.warp_iters):
|
| 136 |
+
flow = flow.detach()
|
| 137 |
+
|
| 138 |
+
# Compute warped coordinates
|
| 139 |
+
coords = coords_grid(B * S, H, W, device=features.device).to(fmap.dtype).view(B, S, 2, H, W)
|
| 140 |
+
coords_warped = coords + flow
|
| 141 |
+
|
| 142 |
+
# Warp features using current flow estimate
|
| 143 |
+
warped_fmap = bilinear_sampler(
|
| 144 |
+
fmap.view(B * S, -1, H, W), coords_warped.view(B * S, 2, H, W).permute(0, 2, 3, 1)
|
| 145 |
+
).view(B, S, -1, H, W)
|
| 146 |
+
|
| 147 |
+
# Build refinement input
|
| 148 |
+
refine_inp = self.warp_linear(
|
| 149 |
+
torch.cat([frame0_expanded, warped_fmap, net, flow], dim=2).view(B * S, -1, H, W)
|
| 150 |
+
).view(B, S, -1, H, W)
|
| 151 |
+
|
| 152 |
+
# Apply video transformer with temporal attention
|
| 153 |
+
refine_out = self.refine_net(refine_inp)["out"]
|
| 154 |
+
|
| 155 |
+
# Update hidden state
|
| 156 |
+
net = self.refine_transform(
|
| 157 |
+
torch.cat([refine_out.view(B * S, -1, H, W), net.view(B * S, -1, H, W)], dim=1)
|
| 158 |
+
).view(B, S, -1, H, W)
|
| 159 |
+
|
| 160 |
+
# Predict flow and info update
|
| 161 |
+
update = self.flow_head(net.view(B * S, -1, H, W)).view(B, S, 4, H, W)
|
| 162 |
+
flow = flow + update[:, :, :2]
|
| 163 |
+
info = update[:, :, 2:]
|
| 164 |
+
|
| 165 |
+
# Upsample to original resolution
|
| 166 |
+
weight = 0.25 * self.upsample_weight(net.view(B * S, -1, H, W)).view(B, S, -1, H, W)
|
| 167 |
+
flow_up, info_up = self._upsample_predictions(flow, info, weight)
|
| 168 |
+
|
| 169 |
+
# Convert flow to absolute track coordinates
|
| 170 |
+
tracks = self._flow_to_tracks(flow_up, H_img, W_img)
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"track": tracks,
|
| 174 |
+
"vis": torch.sigmoid(info_up[..., 0]),
|
| 175 |
+
"conf": torch.sigmoid(info_up[..., 1]),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
def _upsample_predictions(
|
| 179 |
+
self,
|
| 180 |
+
flow: torch.Tensor,
|
| 181 |
+
info: torch.Tensor,
|
| 182 |
+
weight: torch.Tensor,
|
| 183 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 184 |
+
"""Upsample flow and info using learned convex combination."""
|
| 185 |
+
B, S, _, H, W = flow.shape
|
| 186 |
+
|
| 187 |
+
flow_ups, info_ups = [], []
|
| 188 |
+
for t in range(S):
|
| 189 |
+
f_up, i_up = self._upsample_single(flow[:, t], info[:, t], weight[:, t])
|
| 190 |
+
flow_ups.append(f_up)
|
| 191 |
+
info_ups.append(i_up)
|
| 192 |
+
|
| 193 |
+
return torch.stack(flow_ups, dim=1), torch.stack(info_ups, dim=1)
|
| 194 |
+
|
| 195 |
+
def _upsample_single(
|
| 196 |
+
self,
|
| 197 |
+
flow: torch.Tensor,
|
| 198 |
+
info: torch.Tensor,
|
| 199 |
+
mask: torch.Tensor,
|
| 200 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 201 |
+
"""Upsample single frame using soft convex combination."""
|
| 202 |
+
N, _, H, W = flow.shape
|
| 203 |
+
C = info.shape[1]
|
| 204 |
+
factor = self.down_ratio
|
| 205 |
+
|
| 206 |
+
mask = mask.view(N, 1, 9, factor, factor, H, W)
|
| 207 |
+
mask = torch.softmax(mask, dim=2)
|
| 208 |
+
|
| 209 |
+
up_flow = F.unfold(factor * flow, [3, 3], padding=1).view(N, 2, 9, 1, 1, H, W)
|
| 210 |
+
up_info = F.unfold(info, [3, 3], padding=1).view(N, C, 9, 1, 1, H, W)
|
| 211 |
+
|
| 212 |
+
up_flow = torch.sum(mask * up_flow, dim=2).permute(0, 1, 4, 2, 5, 3)
|
| 213 |
+
up_info = torch.sum(mask * up_info, dim=2).permute(0, 1, 4, 2, 5, 3)
|
| 214 |
+
|
| 215 |
+
return (
|
| 216 |
+
up_flow.reshape(N, 2, factor * H, factor * W).permute(0, 2, 3, 1),
|
| 217 |
+
up_info.reshape(N, C, factor * H, factor * W).permute(0, 2, 3, 1),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def _flow_to_tracks(
|
| 221 |
+
self,
|
| 222 |
+
flow: torch.Tensor,
|
| 223 |
+
H_img: int,
|
| 224 |
+
W_img: int,
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
"""Convert flow to absolute track coordinates."""
|
| 227 |
+
B, S = flow.shape[:2]
|
| 228 |
+
device, dtype = flow.device, flow.dtype
|
| 229 |
+
|
| 230 |
+
# Create coordinate grid
|
| 231 |
+
y, x = torch.meshgrid(
|
| 232 |
+
torch.arange(H_img, device=device, dtype=dtype),
|
| 233 |
+
torch.arange(W_img, device=device, dtype=dtype),
|
| 234 |
+
indexing="ij",
|
| 235 |
+
)
|
| 236 |
+
coords = torch.stack([x, y], dim=-1).unsqueeze(0).unsqueeze(0).expand(B, S, -1, -1, -1)
|
| 237 |
+
|
| 238 |
+
# Normalize flow relative to frame 0 during inference
|
| 239 |
+
if not self.training:
|
| 240 |
+
flow = flow - flow[:, 0:1]
|
| 241 |
+
flow[:, 0] = 0
|
| 242 |
+
|
| 243 |
+
return coords + flow
|
cowtracker/inference/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Inference utilities for CowTracker."""
|
| 8 |
+
|
| 9 |
+
from cowtracker.inference.windowed import WindowedInference
|
| 10 |
+
|
| 11 |
+
__all__ = ["WindowedInference"]
|
| 12 |
+
|
cowtracker/inference/windowed.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Windowed inference for long video processing."""
|
| 8 |
+
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WindowedInference:
|
| 13 |
+
"""
|
| 14 |
+
Manages windowed inference for long videos.
|
| 15 |
+
|
| 16 |
+
Handles window computation, memory frame selection, and prediction merging.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
window_len: int = 100,
|
| 22 |
+
stride: int = 100,
|
| 23 |
+
num_memory_frames: int = 10,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
window_len: Number of frames per window.
|
| 28 |
+
stride: Step size between windows.
|
| 29 |
+
num_memory_frames: Maximum number of memory frames to use.
|
| 30 |
+
"""
|
| 31 |
+
self.window_len = window_len
|
| 32 |
+
self.stride = stride
|
| 33 |
+
self.num_memory_frames = num_memory_frames
|
| 34 |
+
|
| 35 |
+
def compute_windows(self, total_frames: int) -> List[Tuple[int, int]]:
|
| 36 |
+
"""
|
| 37 |
+
Compute all window (start, end) indices.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
total_frames: Total number of frames in the video.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
List of (start, end) tuples for each window.
|
| 44 |
+
"""
|
| 45 |
+
S = self.window_len
|
| 46 |
+
step = self.stride
|
| 47 |
+
|
| 48 |
+
if total_frames <= S:
|
| 49 |
+
return [(0, total_frames)]
|
| 50 |
+
|
| 51 |
+
windows = []
|
| 52 |
+
start = 0
|
| 53 |
+
while start < total_frames:
|
| 54 |
+
end = min(start + S, total_frames)
|
| 55 |
+
windows.append((start, end))
|
| 56 |
+
if end == total_frames:
|
| 57 |
+
break
|
| 58 |
+
start += step
|
| 59 |
+
|
| 60 |
+
return windows
|
| 61 |
+
|
| 62 |
+
def select_memory_frames(
|
| 63 |
+
self,
|
| 64 |
+
window_idx: int,
|
| 65 |
+
window_start: int,
|
| 66 |
+
) -> List[int]:
|
| 67 |
+
"""
|
| 68 |
+
Select memory frame indices using hybrid strategy.
|
| 69 |
+
|
| 70 |
+
Strategy combines:
|
| 71 |
+
- First frame (always included for global reference)
|
| 72 |
+
- Recent frames (temporal continuity)
|
| 73 |
+
- Uniformly sampled middle frames (long-range context)
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
window_idx: Current window index.
|
| 77 |
+
window_start: Start frame index of current window.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Sorted list of memory frame indices.
|
| 81 |
+
"""
|
| 82 |
+
if window_idx == 0:
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
memory_indices = [0] # Always include first frame
|
| 86 |
+
|
| 87 |
+
# Recent frames for temporal continuity
|
| 88 |
+
for offset in [2, 1]:
|
| 89 |
+
idx = window_start - offset
|
| 90 |
+
if idx > 0 and idx not in memory_indices:
|
| 91 |
+
memory_indices.append(idx)
|
| 92 |
+
|
| 93 |
+
# Uniform sampling from middle history for long-range context
|
| 94 |
+
if window_start > 10:
|
| 95 |
+
mid_start, mid_end = 5, window_start - 3
|
| 96 |
+
step = (mid_end - mid_start) / 6
|
| 97 |
+
for i in range(5):
|
| 98 |
+
idx = int(mid_start + (i + 1) * step)
|
| 99 |
+
if idx not in memory_indices:
|
| 100 |
+
memory_indices.append(idx)
|
| 101 |
+
|
| 102 |
+
# Limit to maximum number of memory frames
|
| 103 |
+
if len(memory_indices) > self.num_memory_frames:
|
| 104 |
+
memory_indices = sorted(memory_indices)[-self.num_memory_frames :]
|
| 105 |
+
|
| 106 |
+
return sorted(memory_indices)
|
| 107 |
+
|
| 108 |
+
def merge_predictions(
|
| 109 |
+
self,
|
| 110 |
+
window_idx: int,
|
| 111 |
+
window_start: int,
|
| 112 |
+
window_end: int,
|
| 113 |
+
window_pred: Dict,
|
| 114 |
+
accumulated: Dict,
|
| 115 |
+
) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Merge window predictions into accumulated results.
|
| 118 |
+
|
| 119 |
+
Handles overlapping regions by using only non-overlapping parts
|
| 120 |
+
from subsequent windows.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
window_idx: Current window index.
|
| 124 |
+
window_start: Start frame index.
|
| 125 |
+
window_end: End frame index.
|
| 126 |
+
window_pred: Predictions for current window (track, vis, conf).
|
| 127 |
+
accumulated: Accumulated predictions to update in-place.
|
| 128 |
+
"""
|
| 129 |
+
S_actual = window_end - window_start
|
| 130 |
+
|
| 131 |
+
if window_idx > 0 and self.stride < self.window_len:
|
| 132 |
+
# Has overlap with previous window - only take non-overlapping part
|
| 133 |
+
overlap_len = min(self.window_len - self.stride, S_actual)
|
| 134 |
+
if overlap_len < S_actual:
|
| 135 |
+
start_offset = overlap_len
|
| 136 |
+
for key in ["track", "vis", "conf"]:
|
| 137 |
+
accumulated[key][:, window_start + start_offset : window_end] = window_pred[key][
|
| 138 |
+
:, start_offset:S_actual
|
| 139 |
+
]
|
| 140 |
+
else:
|
| 141 |
+
# No overlap or first window - take everything
|
| 142 |
+
for key in ["track", "vis", "conf"]:
|
| 143 |
+
accumulated[key][:, window_start:window_end] = window_pred[key][:, :S_actual]
|
| 144 |
+
|
cowtracker/layers/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Network layers and backbone modules for CowTracker."""
|
| 8 |
+
|
| 9 |
+
from cowtracker.layers.temporal_attention import TemporalSelfAttentionBlock
|
| 10 |
+
from cowtracker.layers.video_transformer import (
|
| 11 |
+
MODEL_CONFIGS,
|
| 12 |
+
VisionTransformerVideo,
|
| 13 |
+
FlashAttention3,
|
| 14 |
+
replace_attention_with_flash3,
|
| 15 |
+
)
|
| 16 |
+
from cowtracker.layers.patch_embed import PatchEmbed
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"TemporalSelfAttentionBlock",
|
| 20 |
+
"MODEL_CONFIGS",
|
| 21 |
+
"VisionTransformerVideo",
|
| 22 |
+
"FlashAttention3",
|
| 23 |
+
"replace_attention_with_flash3",
|
| 24 |
+
"PatchEmbed",
|
| 25 |
+
]
|
| 26 |
+
|
cowtracker/layers/dpt_head.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Custom DPTHead with intermediate feature extraction support.
|
| 9 |
+
|
| 10 |
+
This module imports the base components from the Depth-Anything-V2 submodule
|
| 11 |
+
and provides a modified DPTHead that can return intermediate features.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
# Import base components from submodule
|
| 19 |
+
from cowtracker.thirdparty.DepthAnythingV2.depth_anything_v2.util.blocks import (
|
| 20 |
+
FeatureFusionBlock,
|
| 21 |
+
_make_scratch,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _make_fusion_block(features, use_bn, size=None):
|
| 26 |
+
return FeatureFusionBlock(
|
| 27 |
+
features,
|
| 28 |
+
nn.ReLU(False),
|
| 29 |
+
deconv=False,
|
| 30 |
+
bn=use_bn,
|
| 31 |
+
expand=False,
|
| 32 |
+
align_corners=True,
|
| 33 |
+
size=size,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DPTHead(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
DPT decoder head with support for returning intermediate features.
|
| 40 |
+
|
| 41 |
+
This is a modified version that:
|
| 42 |
+
- Removes output_conv2 (final depth prediction layers)
|
| 43 |
+
- Removes resConfUnit1 from refinenet4
|
| 44 |
+
- Supports returning intermediate feature maps via return_intermediate flag
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
in_channels,
|
| 50 |
+
features=256,
|
| 51 |
+
use_bn=False,
|
| 52 |
+
out_channels=[256, 512, 1024, 1024],
|
| 53 |
+
use_clstoken=False
|
| 54 |
+
):
|
| 55 |
+
super(DPTHead, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.use_clstoken = use_clstoken
|
| 58 |
+
|
| 59 |
+
self.projects = nn.ModuleList([
|
| 60 |
+
nn.Conv2d(
|
| 61 |
+
in_channels=in_channels,
|
| 62 |
+
out_channels=out_channel,
|
| 63 |
+
kernel_size=1,
|
| 64 |
+
stride=1,
|
| 65 |
+
padding=0,
|
| 66 |
+
) for out_channel in out_channels
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
self.resize_layers = nn.ModuleList([
|
| 70 |
+
nn.ConvTranspose2d(
|
| 71 |
+
in_channels=out_channels[0],
|
| 72 |
+
out_channels=out_channels[0],
|
| 73 |
+
kernel_size=4,
|
| 74 |
+
stride=4,
|
| 75 |
+
padding=0),
|
| 76 |
+
nn.ConvTranspose2d(
|
| 77 |
+
in_channels=out_channels[1],
|
| 78 |
+
out_channels=out_channels[1],
|
| 79 |
+
kernel_size=2,
|
| 80 |
+
stride=2,
|
| 81 |
+
padding=0),
|
| 82 |
+
nn.Identity(),
|
| 83 |
+
nn.Conv2d(
|
| 84 |
+
in_channels=out_channels[3],
|
| 85 |
+
out_channels=out_channels[3],
|
| 86 |
+
kernel_size=3,
|
| 87 |
+
stride=2,
|
| 88 |
+
padding=1)
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
if use_clstoken:
|
| 92 |
+
self.readout_projects = nn.ModuleList()
|
| 93 |
+
for _ in range(len(self.projects)):
|
| 94 |
+
self.readout_projects.append(
|
| 95 |
+
nn.Sequential(
|
| 96 |
+
nn.Linear(2 * in_channels, in_channels),
|
| 97 |
+
nn.GELU()))
|
| 98 |
+
|
| 99 |
+
self.scratch = _make_scratch(
|
| 100 |
+
out_channels,
|
| 101 |
+
features,
|
| 102 |
+
groups=1,
|
| 103 |
+
expand=False,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.scratch.stem_transpose = None
|
| 107 |
+
|
| 108 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
| 109 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
| 110 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
| 111 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
| 112 |
+
|
| 113 |
+
head_features_1 = features
|
| 114 |
+
|
| 115 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 116 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Remove resConfUnit1 from refinenet4 (not needed for intermediate feature extraction)
|
| 120 |
+
del self.scratch.refinenet4.resConfUnit1
|
| 121 |
+
|
| 122 |
+
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
|
| 123 |
+
"""
|
| 124 |
+
Forward pass through the DPT head.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
out_features: List of intermediate features from the encoder
|
| 128 |
+
patch_h: Height in patches
|
| 129 |
+
patch_w: Width in patches
|
| 130 |
+
return_intermediate: If True, return intermediate feature maps
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
If return_intermediate=True:
|
| 134 |
+
(out, path_1, path_2, path_3, path_4) - output and intermediate features
|
| 135 |
+
Else:
|
| 136 |
+
out - final output only
|
| 137 |
+
"""
|
| 138 |
+
out = []
|
| 139 |
+
for i, x in enumerate(out_features):
|
| 140 |
+
if self.use_clstoken:
|
| 141 |
+
x, cls_token = x[0], x[1]
|
| 142 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
| 143 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
| 144 |
+
else:
|
| 145 |
+
x = x[0]
|
| 146 |
+
|
| 147 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 148 |
+
|
| 149 |
+
x = self.projects[i](x)
|
| 150 |
+
x = self.resize_layers[i](x)
|
| 151 |
+
|
| 152 |
+
out.append(x)
|
| 153 |
+
|
| 154 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
| 155 |
+
|
| 156 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 157 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 158 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 159 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 160 |
+
|
| 161 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 162 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 163 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 164 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
| 165 |
+
|
| 166 |
+
out = self.scratch.output_conv1(path_1)
|
| 167 |
+
|
| 168 |
+
if return_intermediate:
|
| 169 |
+
return out, path_1, path_2, path_3, path_4
|
| 170 |
+
else:
|
| 171 |
+
out = F.relu(out)
|
| 172 |
+
return out
|
| 173 |
+
|
cowtracker/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 68 |
+
|
| 69 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 70 |
+
_, _, H, W = x.shape
|
| 71 |
+
patch_H, patch_W = self.patch_size
|
| 72 |
+
|
| 73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 75 |
+
|
| 76 |
+
x = self.proj(x) # B C H W
|
| 77 |
+
H, W = x.size(2), x.size(3)
|
| 78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
if not self.flatten_embedding:
|
| 81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def flops(self) -> float:
|
| 85 |
+
Ho, Wo = self.patches_resolution
|
| 86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 87 |
+
if self.norm is not None:
|
| 88 |
+
flops += Ho * Wo * self.embed_dim
|
| 89 |
+
return flops
|
| 90 |
+
|
cowtracker/layers/resnet_deconv.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""ResNet18-style encoder-decoder for image features."""
|
| 8 |
+
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class resconv(nn.Module):
|
| 13 |
+
"""Residual convolution block."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, inp, oup, k=3, s=1):
|
| 16 |
+
super(resconv, self).__init__()
|
| 17 |
+
self.conv = nn.Sequential(
|
| 18 |
+
nn.GELU(),
|
| 19 |
+
nn.Conv2d(inp, oup, kernel_size=k, stride=s, padding=k // 2, bias=True),
|
| 20 |
+
nn.GELU(),
|
| 21 |
+
nn.Conv2d(oup, oup, kernel_size=3, stride=1, padding=1, bias=True),
|
| 22 |
+
)
|
| 23 |
+
if inp != oup or s != 1:
|
| 24 |
+
self.skip_conv = nn.Conv2d(
|
| 25 |
+
inp, oup, kernel_size=1, stride=s, padding=0, bias=True
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
self.skip_conv = nn.Identity()
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return self.conv(x) + self.skip_conv(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ResNet18Deconv(nn.Module):
|
| 35 |
+
"""ResNet18-style encoder-decoder for image features."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, inp, oup):
|
| 38 |
+
super(ResNet18Deconv, self).__init__()
|
| 39 |
+
self.ds1 = resconv(inp, 64, k=7, s=2)
|
| 40 |
+
self.conv1 = resconv(64, 64, k=3, s=1)
|
| 41 |
+
self.conv2 = resconv(64, 128, k=3, s=2)
|
| 42 |
+
self.conv3 = resconv(128, 256, k=3, s=2)
|
| 43 |
+
self.conv4 = resconv(256, 512, k=3, s=2)
|
| 44 |
+
self.up_4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0, bias=True)
|
| 45 |
+
self.proj_3 = resconv(256, 256, k=3, s=1)
|
| 46 |
+
self.up_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True)
|
| 47 |
+
self.proj_2 = resconv(128, 128, k=3, s=1)
|
| 48 |
+
self.up_2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0, bias=True)
|
| 49 |
+
self.proj_1 = resconv(64, oup, k=3, s=1)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
out_1 = self.ds1(x)
|
| 53 |
+
out_1 = self.conv1(out_1)
|
| 54 |
+
out_2 = self.conv2(out_1)
|
| 55 |
+
out_3 = self.conv3(out_2)
|
| 56 |
+
out_4 = self.conv4(out_3)
|
| 57 |
+
out_3 = self.proj_3(out_3 + self.up_4(out_4))
|
| 58 |
+
out_2 = self.proj_2(out_2 + self.up_3(out_3))
|
| 59 |
+
out_1 = self.proj_1(out_1 + self.up_2(out_2))
|
| 60 |
+
return [out_1, out_2, out_3, out_4]
|
| 61 |
+
|
cowtracker/layers/temporal_attention.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Enhanced Cross Attention Block implementation.
|
| 9 |
+
Self-contained version with all necessary components inline.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ============================================================================
|
| 20 |
+
# Inline helper modules (to avoid external dependencies)
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DropPath(nn.Module):
|
| 25 |
+
"""Drop paths (Stochastic Depth) per sample."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.drop_prob = drop_prob
|
| 30 |
+
|
| 31 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 32 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 33 |
+
return x
|
| 34 |
+
keep_prob = 1 - self.drop_prob
|
| 35 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 36 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 37 |
+
if keep_prob > 0.0:
|
| 38 |
+
random_tensor.div_(keep_prob)
|
| 39 |
+
return x * random_tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LayerScale(nn.Module):
|
| 43 |
+
"""Layer scale module."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, dim: int, init_values: float = 1e-5):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
return x * self.gamma
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Mlp(nn.Module):
|
| 54 |
+
"""MLP as used in Vision Transformer."""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
in_features: int,
|
| 59 |
+
hidden_features: int = None,
|
| 60 |
+
out_features: int = None,
|
| 61 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 62 |
+
drop: float = 0.0,
|
| 63 |
+
bias: bool = True,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
out_features = out_features or in_features
|
| 67 |
+
hidden_features = hidden_features or in_features
|
| 68 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 69 |
+
self.act = act_layer()
|
| 70 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 71 |
+
self.drop = nn.Dropout(drop)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 74 |
+
x = self.fc1(x)
|
| 75 |
+
x = self.act(x)
|
| 76 |
+
x = self.drop(x)
|
| 77 |
+
x = self.fc2(x)
|
| 78 |
+
x = self.drop(x)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class MemEffAttention(nn.Module):
|
| 83 |
+
"""Memory efficient self-attention using PyTorch's scaled_dot_product_attention."""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
num_heads: int = 8,
|
| 89 |
+
qkv_bias: bool = True,
|
| 90 |
+
proj_bias: bool = True,
|
| 91 |
+
attn_drop: float = 0.0,
|
| 92 |
+
proj_drop: float = 0.0,
|
| 93 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 94 |
+
qk_norm: bool = False,
|
| 95 |
+
fused_attn: bool = True,
|
| 96 |
+
rope=None,
|
| 97 |
+
) -> None:
|
| 98 |
+
super().__init__()
|
| 99 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
self.head_dim = dim // num_heads
|
| 102 |
+
self.scale = self.head_dim**-0.5
|
| 103 |
+
self.fused_attn = fused_attn
|
| 104 |
+
|
| 105 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 106 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 107 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 108 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 109 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 110 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 111 |
+
self.rope = rope
|
| 112 |
+
|
| 113 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 114 |
+
B, N, C = x.shape
|
| 115 |
+
qkv = (
|
| 116 |
+
self.qkv(x)
|
| 117 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 118 |
+
.permute(2, 0, 3, 1, 4)
|
| 119 |
+
)
|
| 120 |
+
q, k, v = qkv.unbind(0)
|
| 121 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 122 |
+
|
| 123 |
+
if self.rope is not None and pos is not None:
|
| 124 |
+
q = self.rope(q, pos)
|
| 125 |
+
k = self.rope(k, pos)
|
| 126 |
+
|
| 127 |
+
if self.fused_attn:
|
| 128 |
+
x = F.scaled_dot_product_attention(
|
| 129 |
+
q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
q = q * self.scale
|
| 133 |
+
attn = q @ k.transpose(-2, -1)
|
| 134 |
+
attn = attn.softmax(dim=-1)
|
| 135 |
+
attn = self.attn_drop(attn)
|
| 136 |
+
x = attn @ v
|
| 137 |
+
|
| 138 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 139 |
+
x = self.proj(x)
|
| 140 |
+
x = self.proj_drop(x)
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ============================================================================
|
| 145 |
+
# Main attention block classes
|
| 146 |
+
# ============================================================================
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SelfAttentionBlock(nn.Module):
|
| 150 |
+
"""
|
| 151 |
+
Self attention block using the same architecture as CrossAttentionBlock but for self attention.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
dim: int,
|
| 157 |
+
num_heads: int,
|
| 158 |
+
mlp_ratio: float = 4.0,
|
| 159 |
+
qkv_bias: bool = True,
|
| 160 |
+
proj_bias: bool = True,
|
| 161 |
+
ffn_bias: bool = True,
|
| 162 |
+
drop: float = 0.0,
|
| 163 |
+
attn_drop: float = 0.0,
|
| 164 |
+
init_values=None,
|
| 165 |
+
drop_path: float = 0.0,
|
| 166 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 167 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 168 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 169 |
+
qk_norm: bool = False,
|
| 170 |
+
fused_attn: bool = True,
|
| 171 |
+
rope=None,
|
| 172 |
+
) -> None:
|
| 173 |
+
super().__init__()
|
| 174 |
+
|
| 175 |
+
self.norm1 = norm_layer(dim)
|
| 176 |
+
|
| 177 |
+
# Use standard Attention for self attention
|
| 178 |
+
self.attn = MemEffAttention(
|
| 179 |
+
dim,
|
| 180 |
+
num_heads=num_heads,
|
| 181 |
+
qkv_bias=qkv_bias,
|
| 182 |
+
proj_bias=proj_bias,
|
| 183 |
+
attn_drop=attn_drop,
|
| 184 |
+
proj_drop=drop,
|
| 185 |
+
qk_norm=qk_norm,
|
| 186 |
+
fused_attn=fused_attn,
|
| 187 |
+
rope=rope,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self.ls1 = (
|
| 191 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 192 |
+
)
|
| 193 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 194 |
+
|
| 195 |
+
self.norm2 = norm_layer(dim)
|
| 196 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 197 |
+
self.mlp = ffn_layer(
|
| 198 |
+
in_features=dim,
|
| 199 |
+
hidden_features=mlp_hidden_dim,
|
| 200 |
+
act_layer=act_layer,
|
| 201 |
+
drop=drop,
|
| 202 |
+
bias=ffn_bias,
|
| 203 |
+
)
|
| 204 |
+
self.ls2 = (
|
| 205 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 206 |
+
)
|
| 207 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 208 |
+
|
| 209 |
+
self.sample_drop_ratio = drop_path
|
| 210 |
+
|
| 211 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 212 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 213 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
| 214 |
+
|
| 215 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 216 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 217 |
+
|
| 218 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 219 |
+
x = x + self.drop_path1(attn_residual_func(x, pos))
|
| 220 |
+
x = x + self.drop_path2(ffn_residual_func(x))
|
| 221 |
+
else:
|
| 222 |
+
x = x + attn_residual_func(x, pos)
|
| 223 |
+
x = x + ffn_residual_func(x)
|
| 224 |
+
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class TemporalSelfAttentionBlock(nn.Module):
|
| 229 |
+
"""
|
| 230 |
+
Temporal self attention block that applies self-attention across time for each spatial position.
|
| 231 |
+
Input: [B, S, N, C] -> Output: [B, S, N, C]
|
| 232 |
+
For each position n, performs self-attention across all time steps.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
dim: int,
|
| 238 |
+
num_heads: int,
|
| 239 |
+
mlp_ratio: float = 4.0,
|
| 240 |
+
qkv_bias: bool = True,
|
| 241 |
+
proj_bias: bool = True,
|
| 242 |
+
ffn_bias: bool = True,
|
| 243 |
+
drop: float = 0.0,
|
| 244 |
+
attn_drop: float = 0.0,
|
| 245 |
+
init_values=None,
|
| 246 |
+
drop_path: float = 0.0,
|
| 247 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 248 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 249 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 250 |
+
qk_norm: bool = False,
|
| 251 |
+
fused_attn: bool = True,
|
| 252 |
+
rope=None,
|
| 253 |
+
) -> None:
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
self.self_attn_block = SelfAttentionBlock(
|
| 257 |
+
dim,
|
| 258 |
+
num_heads,
|
| 259 |
+
mlp_ratio,
|
| 260 |
+
qkv_bias,
|
| 261 |
+
proj_bias,
|
| 262 |
+
ffn_bias,
|
| 263 |
+
drop,
|
| 264 |
+
attn_drop,
|
| 265 |
+
init_values,
|
| 266 |
+
drop_path,
|
| 267 |
+
act_layer,
|
| 268 |
+
norm_layer,
|
| 269 |
+
ffn_layer,
|
| 270 |
+
qk_norm,
|
| 271 |
+
fused_attn,
|
| 272 |
+
rope,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
def forward(self, x: Tensor, pos=None):
|
| 276 |
+
"""
|
| 277 |
+
Apply temporal self-attention across time for each spatial position.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
x: Input tensor of shape [B, S, N, C]
|
| 281 |
+
pos: Position encoding
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Output tensor of shape [B, S, N, C]
|
| 285 |
+
"""
|
| 286 |
+
if len(x.shape) != 4:
|
| 287 |
+
raise ValueError(
|
| 288 |
+
f"TemporalSelfAttentionBlock expects 4D input [B, S, N, C], got {x.shape}"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
B, S, N, C = x.shape
|
| 292 |
+
|
| 293 |
+
if S <= 1:
|
| 294 |
+
# No temporal dimension to attend over, return input unchanged
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
# Reshape to [B*N, S, C] to process each spatial position independently
|
| 298 |
+
x_reshaped = x.permute(0, 2, 1, 3).reshape(B * N, S, C) # [B*N, S, C]
|
| 299 |
+
|
| 300 |
+
# Apply temporal self-attention
|
| 301 |
+
x_reshaped = self.self_attn_block(x_reshaped, pos=pos)
|
| 302 |
+
|
| 303 |
+
# Reshape back to [B, S, N, C]
|
| 304 |
+
x = x_reshaped.reshape(B, N, S, C).permute(0, 2, 1, 3)
|
| 305 |
+
|
| 306 |
+
return x
|
| 307 |
+
|
cowtracker/layers/video_transformer.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import timm
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import xformers.ops as xops
|
| 14 |
+
from timm.models.vision_transformer import Attention as TimmAttention
|
| 15 |
+
from cowtracker.layers.temporal_attention import TemporalSelfAttentionBlock
|
| 16 |
+
from cowtracker.layers.patch_embed import PatchEmbed
|
| 17 |
+
from cowtracker.layers.dpt_head import DPTHead
|
| 18 |
+
|
| 19 |
+
print("timm version: ", timm.__version__)
|
| 20 |
+
|
| 21 |
+
def get_1d_sincos_pos_embed_from_grid(
|
| 22 |
+
embed_dim: int, pos: torch.Tensor
|
| 23 |
+
) -> torch.Tensor:
|
| 24 |
+
"""
|
| 25 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
- embed_dim: The embedding dimension.
|
| 29 |
+
- pos: The position to generate the embedding from.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
- emb: The generated 1D positional embedding.
|
| 33 |
+
"""
|
| 34 |
+
assert embed_dim % 2 == 0
|
| 35 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 36 |
+
omega /= embed_dim / 2.0
|
| 37 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 38 |
+
|
| 39 |
+
pos = pos.reshape(-1) # (M,)
|
| 40 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 41 |
+
|
| 42 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 43 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 44 |
+
|
| 45 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 46 |
+
return emb[None].float()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _get_flash_attention_ops():
|
| 50 |
+
"""Automatically detect GPU and return appropriate flash attention ops.
|
| 51 |
+
|
| 52 |
+
Returns Flash Attention 3 ops for H100 (compute capability >= 9.0),
|
| 53 |
+
otherwise returns Flash Attention 2 ops.
|
| 54 |
+
"""
|
| 55 |
+
if not torch.cuda.is_available():
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# Get compute capability of current device
|
| 59 |
+
major, _ = torch.cuda.get_device_capability()
|
| 60 |
+
# print("compute capability: ", torch.cuda.get_device_capability())
|
| 61 |
+
# H100 has compute capability 9.0
|
| 62 |
+
if major >= 9:
|
| 63 |
+
# Use Flash Attention 3 for H100 and newer
|
| 64 |
+
try:
|
| 65 |
+
return (xops.fmha.flash3.FwOp, xops.fmha.flash3.BwOp)
|
| 66 |
+
except AttributeError:
|
| 67 |
+
# Fall back to flash2 if flash3 not available
|
| 68 |
+
print("Flash Attention 3 not available, falling back to Flash Attention 2")
|
| 69 |
+
return (xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
|
| 70 |
+
else:
|
| 71 |
+
# Use Flash Attention 2 for older GPUs
|
| 72 |
+
return (xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FlashAttention3(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Drop-in replacement for timm.models.vision_transformer.Attention using xformers Flash Attention 3.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
dim: int,
|
| 83 |
+
num_heads: int = 8,
|
| 84 |
+
qkv_bias: bool = False,
|
| 85 |
+
attn_drop: float = 0.0,
|
| 86 |
+
proj_drop: float = 0.0,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 90 |
+
self.num_heads = num_heads
|
| 91 |
+
self.head_dim = dim // num_heads
|
| 92 |
+
self.scale = self.head_dim**-0.5
|
| 93 |
+
|
| 94 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 95 |
+
self.attn_drop = attn_drop
|
| 96 |
+
self.proj = nn.Linear(dim, dim)
|
| 97 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 98 |
+
|
| 99 |
+
# Get Flash Attention ops
|
| 100 |
+
self.flash_ops = _get_flash_attention_ops()
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
|
| 104 |
+
) -> torch.Tensor:
|
| 105 |
+
B, N, C = x.shape
|
| 106 |
+
|
| 107 |
+
# Compute Q, K, V
|
| 108 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 109 |
+
q, k, v = qkv.unbind(2) # Each is (B, N, num_heads, head_dim)
|
| 110 |
+
|
| 111 |
+
# xformers expects [B, M, H, K] format - we already have it!
|
| 112 |
+
# Use xformers memory_efficient_attention with Flash Attention 3
|
| 113 |
+
x = xops.memory_efficient_attention(
|
| 114 |
+
q,
|
| 115 |
+
k,
|
| 116 |
+
v,
|
| 117 |
+
attn_bias=attn_mask, # Pass attention mask if provided
|
| 118 |
+
p=self.attn_drop if self.training else 0.0,
|
| 119 |
+
scale=self.scale,
|
| 120 |
+
op=self.flash_ops,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Reshape back to [B, N, C]
|
| 124 |
+
x = x.reshape(B, N, C)
|
| 125 |
+
|
| 126 |
+
# Output projection
|
| 127 |
+
x = self.proj(x)
|
| 128 |
+
x = self.proj_drop(x)
|
| 129 |
+
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def replace_attention_with_flash3(model: nn.Module) -> nn.Module:
|
| 134 |
+
"""
|
| 135 |
+
Recursively replace all timm.Attention modules with FlashAttention3.
|
| 136 |
+
"""
|
| 137 |
+
for name, module in model.named_children():
|
| 138 |
+
if isinstance(module, TimmAttention):
|
| 139 |
+
# Extract parameters from timm Attention
|
| 140 |
+
flash_attn = FlashAttention3(
|
| 141 |
+
dim=module.qkv.in_features,
|
| 142 |
+
num_heads=module.num_heads,
|
| 143 |
+
qkv_bias=module.qkv.bias is not None,
|
| 144 |
+
attn_drop=module.attn_drop.p if hasattr(module, "attn_drop") else 0.0,
|
| 145 |
+
proj_drop=module.proj_drop.p if hasattr(module, "proj_drop") else 0.0,
|
| 146 |
+
)
|
| 147 |
+
# Copy weights from original attention
|
| 148 |
+
flash_attn.qkv.weight.data = module.qkv.weight.data.clone()
|
| 149 |
+
if module.qkv.bias is not None:
|
| 150 |
+
flash_attn.qkv.bias.data = module.qkv.bias.data.clone()
|
| 151 |
+
flash_attn.proj.weight.data = module.proj.weight.data.clone()
|
| 152 |
+
if module.proj.bias is not None:
|
| 153 |
+
flash_attn.proj.bias.data = module.proj.bias.data.clone()
|
| 154 |
+
|
| 155 |
+
# Replace the module
|
| 156 |
+
setattr(model, name, flash_attn)
|
| 157 |
+
# print(
|
| 158 |
+
# f" Replaced attention module '{name}' (dim={module.qkv.in_features}, heads={module.num_heads})"
|
| 159 |
+
# )
|
| 160 |
+
else:
|
| 161 |
+
# Recursively apply to child modules
|
| 162 |
+
replace_attention_with_flash3(module)
|
| 163 |
+
|
| 164 |
+
return model
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
MODEL_CONFIGS = {
|
| 168 |
+
"vitl": {
|
| 169 |
+
"encoder": "vit_large_patch16_224",
|
| 170 |
+
"features": 256,
|
| 171 |
+
"out_channels": [256, 512, 1024, 1024],
|
| 172 |
+
},
|
| 173 |
+
"vitb": {
|
| 174 |
+
"encoder": "vit_base_patch16_224",
|
| 175 |
+
"features": 128,
|
| 176 |
+
"out_channels": [96, 192, 384, 768],
|
| 177 |
+
},
|
| 178 |
+
"vits": {
|
| 179 |
+
"encoder": "vit_small_patch16_224",
|
| 180 |
+
"features": 64,
|
| 181 |
+
"out_channels": [48, 96, 192, 384],
|
| 182 |
+
},
|
| 183 |
+
"vitt": {
|
| 184 |
+
"encoder": "vit_tiny_patch16_224",
|
| 185 |
+
"features": 32,
|
| 186 |
+
"out_channels": [24, 48, 96, 192],
|
| 187 |
+
},
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
class VisionTransformerVideo(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
Input: (B, T, C, H, W)
|
| 193 |
+
Pipeline: per-frame ViT + interleaved Temporal Attention (across frames)
|
| 194 |
+
Time pos: 1D sinusoidal encoding + linear interpolation for variable T
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
model_name,
|
| 200 |
+
input_dim,
|
| 201 |
+
patch_size=16,
|
| 202 |
+
temporal_interleave_stride=2,
|
| 203 |
+
max_frames=256,
|
| 204 |
+
mlp_ratio=4.0,
|
| 205 |
+
attn_dropout=0.0,
|
| 206 |
+
proj_dropout=0.0,
|
| 207 |
+
drop_path=0.0,
|
| 208 |
+
shared_temporal_block=False,
|
| 209 |
+
num_blocks=None,
|
| 210 |
+
use_flash_attention3=False,
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
model = timm.create_model(
|
| 214 |
+
MODEL_CONFIGS[model_name]["encoder"],
|
| 215 |
+
pretrained=False,
|
| 216 |
+
num_classes=0,
|
| 217 |
+
)
|
| 218 |
+
self.intermediate_layer_idx = {
|
| 219 |
+
"vitt": [2, 5, 8, 11],
|
| 220 |
+
"vits": [2, 5, 8, 11],
|
| 221 |
+
"vitb": [2, 5, 8, 11],
|
| 222 |
+
"vitl": [4, 11, 17, 23],
|
| 223 |
+
"vitg": [9, 19, 29, 39],
|
| 224 |
+
}
|
| 225 |
+
self.idx = self.intermediate_layer_idx[model_name]
|
| 226 |
+
self.blks = model.blocks if num_blocks is None else model.blocks[:num_blocks]
|
| 227 |
+
|
| 228 |
+
# Replace attention with Flash Attention 3 if enabled
|
| 229 |
+
if use_flash_attention3:
|
| 230 |
+
self.blks = replace_attention_with_flash3(self.blks)
|
| 231 |
+
num_fa3_modules = sum(
|
| 232 |
+
1 for m in self.blks.modules() if isinstance(m, FlashAttention3)
|
| 233 |
+
)
|
| 234 |
+
print(
|
| 235 |
+
f"✓ Flash Attention 3 enabled for spatial attention: replaced {num_fa3_modules} attention modules"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
self.embed_dim = model.embed_dim
|
| 239 |
+
self.input_dim = input_dim
|
| 240 |
+
self.img_size = (224, 224)
|
| 241 |
+
self.patch_size = patch_size
|
| 242 |
+
self.output_dim = MODEL_CONFIGS[model_name]["features"]
|
| 243 |
+
|
| 244 |
+
# Spatial positional embedding (64 corresponds to 224/16 = 14x14)
|
| 245 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 64, self.embed_dim))
|
| 246 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
| 247 |
+
|
| 248 |
+
# ====== New: sinusoidal time positional embedding (buffer) ======
|
| 249 |
+
self.max_frames = max_frames
|
| 250 |
+
time_grid = torch.arange(max_frames, dtype=torch.float32) # (T0,)
|
| 251 |
+
time_emb = get_1d_sincos_pos_embed_from_grid(
|
| 252 |
+
self.embed_dim, time_grid
|
| 253 |
+
) # (1, T0, D)
|
| 254 |
+
# Follow your pattern: register_buffer + interpolate_time_embed
|
| 255 |
+
self.register_buffer("time_emb", time_emb, persistent=False)
|
| 256 |
+
|
| 257 |
+
# Patch embed and DPT head
|
| 258 |
+
self.patch_embed = PatchEmbed(
|
| 259 |
+
img_size=self.img_size,
|
| 260 |
+
patch_size=self.patch_size,
|
| 261 |
+
in_chans=input_dim,
|
| 262 |
+
embed_dim=self.embed_dim,
|
| 263 |
+
)
|
| 264 |
+
self.dpt_head = DPTHead(
|
| 265 |
+
self.embed_dim,
|
| 266 |
+
MODEL_CONFIGS[model_name]["features"],
|
| 267 |
+
out_channels=MODEL_CONFIGS[model_name]["out_channels"],
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Temporal block(s)
|
| 271 |
+
num_heads = getattr(model.blocks[0].attn, "num_heads", 8)
|
| 272 |
+
self.shared_temporal_block = shared_temporal_block
|
| 273 |
+
|
| 274 |
+
# Insert a temporal block after every N spatial blocks
|
| 275 |
+
self.temporal_interleave_stride = max(1, int(temporal_interleave_stride))
|
| 276 |
+
|
| 277 |
+
# Calculate how many temporal blocks we need
|
| 278 |
+
num_temporal_blocks = sum(
|
| 279 |
+
1
|
| 280 |
+
for i in range(len(self.blks))
|
| 281 |
+
if (i + 1) % self.temporal_interleave_stride == 0
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if shared_temporal_block:
|
| 285 |
+
# Single shared temporal block for all layers
|
| 286 |
+
self.temporal_block = TemporalSelfAttentionBlock(
|
| 287 |
+
dim=self.embed_dim,
|
| 288 |
+
num_heads=num_heads,
|
| 289 |
+
mlp_ratio=mlp_ratio,
|
| 290 |
+
attn_drop=attn_dropout,
|
| 291 |
+
drop=proj_dropout,
|
| 292 |
+
drop_path=drop_path,
|
| 293 |
+
)
|
| 294 |
+
self.temporal_blocks = None
|
| 295 |
+
else:
|
| 296 |
+
# Separate temporal block for each layer
|
| 297 |
+
self.temporal_block = None
|
| 298 |
+
self.temporal_blocks = nn.ModuleList(
|
| 299 |
+
[
|
| 300 |
+
TemporalSelfAttentionBlock(
|
| 301 |
+
dim=self.embed_dim,
|
| 302 |
+
num_heads=num_heads,
|
| 303 |
+
mlp_ratio=mlp_ratio,
|
| 304 |
+
attn_drop=attn_dropout,
|
| 305 |
+
drop=proj_dropout,
|
| 306 |
+
drop_path=drop_path,
|
| 307 |
+
)
|
| 308 |
+
for _ in range(num_temporal_blocks)
|
| 309 |
+
]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# ====== New: interpolate temporal positional embedding ======
|
| 313 |
+
def interpolate_time_embed(self, x_like: torch.Tensor, t: int) -> torch.Tensor:
|
| 314 |
+
"""
|
| 315 |
+
x_like: used only to fetch dtype (e.g., fp16)
|
| 316 |
+
Return: time positional embedding of shape (1, t, D)
|
| 317 |
+
"""
|
| 318 |
+
previous_dtype = x_like.dtype
|
| 319 |
+
T0 = self.time_emb.shape[1]
|
| 320 |
+
if t == T0:
|
| 321 |
+
return self.time_emb.to(previous_dtype)
|
| 322 |
+
temb = self.time_emb.float() # (1, T0, D)
|
| 323 |
+
temb = F.interpolate(
|
| 324 |
+
temb.permute(0, 2, 1), size=t, mode="linear", align_corners=False
|
| 325 |
+
).permute(0, 2, 1) # (1, t, D)
|
| 326 |
+
return temb.to(previous_dtype)
|
| 327 |
+
|
| 328 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 329 |
+
"""
|
| 330 |
+
Interpolate the 2D spatial positional encoding to match HxW (in patches).
|
| 331 |
+
"""
|
| 332 |
+
previous_dtype = x.dtype
|
| 333 |
+
npatch = x.shape[1]
|
| 334 |
+
N = self.pos_embed.shape[1]
|
| 335 |
+
if npatch == N and w == h:
|
| 336 |
+
return self.pos_embed
|
| 337 |
+
pos_embed = self.pos_embed.float()
|
| 338 |
+
dim = x.shape[-1]
|
| 339 |
+
w0 = w // self.patch_size
|
| 340 |
+
h0 = h // self.patch_size
|
| 341 |
+
sqrt_N = math.sqrt(N)
|
| 342 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
| 343 |
+
pos_embed = nn.functional.interpolate(
|
| 344 |
+
pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
| 345 |
+
scale_factor=(sy, sx),
|
| 346 |
+
mode="bicubic",
|
| 347 |
+
antialias=False,
|
| 348 |
+
)
|
| 349 |
+
assert int(w0) == pos_embed.shape[-1] and int(h0) == pos_embed.shape[-2]
|
| 350 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 351 |
+
return pos_embed.to(previous_dtype)
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
"""
|
| 355 |
+
x: (B, T, C, H, W)
|
| 356 |
+
"""
|
| 357 |
+
B, T, C, H, W = x.shape
|
| 358 |
+
# Merge time into batch for per-frame spatial encoding
|
| 359 |
+
x = x.view(B * T, C, H, W)
|
| 360 |
+
x = self.patch_embed(x) # (B*T, Np, D)
|
| 361 |
+
|
| 362 |
+
x = x.view(B, T, *x.shape[1:])
|
| 363 |
+
# Get time positional embedding for current T via linear interpolation: (1, T, D)
|
| 364 |
+
tpos = self.interpolate_time_embed(x, T).unsqueeze(2) # (1, T, 1, D)
|
| 365 |
+
x = x + tpos # (B, T, Np, D)
|
| 366 |
+
x = x.view(B * T, *x.shape[2:]) # (B*T, Np, D)
|
| 367 |
+
|
| 368 |
+
x = x + self.interpolate_pos_encoding(x, H, W)
|
| 369 |
+
|
| 370 |
+
outputs = []
|
| 371 |
+
temporal_block_idx = 0
|
| 372 |
+
for i in range(len(self.blks)):
|
| 373 |
+
# 1) Spatial self-attention (per frame)
|
| 374 |
+
x = self.blks[i](x) # (B*T, Np, D)
|
| 375 |
+
# 2) Interleave temporal self-attention (across frames, same spatial patch)
|
| 376 |
+
if (i + 1) % self.temporal_interleave_stride == 0:
|
| 377 |
+
x = x.view(B, T, *x.shape[1:])
|
| 378 |
+
if self.shared_temporal_block:
|
| 379 |
+
x = self.temporal_block(x)
|
| 380 |
+
else:
|
| 381 |
+
x = self.temporal_blocks[temporal_block_idx](x)
|
| 382 |
+
temporal_block_idx += 1
|
| 383 |
+
x = x.view(B * T, *x.shape[2:])
|
| 384 |
+
# 3) Collect intermediate features for DPT head
|
| 385 |
+
if i in self.idx:
|
| 386 |
+
outputs.append([x])
|
| 387 |
+
|
| 388 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 389 |
+
# DPT head consumes (B*T, Np, D); here batch is B*T
|
| 390 |
+
out, path_1, path_2, path_3, path_4 = self.dpt_head.forward(
|
| 391 |
+
outputs, patch_h, patch_w, return_intermediate=True
|
| 392 |
+
)
|
| 393 |
+
# Upsample per frame
|
| 394 |
+
out = F.interpolate(
|
| 395 |
+
out, (H, W), mode="bilinear", align_corners=True
|
| 396 |
+
) # (B*T, Cout, H, W)
|
| 397 |
+
|
| 398 |
+
# Restore (B, T, ...)
|
| 399 |
+
def bt_to_btensor(tensor_or_none):
|
| 400 |
+
if tensor_or_none is None:
|
| 401 |
+
return None
|
| 402 |
+
return tensor_or_none.view(B, T, *tensor_or_none.shape[1:])
|
| 403 |
+
|
| 404 |
+
return {
|
| 405 |
+
"out": out.view(B, T, *out.shape[1:]),
|
| 406 |
+
"path_1": bt_to_btensor(path_1),
|
| 407 |
+
"path_2": bt_to_btensor(path_2),
|
| 408 |
+
"path_3": bt_to_btensor(path_3),
|
| 409 |
+
"path_4": bt_to_btensor(path_4),
|
| 410 |
+
}
|
| 411 |
+
|
cowtracker/models/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CoWTracker models."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def __getattr__(name):
|
| 11 |
+
"""Lazy import to avoid import errors when dependencies are missing."""
|
| 12 |
+
if name == "CoWTracker":
|
| 13 |
+
from cowtracker.models.cowtracker import CoWTracker
|
| 14 |
+
|
| 15 |
+
return CoWTracker
|
| 16 |
+
if name == "CoWTrackerWindowed":
|
| 17 |
+
from cowtracker.models.cowtracker_windowed import CoWTrackerWindowed
|
| 18 |
+
|
| 19 |
+
return CoWTrackerWindowed
|
| 20 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ["CoWTracker", "CoWTrackerWindowed"]
|
cowtracker/models/cowtracker.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CoWTracker: Simple version for short videos."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 12 |
+
|
| 13 |
+
import cowtracker.thirdparty # noqa: F401 - sets up vggt path
|
| 14 |
+
from vggt.models.aggregator import Aggregator
|
| 15 |
+
from cowtracker.heads.feature_extractor import FeatureExtractor
|
| 16 |
+
from cowtracker.heads.tracking_head import CowTrackingHead
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CoWTracker(nn.Module, PyTorchModelHubMixin):
|
| 20 |
+
"""
|
| 21 |
+
CoWTracker simple version: processes entire video at once.
|
| 22 |
+
|
| 23 |
+
Suitable for: short videos / sufficient GPU memory.
|
| 24 |
+
For long videos, use CoWTrackerWindowed instead.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Backbone configuration
|
| 28 |
+
IMG_SIZE = 518
|
| 29 |
+
PATCH_SIZE = 14
|
| 30 |
+
EMBED_DIM = 1024
|
| 31 |
+
PATCH_EMBED = "dinov2_vitl14_reg"
|
| 32 |
+
DEPTH = 24
|
| 33 |
+
|
| 34 |
+
# Default HuggingFace repo for model loading
|
| 35 |
+
DEFAULT_REPO_ID = "facebook/cowtracker"
|
| 36 |
+
DEFAULT_FILENAME = "cowtracker_model.pth"
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
features: int = 128,
|
| 41 |
+
side_resnet_channels: int = 128,
|
| 42 |
+
down_ratio: int = 2,
|
| 43 |
+
warp_iters: int = 5,
|
| 44 |
+
warp_vit_num_blocks: int = None,
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
features: Number of DPT output features.
|
| 49 |
+
side_resnet_channels: Number of ResNet side feature channels.
|
| 50 |
+
down_ratio: Feature downsampling ratio.
|
| 51 |
+
warp_iters: Number of Warping-based iterative refinement iterations.
|
| 52 |
+
warp_vit_num_blocks: Number of transformer blocks (None = default).
|
| 53 |
+
"""
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
print("Initializing CoWTracker...")
|
| 57 |
+
|
| 58 |
+
# Backbone: VGGT backbone
|
| 59 |
+
self.aggregator = Aggregator(
|
| 60 |
+
img_size=self.IMG_SIZE,
|
| 61 |
+
patch_size=self.PATCH_SIZE,
|
| 62 |
+
embed_dim=self.EMBED_DIM,
|
| 63 |
+
patch_embed=self.PATCH_EMBED,
|
| 64 |
+
depth=self.DEPTH,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# High Resolution Feature extraction
|
| 68 |
+
self.feature_extractor = FeatureExtractor(
|
| 69 |
+
features=features,
|
| 70 |
+
down_ratio=down_ratio,
|
| 71 |
+
side_resnet_channels=side_resnet_channels,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Tracking head: warping-based iterative refinement
|
| 75 |
+
self.tracking_head = CowTrackingHead(
|
| 76 |
+
feature_dim=self.feature_extractor.out_dim,
|
| 77 |
+
down_ratio=down_ratio,
|
| 78 |
+
warp_iters=warp_iters,
|
| 79 |
+
warp_vit_num_blocks=warp_vit_num_blocks,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
print(f" - Features: {features}, Side channels: {side_resnet_channels}")
|
| 83 |
+
print(f" - Warping-based iterative refinement iterations: {warp_iters}")
|
| 84 |
+
|
| 85 |
+
def forward(self, video: torch.Tensor, queries: torch.Tensor = None) -> dict:
|
| 86 |
+
"""
|
| 87 |
+
Forward pass for dense tracking.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
video: Input video [B, S, 3, H, W] or [S, 3, H, W] in range [0, 255].
|
| 91 |
+
queries: Optional query points (unused, for API compatibility).
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
dict with:
|
| 95 |
+
- track: Dense tracks [B, S, H, W, 2].
|
| 96 |
+
- vis: Visibility scores [B, S, H, W].
|
| 97 |
+
- conf: Confidence scores [B, S, H, W].
|
| 98 |
+
"""
|
| 99 |
+
# Normalize input
|
| 100 |
+
images = video / 255.0
|
| 101 |
+
if images.ndim == 4:
|
| 102 |
+
images = images.unsqueeze(0)
|
| 103 |
+
|
| 104 |
+
B, S, C, H, W = images.shape
|
| 105 |
+
|
| 106 |
+
# Extract backbone tokens
|
| 107 |
+
tokens, patch_idx = self.aggregator(images)
|
| 108 |
+
|
| 109 |
+
# Extract high resolution features
|
| 110 |
+
features = self.feature_extractor(tokens, images, patch_idx)
|
| 111 |
+
|
| 112 |
+
# Run tracking
|
| 113 |
+
predictions = self.tracking_head(features, image_size=(H, W))
|
| 114 |
+
|
| 115 |
+
if not self.training:
|
| 116 |
+
predictions["images"] = images
|
| 117 |
+
|
| 118 |
+
return predictions
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def _remap_legacy_state_dict(state_dict: dict) -> dict:
|
| 122 |
+
"""
|
| 123 |
+
Remap legacy checkpoint keys to new model structure.
|
| 124 |
+
|
| 125 |
+
Old structure:
|
| 126 |
+
tracking_head.aggregator.* -> aggregator.*
|
| 127 |
+
tracking_head.feature_extractor.* -> feature_extractor.dpt_head.*
|
| 128 |
+
tracking_head.fnet.* -> feature_extractor.fnet.*
|
| 129 |
+
tracking_head.* (rest) -> tracking_head.*
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
state_dict: Original state dict.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Remapped state dict.
|
| 136 |
+
"""
|
| 137 |
+
new_state_dict = {}
|
| 138 |
+
|
| 139 |
+
for key, value in state_dict.items():
|
| 140 |
+
new_key = key
|
| 141 |
+
|
| 142 |
+
# Remap tracking_head.aggregator -> aggregator
|
| 143 |
+
if key.startswith("tracking_head.aggregator."):
|
| 144 |
+
new_key = key.replace("tracking_head.aggregator.", "aggregator.")
|
| 145 |
+
# Remap tracking_head.feature_extractor -> feature_extractor.dpt_head
|
| 146 |
+
elif key.startswith("tracking_head.feature_extractor."):
|
| 147 |
+
new_key = key.replace(
|
| 148 |
+
"tracking_head.feature_extractor.", "feature_extractor.dpt_head."
|
| 149 |
+
)
|
| 150 |
+
# Remap tracking_head.fnet -> feature_extractor.fnet
|
| 151 |
+
elif key.startswith("tracking_head.fnet."):
|
| 152 |
+
new_key = key.replace("tracking_head.fnet.", "feature_extractor.fnet.")
|
| 153 |
+
|
| 154 |
+
new_state_dict[new_key] = value
|
| 155 |
+
|
| 156 |
+
return new_state_dict
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def _load_checkpoint(cls, checkpoint_path: str = None) -> dict:
|
| 160 |
+
"""
|
| 161 |
+
Load checkpoint from local path or HuggingFace Hub.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
checkpoint_path: Local file path to checkpoint.
|
| 165 |
+
If None, downloads from default HuggingFace repo.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Loaded checkpoint dict.
|
| 169 |
+
"""
|
| 170 |
+
import os
|
| 171 |
+
|
| 172 |
+
if checkpoint_path is None:
|
| 173 |
+
# Download from HuggingFace Hub (uses HF_TOKEN env var automatically)
|
| 174 |
+
print(f"Downloading checkpoint from HuggingFace: {cls.DEFAULT_REPO_ID}/{cls.DEFAULT_FILENAME}")
|
| 175 |
+
checkpoint_path = hf_hub_download(
|
| 176 |
+
repo_id=cls.DEFAULT_REPO_ID,
|
| 177 |
+
filename=cls.DEFAULT_FILENAME,
|
| 178 |
+
)
|
| 179 |
+
print(f"Downloaded to: {checkpoint_path}")
|
| 180 |
+
else:
|
| 181 |
+
checkpoint_path = os.path.expanduser(checkpoint_path)
|
| 182 |
+
print(f"Loading checkpoint from local path: {checkpoint_path}")
|
| 183 |
+
|
| 184 |
+
with open(checkpoint_path, "rb") as fp:
|
| 185 |
+
ckpt = torch.load(fp, map_location="cpu")
|
| 186 |
+
|
| 187 |
+
return ckpt
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def from_checkpoint(
|
| 191 |
+
cls,
|
| 192 |
+
checkpoint_path: str = None,
|
| 193 |
+
device: str = "cuda",
|
| 194 |
+
dtype=torch.bfloat16,
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Load model from checkpoint (local path or HuggingFace Hub).
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
checkpoint_path: Path to local checkpoint file.
|
| 201 |
+
If None, downloads from default HuggingFace repo.
|
| 202 |
+
device: Target device.
|
| 203 |
+
dtype: Target dtype.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Loaded model in eval mode.
|
| 207 |
+
"""
|
| 208 |
+
model = cls()
|
| 209 |
+
|
| 210 |
+
ckpt = cls._load_checkpoint(checkpoint_path)
|
| 211 |
+
state_dict = ckpt.get("model", ckpt)
|
| 212 |
+
|
| 213 |
+
# Remap legacy checkpoint keys if needed
|
| 214 |
+
legacy_prefixes = ["tracking_head.feature_extractor.", "tracking_head.aggregator.", "tracking_head.fnet."]
|
| 215 |
+
if any(k.startswith(p) for k in state_dict.keys() for p in legacy_prefixes):
|
| 216 |
+
print("Detected legacy checkpoint format, remapping keys...")
|
| 217 |
+
state_dict = cls._remap_legacy_state_dict(state_dict)
|
| 218 |
+
|
| 219 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 220 |
+
print(f"Load message: {msg}")
|
| 221 |
+
|
| 222 |
+
model = model.to(device).to(dtype)
|
| 223 |
+
model.eval()
|
| 224 |
+
for p in model.parameters():
|
| 225 |
+
p.requires_grad = False
|
| 226 |
+
|
| 227 |
+
print("Model loaded successfully!")
|
| 228 |
+
return model
|
cowtracker/models/cowtracker_windowed.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""CoWTracker Windowed: Full version for long videos."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 12 |
+
|
| 13 |
+
from cowtracker.models.cowtracker import CoWTracker
|
| 14 |
+
from cowtracker.inference.windowed import WindowedInference
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CoWTrackerWindowed(nn.Module, PyTorchModelHubMixin):
|
| 18 |
+
"""
|
| 19 |
+
CoWTracker windowed version: processes video in sliding windows.
|
| 20 |
+
|
| 21 |
+
Suitable for: long videos / limited GPU memory.
|
| 22 |
+
Composes CoWTracker with WindowedInference.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
# Window parameters
|
| 28 |
+
window_len: int = 100,
|
| 29 |
+
stride: int = 100,
|
| 30 |
+
num_memory_frames: int = 10,
|
| 31 |
+
# CoWTracker parameters
|
| 32 |
+
**cow_tracker_kwargs,
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
window_len: Number of frames per window.
|
| 37 |
+
stride: Step size between windows.
|
| 38 |
+
num_memory_frames: Maximum number of memory frames.
|
| 39 |
+
**cow_tracker_kwargs: Arguments passed to CoWTracker.
|
| 40 |
+
"""
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
print(f"Initializing CoWTrackerWindowed: window_len={window_len}, stride={stride}")
|
| 44 |
+
|
| 45 |
+
self.model = CoWTracker(**cow_tracker_kwargs)
|
| 46 |
+
self.windowed = WindowedInference(
|
| 47 |
+
window_len=window_len,
|
| 48 |
+
stride=stride,
|
| 49 |
+
num_memory_frames=num_memory_frames,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, video: torch.Tensor, queries: torch.Tensor = None) -> dict:
|
| 53 |
+
"""
|
| 54 |
+
Forward pass with windowed inference.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
video: Input video [B, S, 3, H, W] or [S, 3, H, W] in range [0, 255].
|
| 58 |
+
queries: Optional query points (unused, for API compatibility).
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict with:
|
| 62 |
+
- track: Dense tracks [B, T, H, W, 2].
|
| 63 |
+
- vis: Visibility scores [B, T, H, W].
|
| 64 |
+
- conf: Confidence scores [B, T, H, W].
|
| 65 |
+
"""
|
| 66 |
+
# Normalize input
|
| 67 |
+
images = video / 255.0
|
| 68 |
+
if images.ndim == 4:
|
| 69 |
+
images = images.unsqueeze(0)
|
| 70 |
+
|
| 71 |
+
B, T, C, H, W = images.shape
|
| 72 |
+
device, dtype = images.device, images.dtype
|
| 73 |
+
|
| 74 |
+
# Initialize accumulated outputs
|
| 75 |
+
accumulated = {
|
| 76 |
+
"track": torch.zeros((B, T, H, W, 2), device=device, dtype=dtype),
|
| 77 |
+
"vis": torch.zeros((B, T, H, W), device=device, dtype=dtype),
|
| 78 |
+
"conf": torch.zeros((B, T, H, W), device=device, dtype=dtype),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
windows = self.windowed.compute_windows(T)
|
| 82 |
+
first_frame = images[:, 0:1]
|
| 83 |
+
first_frame_features = None
|
| 84 |
+
|
| 85 |
+
for window_idx, (start, end) in enumerate(windows):
|
| 86 |
+
if not self.training:
|
| 87 |
+
print(f"Processing window {window_idx + 1}/{len(windows)}: frames [{start}, {end})")
|
| 88 |
+
|
| 89 |
+
# Get memory frame indices
|
| 90 |
+
memory_indices = self.windowed.select_memory_frames(window_idx, start)
|
| 91 |
+
if not self.training and memory_indices:
|
| 92 |
+
print(f" Memory frames: {memory_indices}")
|
| 93 |
+
|
| 94 |
+
# Gather frames: first_frame + memory + window
|
| 95 |
+
frames = self._gather_frames(images, first_frame, start, end, memory_indices)
|
| 96 |
+
|
| 97 |
+
# Extract backbone tokens
|
| 98 |
+
tokens, patch_idx = self.model.aggregator(frames)
|
| 99 |
+
|
| 100 |
+
# Extract combined features
|
| 101 |
+
features = self.model.feature_extractor(tokens, frames, patch_idx)
|
| 102 |
+
|
| 103 |
+
# Split features: first_frame | memory | window
|
| 104 |
+
first_frame_features = features[:, 0:1]
|
| 105 |
+
num_memory = len(memory_indices)
|
| 106 |
+
offset = 1 + num_memory
|
| 107 |
+
|
| 108 |
+
# Run tracking on extended features (memory + window), using first_frame as reference
|
| 109 |
+
extended_features = features[:, 1:] # Exclude first_frame from input
|
| 110 |
+
pred = self.model.tracking_head(
|
| 111 |
+
extended_features,
|
| 112 |
+
image_size=(H, W),
|
| 113 |
+
first_frame_features=first_frame_features,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Extract window predictions (remove memory frames from output)
|
| 117 |
+
window_pred = {
|
| 118 |
+
"track": pred["track"][:, num_memory:],
|
| 119 |
+
"vis": pred["vis"][:, num_memory:],
|
| 120 |
+
"conf": pred["conf"][:, num_memory:],
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# Merge into accumulated results
|
| 124 |
+
self.windowed.merge_predictions(window_idx, start, end, window_pred, accumulated)
|
| 125 |
+
|
| 126 |
+
# Cleanup for memory efficiency
|
| 127 |
+
if not self.training:
|
| 128 |
+
del features, tokens, pred
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
+
|
| 131 |
+
if not self.training:
|
| 132 |
+
accumulated["images"] = images
|
| 133 |
+
|
| 134 |
+
return accumulated
|
| 135 |
+
|
| 136 |
+
def _gather_frames(
|
| 137 |
+
self,
|
| 138 |
+
images: torch.Tensor,
|
| 139 |
+
first_frame: torch.Tensor,
|
| 140 |
+
start: int,
|
| 141 |
+
end: int,
|
| 142 |
+
memory_indices: list,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
"""Gather first_frame + memory + window frames."""
|
| 145 |
+
parts = [first_frame]
|
| 146 |
+
|
| 147 |
+
if memory_indices:
|
| 148 |
+
parts.append(images[:, memory_indices])
|
| 149 |
+
|
| 150 |
+
parts.append(images[:, start:end])
|
| 151 |
+
|
| 152 |
+
return torch.cat(parts, dim=1)
|
| 153 |
+
|
| 154 |
+
# Proxy properties for convenient access to internal components
|
| 155 |
+
@property
|
| 156 |
+
def aggregator(self):
|
| 157 |
+
return self.model.aggregator
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def feature_extractor(self):
|
| 161 |
+
return self.model.feature_extractor
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def tracking_head(self):
|
| 165 |
+
return self.model.tracking_head
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def from_checkpoint(
|
| 169 |
+
cls,
|
| 170 |
+
checkpoint_path: str = None,
|
| 171 |
+
window_len: int = 100,
|
| 172 |
+
stride: int = 100,
|
| 173 |
+
device: str = "cuda",
|
| 174 |
+
dtype=torch.bfloat16,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Load model from checkpoint (local path or HuggingFace Hub).
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
checkpoint_path: Path to local checkpoint file.
|
| 181 |
+
If None, downloads from default HuggingFace repo.
|
| 182 |
+
window_len: Number of frames per window.
|
| 183 |
+
stride: Step size between windows.
|
| 184 |
+
device: Target device.
|
| 185 |
+
dtype: Target dtype.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Loaded model in eval mode.
|
| 189 |
+
"""
|
| 190 |
+
model = cls(window_len=window_len, stride=stride)
|
| 191 |
+
|
| 192 |
+
# Use CoWTracker's checkpoint loading method (handles local path and HuggingFace download)
|
| 193 |
+
ckpt = CoWTracker._load_checkpoint(checkpoint_path)
|
| 194 |
+
state_dict = ckpt.get("model", ckpt)
|
| 195 |
+
|
| 196 |
+
# Remap legacy checkpoint keys if needed (delegate to CoWTracker)
|
| 197 |
+
legacy_prefixes = ["tracking_head.feature_extractor.", "tracking_head.aggregator.", "tracking_head.fnet."]
|
| 198 |
+
if any(k.startswith(p) for k in state_dict.keys() for p in legacy_prefixes):
|
| 199 |
+
print("Detected legacy checkpoint format, remapping keys...")
|
| 200 |
+
state_dict = CoWTracker._remap_legacy_state_dict(state_dict)
|
| 201 |
+
|
| 202 |
+
# Add "model." prefix if checkpoint is from CoWTracker (no prefix)
|
| 203 |
+
# CoWTrackerWindowed wraps CoWTracker as self.model, so keys need "model." prefix
|
| 204 |
+
if not any(k.startswith("model.") for k in state_dict.keys()):
|
| 205 |
+
print("Adding 'model.' prefix to state dict keys...")
|
| 206 |
+
state_dict = {f"model.{k}": v for k, v in state_dict.items()}
|
| 207 |
+
|
| 208 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 209 |
+
print(f"Load message: {msg}")
|
| 210 |
+
|
| 211 |
+
model = model.to(device).to(dtype)
|
| 212 |
+
model.eval()
|
| 213 |
+
for p in model.parameters():
|
| 214 |
+
p.requires_grad = False
|
| 215 |
+
|
| 216 |
+
print("Model loaded successfully!")
|
| 217 |
+
return model
|
| 218 |
+
|
cowtracker/thirdparty/DepthAnythingV2
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit e5a2732d3ea2cddc081d7bfd708fc0bf09f812f1
|
cowtracker/thirdparty/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Third-party modules.
|
| 8 |
+
|
| 9 |
+
This module sets up sys.path for third-party packages.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
# Add vggt to sys.path so that 'from vggt.xxx' imports work
|
| 16 |
+
_vggt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vggt")
|
| 17 |
+
if _vggt_path not in sys.path:
|
| 18 |
+
sys.path.insert(0, _vggt_path)
|
| 19 |
+
|
cowtracker/thirdparty/vggt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 44b3afbd1869d8bde4894dd8ea1e293112dd5eba
|
cowtracker/utils/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from cowtracker.utils.padding import (
|
| 8 |
+
compute_padding_params,
|
| 9 |
+
apply_padding,
|
| 10 |
+
remove_padding_and_scale_back,
|
| 11 |
+
)
|
| 12 |
+
from cowtracker.utils.visualization import paint_point_track
|
| 13 |
+
from cowtracker.utils.ops import (
|
| 14 |
+
bilinear_sampler,
|
| 15 |
+
coords_grid,
|
| 16 |
+
Padder,
|
| 17 |
+
load_ckpt,
|
| 18 |
+
upflow8,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"compute_padding_params",
|
| 23 |
+
"apply_padding",
|
| 24 |
+
"remove_padding_and_scale_back",
|
| 25 |
+
"paint_point_track",
|
| 26 |
+
"bilinear_sampler",
|
| 27 |
+
"coords_grid",
|
| 28 |
+
"Padder",
|
| 29 |
+
"load_ckpt",
|
| 30 |
+
"upflow8",
|
| 31 |
+
]
|
| 32 |
+
|
cowtracker/utils/ops.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Common operations for tracking: bilinear sampling, coordinate grids, etc."""
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import numpy as np
|
| 13 |
+
from scipy import interpolate
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_ckpt(model, path):
|
| 17 |
+
"""Load checkpoint."""
|
| 18 |
+
state_dict = torch.load(path, map_location=torch.device("cpu"))
|
| 19 |
+
model.load_state_dict(state_dict, strict=False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def resize_data(img1, img2, flow, factor=1.0):
|
| 23 |
+
_, _, h, w = img1.shape
|
| 24 |
+
h = int(h * factor)
|
| 25 |
+
w = int(w * factor)
|
| 26 |
+
img1 = F.interpolate(img1, (h, w), mode="area")
|
| 27 |
+
img2 = F.interpolate(img2, (h, w), mode="area")
|
| 28 |
+
flow = F.interpolate(flow, (h, w), mode="area") * factor
|
| 29 |
+
return img1, img2, flow
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Padder:
|
| 33 |
+
"""Pads images such that dimensions are divisible by factor."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, dims, mode="sintel", factor=32):
|
| 36 |
+
self.ht, self.wd = dims[-2:]
|
| 37 |
+
pad_ht = (((self.ht + 8) // factor) + 1) * factor - self.ht
|
| 38 |
+
pad_wd = (((self.wd + 8) // factor) + 1) * factor - self.wd
|
| 39 |
+
if mode == "sintel":
|
| 40 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
|
| 41 |
+
else:
|
| 42 |
+
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
| 43 |
+
|
| 44 |
+
def pad(self, x):
|
| 45 |
+
return F.pad(x, self._pad, mode="constant", value=0)
|
| 46 |
+
|
| 47 |
+
def unpad(self, x):
|
| 48 |
+
ht, wd = x.shape[-2:]
|
| 49 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
| 50 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def forward_interpolate(flow):
|
| 54 |
+
flow = flow.detach().cpu().numpy()
|
| 55 |
+
dx, dy = flow[0], flow[1]
|
| 56 |
+
|
| 57 |
+
ht, wd = dx.shape
|
| 58 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
| 59 |
+
|
| 60 |
+
x1 = x0 + dx
|
| 61 |
+
y1 = y0 + dy
|
| 62 |
+
|
| 63 |
+
x1 = x1.reshape(-1)
|
| 64 |
+
y1 = y1.reshape(-1)
|
| 65 |
+
dx = dx.reshape(-1)
|
| 66 |
+
dy = dy.reshape(-1)
|
| 67 |
+
|
| 68 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
| 69 |
+
x1 = x1[valid]
|
| 70 |
+
y1 = y1[valid]
|
| 71 |
+
dx = dx[valid]
|
| 72 |
+
dy = dy[valid]
|
| 73 |
+
|
| 74 |
+
flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="nearest", fill_value=0)
|
| 75 |
+
|
| 76 |
+
flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="nearest", fill_value=0)
|
| 77 |
+
|
| 78 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
| 79 |
+
return torch.from_numpy(flow).float()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
| 83 |
+
"""Wrapper for grid_sample, uses pixel coordinates."""
|
| 84 |
+
H, W = img.shape[-2:]
|
| 85 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
| 86 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
| 87 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
| 88 |
+
|
| 89 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 90 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 91 |
+
|
| 92 |
+
if mask:
|
| 93 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 94 |
+
return img, mask.float()
|
| 95 |
+
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def coords_grid(batch, ht, wd, device):
|
| 100 |
+
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
| 101 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 102 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def upflow8(flow, mode="bilinear"):
|
| 106 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
| 107 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def transform(T, p):
|
| 111 |
+
assert T.shape == (4, 4)
|
| 112 |
+
return np.einsum("H W j, i j -> H W i", p, T[:3, :3]) + T[:3, 3]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def from_homog(x):
|
| 116 |
+
return x[..., :-1] / x[..., [-1]]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def reproject(depth1, pose1, pose2, K1, K2):
|
| 120 |
+
H, W = depth1.shape
|
| 121 |
+
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
|
| 122 |
+
img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64)
|
| 123 |
+
cam1_coords = np.einsum("H W, H W j, i j -> H W i", depth1, img_1_coords, np.linalg.inv(K1))
|
| 124 |
+
rel_pose = pose2 @ np.linalg.inv(pose1)
|
| 125 |
+
cam2_coords = transform(rel_pose, cam1_coords)
|
| 126 |
+
return from_homog(np.einsum("H W j, i j -> H W i", cam2_coords, K2))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def induced_flow(depth0, depth1, data):
|
| 130 |
+
H, W = depth0.shape
|
| 131 |
+
coords1 = reproject(depth0, data["T0"], data["T1"], data["K0"], data["K1"])
|
| 132 |
+
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
|
| 133 |
+
coords0 = np.stack([x, y], axis=-1)
|
| 134 |
+
flow_01 = coords1 - coords0
|
| 135 |
+
H, W = depth1.shape
|
| 136 |
+
coords1 = reproject(depth1, data["T1"], data["T0"], data["K1"], data["K0"])
|
| 137 |
+
x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
|
| 138 |
+
coords0 = np.stack([x, y], axis=-1)
|
| 139 |
+
flow_10 = coords1 - coords0
|
| 140 |
+
return flow_01, flow_10
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def check_cycle_consistency(flow_01, flow_10):
|
| 144 |
+
H, W = flow_01.shape[:2]
|
| 145 |
+
new_coords = flow_01 + np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), axis=-1)
|
| 146 |
+
flow_reprojected = cv2.remap(flow_10, new_coords.astype(np.float32), None, interpolation=cv2.INTER_LINEAR)
|
| 147 |
+
cycle = flow_reprojected + flow_01
|
| 148 |
+
cycle = np.linalg.norm(cycle, axis=-1)
|
| 149 |
+
mask = (cycle < 0.1 * min(H, W)).astype(np.float32)
|
| 150 |
+
return mask
|
| 151 |
+
|
cowtracker/utils/padding.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Padding utilities for video preprocessing and postprocessing."""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_padding_params(orig_H, orig_W, inf_H, inf_W, skip_upscaling=False):
|
| 14 |
+
"""Compute padding parameters to preserve aspect ratio.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
orig_H: Original height
|
| 18 |
+
orig_W: Original width
|
| 19 |
+
inf_H: Inference height
|
| 20 |
+
inf_W: Inference width
|
| 21 |
+
skip_upscaling: If True and scale > 1, skip upscaling and just pad
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Dictionary containing:
|
| 25 |
+
- scale: Scale factor that would be applied (1.0 if skipped)
|
| 26 |
+
- scaled_H, scaled_W: Dimensions after scaling (before padding)
|
| 27 |
+
- pad_top, pad_bottom, pad_left, pad_right: Padding amounts
|
| 28 |
+
- orig_H, orig_W: Original dimensions (for reference)
|
| 29 |
+
- upscaling_skipped: Whether upscaling was skipped
|
| 30 |
+
"""
|
| 31 |
+
scale = min(inf_H / orig_H, inf_W / orig_W)
|
| 32 |
+
|
| 33 |
+
upscaling_skipped = False
|
| 34 |
+
if skip_upscaling and scale > 1.0:
|
| 35 |
+
scaled_H = orig_H
|
| 36 |
+
scaled_W = orig_W
|
| 37 |
+
upscaling_skipped = True
|
| 38 |
+
else:
|
| 39 |
+
scaled_H = int(orig_H * scale)
|
| 40 |
+
scaled_W = int(orig_W * scale)
|
| 41 |
+
|
| 42 |
+
pad_H = inf_H - scaled_H
|
| 43 |
+
pad_W = inf_W - scaled_W
|
| 44 |
+
|
| 45 |
+
pad_top = pad_H // 2
|
| 46 |
+
pad_bottom = pad_H - pad_top
|
| 47 |
+
pad_left = pad_W // 2
|
| 48 |
+
pad_right = pad_W - pad_left
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"scale": scale,
|
| 52 |
+
"scaled_H": scaled_H,
|
| 53 |
+
"scaled_W": scaled_W,
|
| 54 |
+
"pad_top": pad_top,
|
| 55 |
+
"pad_bottom": pad_bottom,
|
| 56 |
+
"pad_left": pad_left,
|
| 57 |
+
"pad_right": pad_right,
|
| 58 |
+
"orig_H": orig_H,
|
| 59 |
+
"orig_W": orig_W,
|
| 60 |
+
"upscaling_skipped": upscaling_skipped,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def apply_padding(rgbs, padding_info):
|
| 65 |
+
"""Apply padding to input images to reach inference size.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
rgbs: Input tensor (T, C, H, W)
|
| 69 |
+
padding_info: Dictionary from compute_padding_params
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Padded tensor (T, C, inf_H, inf_W)
|
| 73 |
+
"""
|
| 74 |
+
T, C, H, W = rgbs.shape
|
| 75 |
+
scaled_H = padding_info["scaled_H"]
|
| 76 |
+
scaled_W = padding_info["scaled_W"]
|
| 77 |
+
|
| 78 |
+
if (scaled_H, scaled_W) != (H, W):
|
| 79 |
+
rgbs_scaled = F.interpolate(
|
| 80 |
+
rgbs,
|
| 81 |
+
size=(scaled_H, scaled_W),
|
| 82 |
+
mode="bilinear",
|
| 83 |
+
align_corners=False,
|
| 84 |
+
)
|
| 85 |
+
else:
|
| 86 |
+
rgbs_scaled = rgbs
|
| 87 |
+
|
| 88 |
+
pad_left = padding_info["pad_left"]
|
| 89 |
+
pad_right = padding_info["pad_right"]
|
| 90 |
+
pad_top = padding_info["pad_top"]
|
| 91 |
+
pad_bottom = padding_info["pad_bottom"]
|
| 92 |
+
|
| 93 |
+
rgbs_padded = F.pad(
|
| 94 |
+
rgbs_scaled,
|
| 95 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 96 |
+
mode="constant",
|
| 97 |
+
value=0,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return rgbs_padded
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def remove_padding_and_scale_back(tracks, visibility, confidence, padding_info):
|
| 104 |
+
"""Remove padding from model outputs and scale back to original resolution.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
tracks: Track predictions (T, inf_H, inf_W, 2)
|
| 108 |
+
visibility: Visibility predictions (T, inf_H, inf_W)
|
| 109 |
+
confidence: Confidence predictions (T, inf_H, inf_W)
|
| 110 |
+
padding_info: Dictionary from compute_padding_params
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Tuple of (tracks, visibility, confidence) scaled to original resolution
|
| 114 |
+
"""
|
| 115 |
+
scaled_H = padding_info["scaled_H"]
|
| 116 |
+
scaled_W = padding_info["scaled_W"]
|
| 117 |
+
pad_top = padding_info["pad_top"]
|
| 118 |
+
pad_left = padding_info["pad_left"]
|
| 119 |
+
orig_H = padding_info["orig_H"]
|
| 120 |
+
orig_W = padding_info["orig_W"]
|
| 121 |
+
|
| 122 |
+
tracks_unpadded = tracks[
|
| 123 |
+
:, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W, :
|
| 124 |
+
]
|
| 125 |
+
visibility_unpadded = visibility[
|
| 126 |
+
:, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W
|
| 127 |
+
]
|
| 128 |
+
confidence_unpadded = confidence[
|
| 129 |
+
:, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
tracks_unpadded = tracks_unpadded.clone()
|
| 133 |
+
tracks_unpadded[:, :, :, 0] -= pad_left
|
| 134 |
+
tracks_unpadded[:, :, :, 1] -= pad_top
|
| 135 |
+
|
| 136 |
+
if (scaled_H, scaled_W) != (orig_H, orig_W):
|
| 137 |
+
tracks_permuted = tracks_unpadded.permute(0, 3, 1, 2)
|
| 138 |
+
tracks_scaled = F.interpolate(
|
| 139 |
+
tracks_permuted,
|
| 140 |
+
size=(orig_H, orig_W),
|
| 141 |
+
mode="bilinear",
|
| 142 |
+
align_corners=False,
|
| 143 |
+
)
|
| 144 |
+
tracks_final = tracks_scaled.permute(0, 2, 3, 1)
|
| 145 |
+
|
| 146 |
+
tracks_final[:, :, :, 0] *= orig_W / scaled_W
|
| 147 |
+
tracks_final[:, :, :, 1] *= orig_H / scaled_H
|
| 148 |
+
|
| 149 |
+
visibility_final = F.interpolate(
|
| 150 |
+
visibility_unpadded.unsqueeze(1),
|
| 151 |
+
size=(orig_H, orig_W),
|
| 152 |
+
mode="bilinear",
|
| 153 |
+
align_corners=False,
|
| 154 |
+
).squeeze(1)
|
| 155 |
+
|
| 156 |
+
confidence_final = F.interpolate(
|
| 157 |
+
confidence_unpadded.unsqueeze(1),
|
| 158 |
+
size=(orig_H, orig_W),
|
| 159 |
+
mode="bilinear",
|
| 160 |
+
align_corners=False,
|
| 161 |
+
).squeeze(1)
|
| 162 |
+
else:
|
| 163 |
+
tracks_final = tracks_unpadded
|
| 164 |
+
visibility_final = visibility_unpadded
|
| 165 |
+
confidence_final = confidence_unpadded
|
| 166 |
+
|
| 167 |
+
return tracks_final, visibility_final, confidence_final
|
| 168 |
+
|
cowtracker/utils/visualization.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Visualization utilities for point tracking."""
|
| 8 |
+
|
| 9 |
+
import colorsys
|
| 10 |
+
import os
|
| 11 |
+
import random
|
| 12 |
+
from typing import List, Optional, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import matplotlib
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Bremm 2D colormap for position-based coloring
|
| 21 |
+
# This creates a smooth 2D color gradient based on x,y position
|
| 22 |
+
BREMM_COLORMAP = None # Lazy loaded
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _create_bremm_colormap():
|
| 26 |
+
"""Create a 2D colormap programmatically (Bremm-style).
|
| 27 |
+
|
| 28 |
+
This creates a smooth 2D color gradient where:
|
| 29 |
+
- X position maps to hue variation
|
| 30 |
+
- Y position maps to saturation/value variation
|
| 31 |
+
"""
|
| 32 |
+
size = 256
|
| 33 |
+
colormap = np.zeros((size, size, 3), dtype=np.uint8)
|
| 34 |
+
|
| 35 |
+
for y in range(size):
|
| 36 |
+
for x in range(size):
|
| 37 |
+
# Normalize to [0, 1]
|
| 38 |
+
nx = x / (size - 1)
|
| 39 |
+
ny = y / (size - 1)
|
| 40 |
+
|
| 41 |
+
# Create a 2D color mapping using HSV
|
| 42 |
+
# Hue varies with x, saturation/value with y
|
| 43 |
+
hue = (nx * 0.8 + ny * 0.2) % 1.0 # Mix of x and y for hue
|
| 44 |
+
saturation = 0.6 + 0.4 * (1 - ny) # Higher saturation at top
|
| 45 |
+
value = 0.7 + 0.3 * nx # Higher value on right
|
| 46 |
+
|
| 47 |
+
# Convert HSV to RGB
|
| 48 |
+
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
|
| 49 |
+
colormap[y, x] = [int(c * 255) for c in rgb]
|
| 50 |
+
|
| 51 |
+
return colormap
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _get_bremm_colormap():
|
| 55 |
+
"""Get or create the bremm colormap."""
|
| 56 |
+
global BREMM_COLORMAP
|
| 57 |
+
if BREMM_COLORMAP is None:
|
| 58 |
+
# Try to load from file first
|
| 59 |
+
colormap_file = os.path.join(os.path.dirname(__file__), "bremm.png")
|
| 60 |
+
if os.path.exists(colormap_file):
|
| 61 |
+
BREMM_COLORMAP = (plt.imread(colormap_file) * 255).astype(np.uint8)
|
| 62 |
+
if BREMM_COLORMAP.shape[2] == 4: # RGBA
|
| 63 |
+
BREMM_COLORMAP = BREMM_COLORMAP[:, :, :3]
|
| 64 |
+
else:
|
| 65 |
+
BREMM_COLORMAP = _create_bremm_colormap()
|
| 66 |
+
return BREMM_COLORMAP
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_2d_colors(xys: np.ndarray, H: int, W: int) -> np.ndarray:
|
| 70 |
+
"""Get colors based on 2D position using Bremm colormap.
|
| 71 |
+
|
| 72 |
+
This creates position-dependent colors where nearby points have
|
| 73 |
+
similar colors, useful for visualizing spatial coherence.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
xys: Point coordinates [N, 2] in pixel space (x, y)
|
| 77 |
+
H: Image height
|
| 78 |
+
W: Image width
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Array of RGB colors [N, 3] as uint8
|
| 82 |
+
"""
|
| 83 |
+
colormap = _get_bremm_colormap()
|
| 84 |
+
height, width = colormap.shape[:2]
|
| 85 |
+
|
| 86 |
+
N = xys.shape[0]
|
| 87 |
+
output = np.zeros((N, 3), dtype=np.uint8)
|
| 88 |
+
|
| 89 |
+
# Normalize coordinates to [0, 1]
|
| 90 |
+
xys_norm = xys.copy().astype(np.float32)
|
| 91 |
+
xys_norm[:, 0] = xys_norm[:, 0] / max(W - 1, 1)
|
| 92 |
+
xys_norm[:, 1] = xys_norm[:, 1] / max(H - 1, 1)
|
| 93 |
+
|
| 94 |
+
# Clip to valid range
|
| 95 |
+
xys_norm = np.clip(xys_norm, 0, 1)
|
| 96 |
+
|
| 97 |
+
# Map to colormap coordinates
|
| 98 |
+
for i in range(N):
|
| 99 |
+
x, y = xys_norm[i]
|
| 100 |
+
xp = int((width - 1) * x)
|
| 101 |
+
yp = int((height - 1) * y)
|
| 102 |
+
output[i] = colormap[yp, xp]
|
| 103 |
+
|
| 104 |
+
return output
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_colors_from_cmap(num_colors: int, cmap: str = "gist_rainbow") -> np.ndarray:
|
| 108 |
+
"""Gets colormap for points using matplotlib colormap.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
num_colors: Number of colors to generate
|
| 112 |
+
cmap: Matplotlib colormap name (e.g., "gist_rainbow", "jet", "turbo")
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Array of RGB colors [num_colors, 3] as uint8
|
| 116 |
+
"""
|
| 117 |
+
cmap_ = matplotlib.colormaps.get_cmap(cmap)
|
| 118 |
+
colors = []
|
| 119 |
+
for i in range(num_colors):
|
| 120 |
+
c = cmap_(i / float(num_colors))
|
| 121 |
+
colors.append((int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)))
|
| 122 |
+
return np.array(colors)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def paint_point_track(
|
| 126 |
+
frames: np.ndarray,
|
| 127 |
+
point_tracks: np.ndarray,
|
| 128 |
+
visibles: np.ndarray,
|
| 129 |
+
colormap: Optional[Union[List[Tuple[int, int, int]], np.ndarray]] = None,
|
| 130 |
+
rate: int = 1,
|
| 131 |
+
show_bkg: bool = True,
|
| 132 |
+
) -> np.ndarray:
|
| 133 |
+
"""Paint point tracks on video frames using GPU-accelerated scatter.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
frames: Video frames [T, H, W, C] in uint8
|
| 137 |
+
point_tracks: Track coordinates [P, T, 2] (x, y)
|
| 138 |
+
visibles: Visibility mask [P, T]
|
| 139 |
+
colormap: Optional list/array of RGB colors for each point
|
| 140 |
+
rate: Subsampling rate for visualization (affects point size)
|
| 141 |
+
show_bkg: Whether to show background (True) or black out (False)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Painted frames [T, H, W, C] in uint8
|
| 145 |
+
"""
|
| 146 |
+
print("Starting visualization...")
|
| 147 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 148 |
+
frames_t = (
|
| 149 |
+
torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device)
|
| 150 |
+
) # [T,C,H,W]
|
| 151 |
+
|
| 152 |
+
if show_bkg:
|
| 153 |
+
frames_t = frames_t * 0.5 # darken to see tracks better
|
| 154 |
+
else:
|
| 155 |
+
frames_t = frames_t * 0.0 # black out background
|
| 156 |
+
|
| 157 |
+
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
|
| 158 |
+
visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
|
| 159 |
+
T, C, H, W = frames_t.shape
|
| 160 |
+
P = point_tracks.shape[0]
|
| 161 |
+
|
| 162 |
+
# Use gist_rainbow colormap (matching app3.py behavior)
|
| 163 |
+
if colormap is None:
|
| 164 |
+
colormap = get_colors_from_cmap(P, "gist_rainbow")
|
| 165 |
+
colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
|
| 166 |
+
|
| 167 |
+
# Adjust radius based on rate
|
| 168 |
+
if rate == 1:
|
| 169 |
+
radius = 1
|
| 170 |
+
elif rate == 2:
|
| 171 |
+
radius = 1
|
| 172 |
+
elif rate == 4:
|
| 173 |
+
radius = 2
|
| 174 |
+
elif rate == 8:
|
| 175 |
+
radius = 4
|
| 176 |
+
else:
|
| 177 |
+
radius = 6
|
| 178 |
+
|
| 179 |
+
sharpness = 0.15 + 0.05 * np.log2(rate)
|
| 180 |
+
|
| 181 |
+
D = radius * 2 + 1
|
| 182 |
+
y = torch.arange(D, device=device).float()[:, None] - radius
|
| 183 |
+
x = torch.arange(D, device=device).float()[None, :] - radius
|
| 184 |
+
dist2 = x**2 + y**2
|
| 185 |
+
icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1)
|
| 186 |
+
icon = icon.view(1, D, D)
|
| 187 |
+
dx = torch.arange(-radius, radius + 1, device=device)
|
| 188 |
+
dy = torch.arange(-radius, radius + 1, device=device)
|
| 189 |
+
disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij")
|
| 190 |
+
|
| 191 |
+
for t in range(T):
|
| 192 |
+
mask = visibles_t[:, t]
|
| 193 |
+
if mask.sum() == 0:
|
| 194 |
+
continue
|
| 195 |
+
xy = point_tracks_t[mask, t] + 0.5
|
| 196 |
+
xy[:, 0] = xy[:, 0].clamp(0, W - 1)
|
| 197 |
+
xy[:, 1] = xy[:, 1].clamp(0, H - 1)
|
| 198 |
+
colors_now = colors[mask]
|
| 199 |
+
N = xy.shape[0]
|
| 200 |
+
cx = xy[:, 0].long()
|
| 201 |
+
cy = xy[:, 1].long()
|
| 202 |
+
x_grid = cx[:, None, None] + disp_x
|
| 203 |
+
y_grid = cy[:, None, None] + disp_y
|
| 204 |
+
valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H)
|
| 205 |
+
x_valid = x_grid[valid]
|
| 206 |
+
y_valid = y_grid[valid]
|
| 207 |
+
icon_weights = icon.expand(N, D, D)[valid]
|
| 208 |
+
colors_valid = (
|
| 209 |
+
colors_now[:, :, None, None]
|
| 210 |
+
.expand(N, 3, D, D)
|
| 211 |
+
.permute(1, 0, 2, 3)[:, valid]
|
| 212 |
+
)
|
| 213 |
+
idx_flat = (y_valid * W + x_valid).long()
|
| 214 |
+
|
| 215 |
+
accum = torch.zeros_like(frames_t[t])
|
| 216 |
+
weight = torch.zeros(1, H * W, device=device)
|
| 217 |
+
img_flat = accum.view(C, -1)
|
| 218 |
+
weighted_colors = colors_valid * icon_weights
|
| 219 |
+
img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
|
| 220 |
+
weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0))
|
| 221 |
+
weight = weight.view(1, H, W)
|
| 222 |
+
|
| 223 |
+
alpha = weight.clamp(0, 1)
|
| 224 |
+
accum = accum / (weight + 1e-6)
|
| 225 |
+
frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
|
| 226 |
+
|
| 227 |
+
print("Visualization done.")
|
| 228 |
+
return frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
|
| 229 |
+
|
demo.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Minimal CoWTracker inference demo.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python demo.py --video input.mp4 --output output.mp4
|
| 13 |
+
python demo.py --video input.mp4 --output output.mp4 --checkpoint ~/run168/cow_tracker_model.pth
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import mediapy
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from cowtracker import CoWTracker
|
| 24 |
+
from cowtracker.utils.visualization import paint_point_track
|
| 25 |
+
|
| 26 |
+
inf_dtype = torch.float16
|
| 27 |
+
def preprocess_video(video_path, max_frames=200, target_size=(336, 560)):
|
| 28 |
+
"""Load and preprocess video.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
video_path: Path to input video
|
| 32 |
+
max_frames: Maximum number of frames to process
|
| 33 |
+
target_size: Target size (H, W) for inference
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Tuple of (video_array, fps)
|
| 37 |
+
"""
|
| 38 |
+
video_arr = mediapy.read_video(video_path)
|
| 39 |
+
video_fps = video_arr.metadata.fps
|
| 40 |
+
num_frames = video_arr.shape[0]
|
| 41 |
+
|
| 42 |
+
# Truncate if too long
|
| 43 |
+
if num_frames > max_frames:
|
| 44 |
+
print(f"Video is too long. Truncating to first {max_frames} frames.")
|
| 45 |
+
video_arr = video_arr[:max_frames]
|
| 46 |
+
|
| 47 |
+
# Resize to target size
|
| 48 |
+
video_arr = mediapy.resize_video(video_arr, target_size)
|
| 49 |
+
|
| 50 |
+
return np.array(video_arr), video_fps
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_inference(model, video):
|
| 54 |
+
"""Run tracking inference on video.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
model: CoWTracker model
|
| 58 |
+
video: Video array [T, H, W, C] in uint8
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Tuple of (tracks, visibilities, confidences)
|
| 62 |
+
- tracks: [T, H, W, 2]
|
| 63 |
+
- visibilities: [T, H, W]
|
| 64 |
+
- confidences: [T, H, W]
|
| 65 |
+
"""
|
| 66 |
+
device = next(model.parameters()).device
|
| 67 |
+
|
| 68 |
+
# Convert to tensor [T, C, H, W]
|
| 69 |
+
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2).float().to(device)
|
| 70 |
+
T, C, H, W = video_tensor.shape
|
| 71 |
+
print(f"Video size: {H}x{W}")
|
| 72 |
+
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
with torch.amp.autocast(device_type="cuda", dtype=inf_dtype):
|
| 77 |
+
predictions = model.forward(video=video_tensor, queries=None)
|
| 78 |
+
|
| 79 |
+
tracks = predictions["track"][0].cpu()
|
| 80 |
+
visibility = predictions["vis"][0].cpu()
|
| 81 |
+
confidence = predictions["conf"][0].cpu()
|
| 82 |
+
|
| 83 |
+
visconf = visibility * confidence
|
| 84 |
+
return tracks, visconf > 0.1, visconf
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def create_visualization(video, tracks, visibilities, rate=8, fps=30, show_bkg=True):
|
| 88 |
+
"""Create visualization video.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
video: Video array [T, H, W, C]
|
| 92 |
+
tracks: Tracks [T, H, W, 2]
|
| 93 |
+
visibilities: Visibility mask [T, H, W]
|
| 94 |
+
rate: Subsampling rate for points
|
| 95 |
+
fps: Output video fps
|
| 96 |
+
show_bkg: Whether to show background
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Painted video frames [T, H, W, C]
|
| 100 |
+
"""
|
| 101 |
+
T, H, W, _ = video.shape
|
| 102 |
+
|
| 103 |
+
# Subsample tracks for visualization
|
| 104 |
+
tracks_np = tracks.permute(1, 2, 0, 3).reshape(-1, T, 2).numpy() # [HW, T, 2]
|
| 105 |
+
vis_np = visibilities.permute(1, 2, 0).reshape(-1, T).numpy() # [HW, T]
|
| 106 |
+
|
| 107 |
+
# Subsample
|
| 108 |
+
tracks_sub = tracks_np.reshape(H, W, T, 2)[::rate, ::rate].reshape(-1, T, 2)
|
| 109 |
+
vis_sub = vis_np.reshape(H, W, T)[::rate, ::rate].reshape(-1, T)
|
| 110 |
+
|
| 111 |
+
# Paint tracks
|
| 112 |
+
painted_video = paint_point_track(
|
| 113 |
+
video, tracks_sub, vis_sub, rate=rate, show_bkg=show_bkg
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return painted_video
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main():
|
| 120 |
+
parser = argparse.ArgumentParser(description="CoWTracker Inference Demo")
|
| 121 |
+
parser.add_argument("--video", type=str, required=True, help="Path to input video")
|
| 122 |
+
parser.add_argument("--output", type=str, default=None, help="Path to output video")
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--checkpoint",
|
| 125 |
+
type=str,
|
| 126 |
+
default=None,
|
| 127 |
+
help="Path to model checkpoint",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--rate", type=int, default=8, help="Subsampling rate for visualization"
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--max_frames", type=int, default=200, help="Maximum number of frames"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument("--no_bkg", action="store_true", help="Hide background in visualization")
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
|
| 138 |
+
# Set output path
|
| 139 |
+
if args.output is None:
|
| 140 |
+
base_name = os.path.splitext(os.path.basename(args.video))[0]
|
| 141 |
+
args.output = f"{base_name}_tracked.mp4"
|
| 142 |
+
|
| 143 |
+
print("=" * 60)
|
| 144 |
+
print("CoWTracker Inference Demo")
|
| 145 |
+
print("=" * 60)
|
| 146 |
+
|
| 147 |
+
# Load model
|
| 148 |
+
print("\n[1/4] Loading model...")
|
| 149 |
+
model = CoWTracker.from_checkpoint(
|
| 150 |
+
args.checkpoint,
|
| 151 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 152 |
+
dtype=inf_dtype if torch.cuda.is_available() else torch.float32,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Load video
|
| 156 |
+
print("\n[2/4] Loading video...")
|
| 157 |
+
video, fps = preprocess_video(args.video, max_frames=args.max_frames)
|
| 158 |
+
print(f"Video shape: {video.shape}, FPS: {fps}")
|
| 159 |
+
|
| 160 |
+
# Run inference
|
| 161 |
+
print("\n[3/4] Running inference...")
|
| 162 |
+
tracks, visibilities, confidences = run_inference(model, video)
|
| 163 |
+
print(f"Tracks shape: {tracks.shape}")
|
| 164 |
+
|
| 165 |
+
# Create visualization
|
| 166 |
+
print("\n[4/4] Creating visualization...")
|
| 167 |
+
painted_video = create_visualization(
|
| 168 |
+
video, tracks, visibilities, rate=args.rate, fps=fps, show_bkg=not args.no_bkg
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Save output
|
| 172 |
+
mediapy.write_video(args.output, painted_video, fps=fps)
|
| 173 |
+
print(f"\nSaved output to: {args.output}")
|
| 174 |
+
print("=" * 60)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
| 179 |
+
|
docs/logo.jpg
ADDED
|
|
Git LFS Details
|
docs/teaser.jpg
ADDED
|
|
Git LFS Details
|
environments.yml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: cowtracker
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
- conda-forge
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.12
|
| 9 |
+
- pip
|
| 10 |
+
- pip:
|
| 11 |
+
# Core deep learning
|
| 12 |
+
- torch>=2.0.0
|
| 13 |
+
- torchvision>=0.15.0
|
| 14 |
+
- xformers
|
| 15 |
+
- timm
|
| 16 |
+
|
| 17 |
+
# Numerical / scientific
|
| 18 |
+
- numpy
|
| 19 |
+
- scipy
|
| 20 |
+
- einops
|
| 21 |
+
|
| 22 |
+
# Image / video processing
|
| 23 |
+
- opencv-python
|
| 24 |
+
- Pillow
|
| 25 |
+
- mediapy
|
| 26 |
+
- matplotlib
|
| 27 |
+
|
| 28 |
+
# Model hub
|
| 29 |
+
- huggingface_hub
|
| 30 |
+
|
| 31 |
+
# Gradio demo
|
| 32 |
+
- gradio
|
| 33 |
+
|
| 34 |
+
# Optional: development tools
|
| 35 |
+
- ipython
|
| 36 |
+
- ipdb
|
output.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:989c1e01c2f3eb2e55387ba74c78497d3aba3805c63dc0eeb7f581754465bf7c
|
| 3 |
+
size 2550238
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core deep learning
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
xformers==0.0.33.post1
|
| 5 |
+
timm
|
| 6 |
+
|
| 7 |
+
# Numerical / scientific
|
| 8 |
+
numpy
|
| 9 |
+
scipy
|
| 10 |
+
einops
|
| 11 |
+
|
| 12 |
+
# Image / video processing
|
| 13 |
+
opencv-python-headless
|
| 14 |
+
Pillow
|
| 15 |
+
mediapy
|
| 16 |
+
matplotlib
|
| 17 |
+
|
| 18 |
+
# Model hub
|
| 19 |
+
huggingface_hub
|
| 20 |
+
|
| 21 |
+
# Gradio demo
|
| 22 |
+
gradio>=4.0.0
|
| 23 |
+
|
| 24 |
+
# HuggingFace Spaces (ZeroGPU support)
|
| 25 |
+
spaces
|
| 26 |
+
|
videos/apple.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7f48c5cfb1479e1dbc1df2373d5cad4f55c198bbdb379da0ece10087971542a
|
| 3 |
+
size 1219872
|
videos/bear.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eeffab1780be601b19b2097be81a8c2d4fa2b624ac1028be0a32191d25acca0f
|
| 3 |
+
size 893943
|
videos/bmx-bumps.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4d4aa73e0342d8dc08c4a7e3c9ea10e46507d363f0363f3e38bfb3ececa1588
|
| 3 |
+
size 3094667
|
videos/cows.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4ca9ba3b3f720142917dc20935b03c9bbdc629e55502955526daccec567170d
|
| 3 |
+
size 5282840
|
videos/lab-coat.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf43f05f1a011e6cf376d4dc17b18b249613fa5e082df9e3af941fa012c34c9e
|
| 3 |
+
size 1850114
|
videos/longboard.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13c977306e09a3c0766952497b7e51d5accfb5b15bdcb5ac32a1c1afc7893f67
|
| 3 |
+
size 2879038
|
videos/motocross-jump.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f757c903177154b7eceb4b9fe5386bb38cba20ef4da8645cbfc450e0ac39ffef
|
| 3 |
+
size 1343986
|