zlai commited on
Commit
715f79d
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ *.egg
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ eggs/
11
+ .eggs/
12
+
13
+ # Temporary files
14
+ tmp/
15
+ *.tmp
16
+ *.temp
17
+
18
+ # OS files
19
+ .DS_Store
20
+ .DS_Store?
21
+ ._*
22
+ Thumbs.db
23
+
24
+ # IDE
25
+ .vscode/
26
+ .idea/
27
+ *.swp
28
+ *.swo
29
+
30
+ # Jupyter
31
+ .ipynb_checkpoints/
32
+
33
+ # Model checkpoints (use LFS for large files)
34
+ # *.pth
35
+ # *.pt
36
+
37
+ # Environment
38
+ .env
39
+ .venv/
40
+ venv/
41
+ ENV/
42
+
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "cowtracker/thirdparty/DepthAnythingV2"]
2
+ path = cowtracker/thirdparty/DepthAnythingV2
3
+ url = https://github.com/DepthAnything/Depth-Anything-V2.git
4
+ [submodule "cowtracker/thirdparty/vggt"]
5
+ path = cowtracker/thirdparty/vggt
6
+ url = https://github.com/facebookresearch/vggt.git
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to cowtracker
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to cowtracker, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FAIR Noncommercial Research License
2
+ v1 Last Updated: December 22, 2025
3
+
4
+ “Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement.
5
+
6
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein.
7
+
8
+
9
+ “Documentation” means the specifications, manuals and documentation accompanying
10
+ Research Materials distributed by Meta.
11
+
12
+
13
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
14
+
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+
18
+ “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
19
+
20
+ “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement.
21
+
22
+ By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement.
23
+
24
+
25
+ 1. License Rights and Redistribution.
26
+
27
+
28
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials.
29
+
30
+ b. Redistribution and Use.
31
+ i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
32
+
33
+
34
+ ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
35
+
36
+
37
+ iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication.
38
+
39
+
40
+ iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
41
+ 2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
42
+
43
+
44
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
47
+
48
+ 5. Intellectual Property.
49
+
50
+
51
+ a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
52
+
53
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials.
54
+
55
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
56
+
57
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
58
+
59
+
60
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
61
+
62
+
63
+ FAIR Acceptable Use Policy
64
+
65
+ The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all.
66
+
67
+ As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials.
68
+
69
+ Prohibited Uses
70
+
71
+ You agree you will not use, or allow others to use, Research Materials to:
72
+
73
+ Violate the law or others’ rights, including to:
74
+ Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
75
+ Violence or terrorism
76
+ Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
77
+ Human trafficking, exploitation, and sexual violence
78
+ The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
79
+ Sexual solicitation
80
+ Any other criminal activity
81
+
82
+ Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
83
+
84
+ Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
85
+
86
+ Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
87
+
88
+ Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
89
+
90
+ Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials
91
+
92
+ Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
93
+
94
+ 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following:
95
+
96
+ Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
97
+
98
+ Guns and illegal weapons (including weapon development)
99
+
100
+ Illegal drugs and regulated/controlled substances
101
+
102
+ Operation of critical infrastructure, transportation technologies, or heavy machinery
103
+
104
+ Self-harm or harm to others, including suicide, cutting, and eating disorders
105
+
106
+ Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
107
+
108
+ 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following:
109
+
110
+ Generating, promoting, or furthering fraud or the creation or promotion of disinformation
111
+
112
+ Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
113
+
114
+ Generating, promoting, or further distributing spam
115
+
116
+ Impersonating another individual without consent, authorization, or legal right
117
+
118
+ Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials are human-generated
119
+
120
+ Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
121
+
122
+ 4. Fail to appropriately disclose to end users any known dangers of your Research Materials.
123
+
124
+ Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform].
README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CoWTracker
3
+ emoji: 🐮
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ suggested_hardware: a10g-small
12
+ short_description: Dense Point Tracking with Cost-Volume Free Warping
13
+ ---
14
+
15
+ # 🐮 CoWTracker
16
+
17
+ **Cost-Volume Free Warping-Based Dense Point Tracking**
18
+
19
+ Zihang Lai<sup>1,2</sup>, Eldar Insafutdinov<sup>1</sup>, Edgar Sucar<sup>1</sup>, Andrea Vedaldi<sup>1,2</sup>
20
+
21
+ <sup>1</sup>Visual Geometry Group, University of Oxford &nbsp;&nbsp; <sup>2</sup>Meta AI
22
+
23
+ ---
24
+
25
+ Upload a video and CoWTracker will track every pixel through time, visualizing the motion with colorful trajectories.
26
+
27
+ ## Features
28
+
29
+ - **Dense Tracking**: Track all pixels simultaneously
30
+ - **Bidirectional**: Track forwards and backwards from any query frame
31
+ - **Interactive**: Choose query frame and visualization settings
32
+ - **Fast**: Efficient warping-based architecture
33
+
34
+ ## Links
35
+
36
+ - [Project Page](https://cowtracker.github.io/)
37
+ - [GitHub](https://github.com/facebookresearch/cowtracker/)
38
+ - [Paper](#)
39
+
40
+ ## Usage
41
+
42
+ 1. Upload a video (or select an example)
43
+ 2. Click "Process Video"
44
+ 3. Select query frame using the slider
45
+ 4. Click "Start Tracking"
46
+ 5. Adjust visualization settings as needed
47
+
48
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ CoWTracker Gradio Demo.
10
+
11
+ Interactive web demo for dense point tracking using CoWTracker.
12
+
13
+ Usage:
14
+ python app.py
15
+ python app.py --checkpoint /path/to/model.pth --port 8086
16
+ """
17
+
18
+ import glob
19
+ import os
20
+ import tempfile
21
+ import uuid
22
+ from typing import Optional
23
+
24
+ import gradio as gr
25
+ import spaces
26
+ import matplotlib
27
+ import mediapy
28
+ import numpy as np
29
+ import PIL.Image
30
+ import torch
31
+
32
+ from cowtracker import CoWTracker
33
+ from cowtracker.utils.padding import (
34
+ apply_padding,
35
+ compute_padding_params,
36
+ remove_padding_and_scale_back,
37
+ )
38
+ from cowtracker.utils.visualization import (
39
+ get_2d_colors,
40
+ get_colors_from_cmap,
41
+ paint_point_track,
42
+ )
43
+
44
+ # --- Constants ---
45
+ PREVIEW_WIDTH = 1024
46
+ PREVIEW_HEIGHT = 1024
47
+ FRAME_LIMIT = 512
48
+ # Default checkpoint: None means use the model's default HuggingFace URL
49
+ DEFAULT_CHECKPOINT = None
50
+
51
+ # --- Model Initialization ---
52
+
53
+
54
+ def initialize_model(checkpoint_path: Optional[str] = None):
55
+ """Initialize and load the CoWTracker model once at startup.
56
+
57
+ Args:
58
+ checkpoint_path: Path to local checkpoint file.
59
+ If None, downloads from HuggingFace Hub.
60
+ """
61
+ device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ dtype = torch.float16 if device == "cuda" else torch.float32
63
+
64
+ ckpt_path = checkpoint_path if checkpoint_path is not None else DEFAULT_CHECKPOINT
65
+ if ckpt_path:
66
+ print(f"Initializing CoWTracker model from {ckpt_path}...")
67
+ else:
68
+ print("Initializing CoWTracker model from HuggingFace Hub...")
69
+
70
+ model = CoWTracker.from_checkpoint(
71
+ ckpt_path,
72
+ device=device,
73
+ dtype=dtype,
74
+ )
75
+
76
+ print("Model initialized successfully!")
77
+ return model
78
+
79
+
80
+ # Initialize model once at module level
81
+ GLOBAL_MODEL = None
82
+
83
+
84
+ def get_model():
85
+ """Get the global model, initializing if needed."""
86
+ global GLOBAL_MODEL
87
+ if GLOBAL_MODEL is None:
88
+ GLOBAL_MODEL = initialize_model()
89
+ return GLOBAL_MODEL
90
+
91
+
92
+ # --- Core Logic Functions ---
93
+
94
+
95
+ def preprocess_video_input(video_path):
96
+ """Process uploaded video for tracking."""
97
+ if not video_path:
98
+ return None
99
+
100
+ video_arr = mediapy.read_video(video_path)
101
+ video_fps = video_arr.metadata.fps
102
+ num_frames = video_arr.shape[0]
103
+
104
+ if num_frames > FRAME_LIMIT:
105
+ gr.Warning(
106
+ f"Video is too long. Truncating to first {FRAME_LIMIT} frames.", duration=5
107
+ )
108
+ video_arr = video_arr[:FRAME_LIMIT]
109
+ num_frames = FRAME_LIMIT
110
+
111
+ height, width = video_arr.shape[1:3]
112
+ if height > width:
113
+ new_height, new_width = PREVIEW_HEIGHT, int(PREVIEW_WIDTH * width / height)
114
+ else:
115
+ new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH
116
+
117
+ # Resize logic to keep manageable size
118
+ if new_height * new_width > 768 * 1024:
119
+ new_height = new_height * 3 // 4
120
+ new_width = new_width * 3 // 4
121
+
122
+ # Make divisible by 16 for ffmpeg compatibility
123
+ new_height, new_width = new_height // 16 * 16, new_width // 16 * 16
124
+
125
+ preview_video = mediapy.resize_video(video_arr, (new_height, new_width))
126
+ input_video = preview_video # using preview size for processing
127
+
128
+ preview_video = np.array(preview_video)
129
+ input_video = np.array(input_video)
130
+
131
+ return (
132
+ video_arr,
133
+ preview_video,
134
+ preview_video.copy(),
135
+ input_video,
136
+ video_fps,
137
+ preview_video[0],
138
+ gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True),
139
+ gr.update(interactive=True),
140
+ )
141
+
142
+
143
+ def choose_frame(frame_num, video_preview_array):
144
+ """Select frame for preview."""
145
+ if video_preview_array is None:
146
+ return None
147
+ return video_preview_array[int(frame_num)]
148
+
149
+
150
+ def paint_video(
151
+ video_preview,
152
+ query_frame,
153
+ video_fps,
154
+ tracks,
155
+ visibs,
156
+ rate=1,
157
+ show_bkg=True,
158
+ cmap="gist_rainbow",
159
+ ):
160
+ """Paint tracks onto video and save to file."""
161
+ T, H, W, _ = video_preview.shape
162
+
163
+ # Get colors based on colormap choice
164
+ if cmap == "bremm":
165
+ xy0 = tracks[:, query_frame]
166
+ colors = get_2d_colors(xy0, H, W)
167
+ else:
168
+ query_count = tracks.shape[0]
169
+ colors = get_colors_from_cmap(query_count, cmap)
170
+
171
+ painted_video = paint_point_track(
172
+ video_preview, tracks, visibs, colors, rate=rate, show_bkg=show_bkg
173
+ )
174
+
175
+ # Save video to temp directory
176
+ video_file_name = uuid.uuid4().hex + ".mp4"
177
+ video_path = os.path.join(tempfile.gettempdir(), "cowtracker_output")
178
+ video_file_path = os.path.join(video_path, video_file_name)
179
+ os.makedirs(video_path, exist_ok=True)
180
+
181
+ # Cleanup old jpgs
182
+ for f in glob.glob(os.path.join(video_path, "*.jpg")):
183
+ os.remove(f)
184
+
185
+ # Save frames and compile with ffmpeg
186
+ for ti in range(T):
187
+ temp_out_f = "%s/%03d.jpg" % (video_path, ti)
188
+ im = PIL.Image.fromarray(painted_video[ti])
189
+ im.save(temp_out_f)
190
+
191
+ os.system(
192
+ f'ffmpeg -y -hide_banner -loglevel error -f image2 -framerate {video_fps} '
193
+ f'-pattern_type glob -i "{video_path}/*.jpg" -c:v libx264 -crf 20 '
194
+ f'-pix_fmt yuv420p {video_file_path}'
195
+ )
196
+
197
+ # Cleanup used jpgs
198
+ for ti in range(T):
199
+ temp_out_f = "%s/%03d.jpg" % (video_path, ti)
200
+ if os.path.exists(temp_out_f):
201
+ os.remove(temp_out_f)
202
+
203
+ return video_file_path
204
+
205
+ @spaces.GPU
206
+ def update_vis(
207
+ rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs
208
+ ):
209
+ """Update visualization with new settings."""
210
+ if video_preview is None or len(tracks) == 0:
211
+ return None
212
+ T, H, W, _ = video_preview.shape
213
+ tracks_ = tracks.reshape(H, W, T, 2)[::rate, ::rate].reshape(-1, T, 2)
214
+ visibs_ = visibs.reshape(H, W, T)[::rate, ::rate].reshape(-1, T)
215
+ return paint_video(
216
+ video_preview,
217
+ query_frame,
218
+ video_fps,
219
+ tracks_,
220
+ visibs_,
221
+ rate=rate,
222
+ show_bkg=show_bkg,
223
+ cmap=cmap,
224
+ )
225
+
226
+
227
+ @spaces.GPU
228
+ def track(video_preview, video_input, video_fps, query_frame, rate, show_bkg, cmap):
229
+ """Run tracking on video with bidirectional propagation."""
230
+ device = "cuda" if torch.cuda.is_available() else "cpu"
231
+ dtype = torch.float16 if device == "cuda" else torch.float32
232
+
233
+ video_tensor = torch.tensor(video_input).unsqueeze(0).to(dtype)
234
+
235
+ # Use the globally initialized model
236
+ model = get_model()
237
+ print("Using pre-loaded model for tracking...")
238
+
239
+ video_tensor = video_tensor.permute(0, 1, 4, 2, 3)
240
+ _, T, _, H, W = video_tensor.shape
241
+
242
+ # Store original resolution
243
+ orig_H, orig_W = H, W
244
+
245
+ # Configure inference size and compute padding parameters
246
+ inf_H, inf_W = 336, 560
247
+ skip_upscaling = True
248
+
249
+ print(f"Original video size: {orig_H}x{orig_W}")
250
+ print(f"Inference size: {inf_H}x{inf_W}")
251
+
252
+ # Compute padding parameters
253
+ padding_info = compute_padding_params(
254
+ orig_H, orig_W, inf_H, inf_W, skip_upscaling=skip_upscaling
255
+ )
256
+ print(f"Scale factor: {padding_info['scale']:.4f}")
257
+ if padding_info["upscaling_skipped"]:
258
+ print(
259
+ f"Upscaling skipped (scale > 1.0) - using original size: {orig_H}x{orig_W}"
260
+ )
261
+ else:
262
+ print(
263
+ f"Scaled size (before padding): {padding_info['scaled_H']}x{padding_info['scaled_W']}"
264
+ )
265
+ print(
266
+ f"Padding: top={padding_info['pad_top']}, bottom={padding_info['pad_bottom']}, "
267
+ f"left={padding_info['pad_left']}, right={padding_info['pad_right']}"
268
+ )
269
+
270
+ torch.cuda.empty_cache()
271
+
272
+ # Initialize output tensors for INFERENCE resolution
273
+ traj_maps_e = torch.zeros(
274
+ (1, T, inf_H, inf_W, 2), dtype=torch.float32, device="cpu"
275
+ )
276
+ visconf_maps_e = torch.zeros(
277
+ (1, T, inf_H, inf_W), dtype=torch.float32, device="cpu"
278
+ )
279
+
280
+ with torch.no_grad():
281
+ # Forward pass
282
+ if query_frame < T - 1:
283
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
284
+ # Apply padding to forward video
285
+ forward_video = video_tensor[0, query_frame:]
286
+ forward_video_padded = apply_padding(forward_video, padding_info).to(device)
287
+
288
+ predictions = model.forward(
289
+ video=forward_video_padded,
290
+ queries=None,
291
+ )
292
+
293
+ # Extract dense predictions (at INFERENCE resolution)
294
+ tracks_dense = predictions["track"][0] # (T_forward, inf_H, inf_W, 2)
295
+ visibility_dense = predictions["vis"][0] # (T_forward, inf_H, inf_W)
296
+ confidence_dense = predictions["conf"][0] # (T_forward, inf_H, inf_W)
297
+
298
+ # Store forward predictions
299
+ T_forward = tracks_dense.shape[0]
300
+ traj_maps_e[0, query_frame : query_frame + T_forward] = (
301
+ tracks_dense.cpu()
302
+ )
303
+ visconf_maps_e[0, query_frame : query_frame + T_forward] = (
304
+ visibility_dense * confidence_dense
305
+ ).cpu()
306
+
307
+ # Backward pass
308
+ if query_frame > 0:
309
+ with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
310
+ # Flip video for backward tracking and apply padding
311
+ backward_video = video_tensor[0, : query_frame + 1].flip([0])
312
+ backward_video_padded = apply_padding(
313
+ backward_video, padding_info
314
+ ).to(device)
315
+
316
+ predictions = model.forward(
317
+ video=backward_video_padded,
318
+ queries=None,
319
+ )
320
+
321
+ # Extract dense predictions (at INFERENCE resolution)
322
+ tracks_dense = predictions["track"][0] # (T_backward, inf_H, inf_W, 2)
323
+ visibility_dense = predictions["vis"][0] # (T_backward, inf_H, inf_W)
324
+ confidence_dense = predictions["conf"][0] # (T_backward, inf_H, inf_W)
325
+
326
+ # Flip back to original temporal order
327
+ backward_tracks = tracks_dense.flip([0]).cpu()
328
+ backward_visconf = (visibility_dense * confidence_dense).flip([0]).cpu()
329
+
330
+ # Store backward predictions (excluding query frame if needed)
331
+ end_idx = query_frame if query_frame < T - 1 else query_frame + 1
332
+ traj_maps_e[0, :end_idx] = backward_tracks[:end_idx]
333
+ visconf_maps_e[0, :end_idx] = backward_visconf[:end_idx]
334
+
335
+ # Remove padding and scale back to original resolution
336
+ print(f"Removing padding and scaling back to {orig_H}x{orig_W}")
337
+ tracks_final, _, confidence_final = remove_padding_and_scale_back(
338
+ traj_maps_e[0], # (T, inf_H, inf_W, 2)
339
+ torch.ones_like(visconf_maps_e[0]), # dummy visibility (not used here)
340
+ visconf_maps_e[0], # (T, inf_H, inf_W)
341
+ padding_info,
342
+ )
343
+ print(f"Tracks shape after unpadding: {tracks_final.shape}")
344
+ print(f"Confidence shape after unpadding: {confidence_final.shape}")
345
+
346
+ # Convert to numpy format
347
+ tracks = tracks_final.permute(1, 2, 0, 3).reshape(-1, T, 2).numpy()
348
+ confs = confidence_final.permute(1, 2, 0).reshape(-1, T).numpy()
349
+ visibs = confs > 0.1
350
+
351
+ return (
352
+ update_vis(
353
+ rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs
354
+ ),
355
+ tracks,
356
+ visibs,
357
+ gr.update(interactive=True),
358
+ gr.update(interactive=True),
359
+ gr.update(interactive=True),
360
+ )
361
+
362
+
363
+ # --- Gradio UI Layout ---
364
+
365
+ custom_css = """
366
+ h1 {text-align: center; margin-bottom: 0 !important;}
367
+ .contain {max-width: 95% !important;}
368
+ #examples-accordion {margin-top: 10px;}
369
+ """
370
+
371
+
372
+ def create_demo():
373
+ """Create and return the Gradio demo interface."""
374
+ with gr.Blocks(title="CoWTracker Demo", theme=gr.themes.Ocean(), css=custom_css) as demo:
375
+ # State Variables
376
+ video_state = gr.State()
377
+ video_queried_preview = gr.State()
378
+ video_preview = gr.State()
379
+ video_input = gr.State()
380
+ video_fps = gr.State(24)
381
+ tracks = gr.State([])
382
+ visibs = gr.State([])
383
+
384
+ # Header
385
+ gr.Markdown(
386
+ """
387
+ <div style="text-align: center; max-width: 800px; margin: 0 auto;">
388
+ <h1 style="font-weight: 900; margin-bottom: 7px;">CoWTracker</h1>
389
+ <p style="margin-bottom: 10px; font-size: 94%">
390
+ Cost-Volume Free Warping-Based Dense Point Tracking.
391
+ <a href='https://cowtracker.github.io/' target='_blank'>Project Page</a> |
392
+ <a href='https://github.com/facebookresearch/cowtracker/' target='_blank'>GitHub</a> |
393
+ <a href='' target='_blank'>Paper</a>
394
+ </p>
395
+ </div>
396
+ """
397
+ )
398
+
399
+ with gr.Row():
400
+ # --- Left Column: Input & Query ---
401
+ with gr.Column(scale=1):
402
+ with gr.Group():
403
+ gr.Markdown("### 1. Upload Video")
404
+ video_in = gr.Video(label="Input Video", format="mp4", height=300)
405
+ submit_btn = gr.Button("Step 1: Process Video", variant="primary")
406
+
407
+ # Query Frame Preview
408
+ with gr.Group():
409
+ gr.Markdown("### 2. Select Query Frame")
410
+ query_frame_slider = gr.Slider(
411
+ minimum=0,
412
+ maximum=100,
413
+ value=0,
414
+ step=1,
415
+ label="Frame Number",
416
+ interactive=False,
417
+ )
418
+ current_frame = gr.Image(
419
+ label="Query Frame Preview",
420
+ type="numpy",
421
+ interactive=False,
422
+ height=300,
423
+ )
424
+
425
+ # --- Right Column: Visualization & Output ---
426
+ with gr.Column(scale=2):
427
+ gr.Markdown("### 3. Configure & Track")
428
+
429
+ with gr.Group():
430
+ with gr.Row():
431
+ rate_radio = gr.Radio(
432
+ [1, 2, 4, 8],
433
+ value=8,
434
+ label="Subsampling Rate",
435
+ interactive=False,
436
+ )
437
+ cmap_radio = gr.Radio(
438
+ ["gist_rainbow", "rainbow", "jet", "turbo"],
439
+ value="gist_rainbow",
440
+ label="Colormap",
441
+ interactive=False,
442
+ )
443
+
444
+ with gr.Row():
445
+ bkg_check = gr.Checkbox(
446
+ value=True, label="Overlay on Video", interactive=False
447
+ )
448
+ track_button = gr.Button(
449
+ "Step 2: Start Tracking", variant="primary", interactive=False
450
+ )
451
+
452
+ # Output takes entire width of this column
453
+ output_video = gr.Video(
454
+ label="Tracking Result",
455
+ interactive=False,
456
+ autoplay=True,
457
+ loop=True,
458
+ height=550,
459
+ )
460
+
461
+ # --- Full Width Row: Examples ---
462
+ with gr.Row():
463
+ with gr.Column():
464
+ video_folder = "videos"
465
+ gr.Markdown("### 📚 Example Videos")
466
+ video_dir = os.path.join(os.path.dirname(__file__), video_folder)
467
+ video_files = []
468
+ if os.path.exists(video_dir):
469
+ for filename in sorted(os.listdir(video_dir)):
470
+ if filename.endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")):
471
+ video_files.append(os.path.join(video_dir, filename))
472
+
473
+ if video_files:
474
+ gr.Examples(
475
+ examples=video_files,
476
+ inputs=[video_in],
477
+ examples_per_page=16,
478
+ )
479
+
480
+ # --- Interaction Logic ---
481
+
482
+ # 1. Submit Video
483
+ submit_btn.click(
484
+ fn=preprocess_video_input,
485
+ inputs=[video_in],
486
+ outputs=[
487
+ video_state,
488
+ video_preview,
489
+ video_queried_preview,
490
+ video_input,
491
+ video_fps,
492
+ current_frame,
493
+ query_frame_slider,
494
+ track_button,
495
+ ],
496
+ queue=False,
497
+ )
498
+
499
+ # 2. Update Preview Frame on Slider Change
500
+ query_frame_slider.change(
501
+ fn=choose_frame,
502
+ inputs=[query_frame_slider, video_queried_preview],
503
+ outputs=[current_frame],
504
+ queue=False,
505
+ )
506
+
507
+ # 3. Run Tracking
508
+ track_button.click(
509
+ fn=track,
510
+ inputs=[
511
+ video_preview,
512
+ video_input,
513
+ video_fps,
514
+ query_frame_slider,
515
+ rate_radio,
516
+ bkg_check,
517
+ cmap_radio,
518
+ ],
519
+ outputs=[
520
+ output_video,
521
+ tracks,
522
+ visibs,
523
+ rate_radio,
524
+ bkg_check,
525
+ cmap_radio,
526
+ ],
527
+ queue=True,
528
+ )
529
+
530
+ # 4. Instant Updates for Visualization Settings (after tracking is done)
531
+ vis_args = [
532
+ rate_radio,
533
+ bkg_check,
534
+ cmap_radio,
535
+ video_preview,
536
+ query_frame_slider,
537
+ video_fps,
538
+ tracks,
539
+ visibs,
540
+ ]
541
+ rate_radio.change(
542
+ fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
543
+ )
544
+ cmap_radio.change(
545
+ fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
546
+ )
547
+ bkg_check.change(
548
+ fn=update_vis, inputs=vis_args, outputs=[output_video], queue=False
549
+ )
550
+
551
+ return demo
552
+
553
+
554
+ if __name__ == "__main__":
555
+ import argparse
556
+
557
+ parser = argparse.ArgumentParser(description="CoWTracker Gradio Demo")
558
+ parser.add_argument(
559
+ "--checkpoint",
560
+ type=str,
561
+ default=None,
562
+ help="Path to model checkpoint",
563
+ )
564
+ parser.add_argument(
565
+ "--port",
566
+ type=int,
567
+ default=7860,
568
+ help="Server port",
569
+ )
570
+ parser.add_argument(
571
+ "--share",
572
+ action="store_true",
573
+ help="Create a public share link",
574
+ )
575
+ args = parser.parse_args()
576
+
577
+ # Initialize model with custom checkpoint if provided
578
+ if args.checkpoint:
579
+ GLOBAL_MODEL = initialize_model(args.checkpoint)
580
+
581
+ print("=" * 60)
582
+ print("Starting CoWTracker Gradio Demo")
583
+ print("=" * 60)
584
+
585
+ demo = create_demo()
586
+ demo.launch(
587
+ share=args.share,
588
+ show_error=True,
589
+ server_port=args.port,
590
+ )
591
+
cowtracker/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CoWTracker: Cost-Volume Free Warping-Based Dense Point Tracking."""
8
+
9
+
10
+ def __getattr__(name):
11
+ """Lazy import to avoid import errors when dependencies are missing."""
12
+ if name == "CoWTracker":
13
+ from cowtracker.models.cowtracker import CoWTracker
14
+
15
+ return CoWTracker
16
+ if name == "CoWTrackerWindowed":
17
+ from cowtracker.models.cowtracker_windowed import CoWTrackerWindowed
18
+
19
+ return CoWTrackerWindowed
20
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
21
+
22
+
23
+ __all__ = ["CoWTracker", "CoWTrackerWindowed"]
cowtracker/heads/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CowTracker heads."""
8
+
9
+ from cowtracker.heads.tracking_head import CowTrackingHead
10
+ from cowtracker.heads.feature_extractor import FeatureExtractor
11
+ import cowtracker.thirdparty # noqa: F401 - sets up vggt path
12
+ from vggt.heads.dpt_head import DPTHead
13
+
14
+ __all__ = ["CowTrackingHead", "FeatureExtractor", "DPTHead"]
cowtracker/heads/feature_extractor.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Feature extraction: DPT + ResNet side features."""
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ import cowtracker.thirdparty # noqa: F401 - sets up vggt path
14
+ from vggt.heads.dpt_head import DPTHead
15
+ from cowtracker.layers.resnet_deconv import ResNet18Deconv
16
+
17
+
18
+ class FeatureExtractor(nn.Module):
19
+ """
20
+ Combined DPT and ResNet feature extractor.
21
+
22
+ Takes aggregated tokens from backbone and raw images,
23
+ outputs combined features for tracking.
24
+ """
25
+
26
+ DIM_IN = 2048 # 2 * embed_dim (1024)
27
+ PATCH_SIZE = 14
28
+ INTERMEDIATE_LAYER_IDX = [4, 11, 17, 23]
29
+
30
+ def __init__(
31
+ self,
32
+ features: int = 128,
33
+ down_ratio: int = 2,
34
+ side_resnet_channels: int = 128,
35
+ ):
36
+ """
37
+ Args:
38
+ features: Number of DPT output features.
39
+ down_ratio: Downsampling ratio relative to input image.
40
+ side_resnet_channels: Number of ResNet side feature channels.
41
+ """
42
+ super().__init__()
43
+
44
+ self.features = features
45
+ self.down_ratio = down_ratio
46
+
47
+ # DPT head for backbone features
48
+ self.dpt_head = DPTHead(
49
+ dim_in=self.DIM_IN,
50
+ patch_size=self.PATCH_SIZE,
51
+ features=features,
52
+ feature_only=True,
53
+ down_ratio=down_ratio,
54
+ pos_embed=False,
55
+ intermediate_layer_idx=self.INTERMEDIATE_LAYER_IDX,
56
+ )
57
+
58
+ # ResNet for raw image features
59
+ self.fnet = ResNet18Deconv(3, side_resnet_channels)
60
+
61
+ self.out_dim = features + side_resnet_channels
62
+
63
+ def forward(
64
+ self,
65
+ aggregated_tokens_list: list,
66
+ images: torch.Tensor,
67
+ patch_start_idx: int,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Extract combined features from backbone tokens and raw images.
71
+
72
+ Args:
73
+ aggregated_tokens_list: List of tokens from aggregator.
74
+ images: Input images [B, S, 3, H, W].
75
+ patch_start_idx: Patch start index for DPT.
76
+
77
+ Returns:
78
+ combined_features: [B, S, C, H_out, W_out] where C = features + side_resnet_channels.
79
+ """
80
+ B, S, _, H_img, W_img = images.shape
81
+
82
+ # DPT features from backbone tokens
83
+ backbone_features = self.dpt_head(aggregated_tokens_list, images, patch_start_idx)
84
+ _, _, _, H_out, W_out = backbone_features.shape
85
+
86
+ # Side ResNet features from raw images
87
+ images_flat = images.view(B * S, 3, H_img, W_img)
88
+ side_features = self.fnet(images_flat)[0]
89
+ _, side_channels, H_side, W_side = side_features.shape
90
+
91
+ # Resize side features to match backbone output if needed
92
+ if H_side != H_out or W_side != W_out:
93
+ side_features = F.interpolate(
94
+ side_features, size=(H_out, W_out), mode="bilinear", align_corners=True
95
+ )
96
+
97
+ side_features = side_features.view(B, S, side_channels, H_out, W_out)
98
+
99
+ return torch.cat([backbone_features, side_features], dim=2)
100
+
cowtracker/heads/tracking_head.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CowTracker tracking head - Warping-based iterative refinement."""
8
+
9
+ from typing import Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from cowtracker.layers.video_transformer import MODEL_CONFIGS, VisionTransformerVideo
16
+ from cowtracker.utils.ops import bilinear_sampler, coords_grid
17
+
18
+
19
+ class CowTrackingHead(nn.Module):
20
+ """
21
+ Warping-based iterative refinement module.
22
+
23
+ Responsibility: features -> (tracks, visibility, confidence)
24
+ Does NOT handle: feature extraction, windowing
25
+ """
26
+
27
+ TEMPORAL_INTERLEAVE_STRIDE = 2
28
+ MAX_FRAMES = 256
29
+ MLP_RATIO = 4.0
30
+ REFINE_PATCH_SIZE = 4
31
+
32
+ def __init__(
33
+ self,
34
+ feature_dim: int,
35
+ down_ratio: int = 2,
36
+ warp_iters: int = 5,
37
+ warp_model: str = "vits",
38
+ warp_vit_num_blocks: int = None,
39
+ ):
40
+ """
41
+ Args:
42
+ feature_dim: Input feature dimension (features + side_resnet_channels).
43
+ down_ratio: Feature downsampling ratio relative to original image.
44
+ warp_iters: Number of Warping-based iterative refinement iterations.
45
+ warp_model: Model configuration for video transformer.
46
+ warp_vit_num_blocks: Number of transformer blocks (None = use default).
47
+ """
48
+ super().__init__()
49
+
50
+ self.warp_iters = warp_iters
51
+ self.down_ratio = down_ratio
52
+
53
+ # Warping-based iterative refinement iteration dimension
54
+ self.iter_dim = MODEL_CONFIGS[warp_model]["features"]
55
+
56
+ # Video transformer for temporal attention
57
+ self.refine_net = VisionTransformerVideo(
58
+ warp_model,
59
+ self.iter_dim,
60
+ patch_size=self.REFINE_PATCH_SIZE,
61
+ temporal_interleave_stride=self.TEMPORAL_INTERLEAVE_STRIDE,
62
+ max_frames=self.MAX_FRAMES,
63
+ mlp_ratio=self.MLP_RATIO,
64
+ attn_dropout=0.0,
65
+ proj_dropout=0.0,
66
+ drop_path=0.0,
67
+ num_blocks=warp_vit_num_blocks,
68
+ )
69
+
70
+ # Feature processing layers
71
+ self.fmap_conv = nn.Conv2d(feature_dim, self.iter_dim, 1, 1, 0, bias=True)
72
+ self.hidden_conv = nn.Conv2d(self.iter_dim * 2, self.iter_dim, 1, 1, 0, bias=True)
73
+ self.warp_linear = nn.Conv2d(3 * self.iter_dim + 2, self.iter_dim, 1, 1, 0, bias=True)
74
+ self.refine_transform = nn.Conv2d(self.iter_dim // 2 * 3, self.iter_dim, 1, 1, 0, bias=True)
75
+
76
+ # Upsampling weights
77
+ self.upsample_weight = nn.Sequential(
78
+ nn.Conv2d(self.iter_dim, 2 * self.iter_dim, 3, padding=1, bias=True),
79
+ nn.ReLU(inplace=True),
80
+ nn.Conv2d(2 * self.iter_dim, (down_ratio**2) * 9, 1, padding=0, bias=True),
81
+ )
82
+
83
+ # Flow + visibility + confidence head
84
+ self.flow_head = nn.Sequential(
85
+ nn.Conv2d(self.iter_dim, 2 * self.iter_dim, 3, padding=1, bias=True),
86
+ nn.ReLU(inplace=True),
87
+ nn.Conv2d(2 * self.iter_dim, 4, 1, padding=0, bias=True),
88
+ )
89
+
90
+ print(f"CowTrackingHead initialized: iter_dim={self.iter_dim}, warp_iters={warp_iters}")
91
+
92
+ def forward(
93
+ self,
94
+ features: torch.Tensor,
95
+ image_size: Tuple[int, int],
96
+ first_frame_features: torch.Tensor = None,
97
+ ) -> dict:
98
+ """
99
+ Run Warping-based iterative refinement.
100
+
101
+ Args:
102
+ features: Extracted features [B, S, C, H, W].
103
+ image_size: Original image size (H_img, W_img) for upsampling.
104
+ first_frame_features: Optional first frame features [B, 1, C, H, W]
105
+ for cross-window tracking.
106
+
107
+ Returns:
108
+ dict with:
109
+ - track: Dense tracks [B, S, H_img, W_img, 2].
110
+ - vis: Visibility scores [B, S, H_img, W_img].
111
+ - conf: Confidence scores [B, S, H_img, W_img].
112
+ """
113
+ B, S, _, H, W = features.shape
114
+ H_img, W_img = image_size
115
+
116
+ # Project features to iteration dimension
117
+ fmap = self.fmap_conv(features.view(B * S, -1, H, W)).view(B, S, -1, H, W)
118
+
119
+ # Frame 0 reference features
120
+ if first_frame_features is not None:
121
+ frame0_fmap = self.fmap_conv(first_frame_features.view(B, -1, H, W)).view(B, 1, -1, H, W)
122
+ else:
123
+ frame0_fmap = fmap[:, 0:1]
124
+ frame0_expanded = frame0_fmap.expand(B, S, -1, H, W)
125
+
126
+ # Initialize hidden state from concatenation of frame0 and current features
127
+ net = self.hidden_conv(
128
+ torch.cat([frame0_expanded, fmap], dim=2).view(B * S, -1, H, W)
129
+ ).view(B, S, -1, H, W)
130
+
131
+ # Initialize flow to zero
132
+ flow = torch.zeros(B, S, 2, H, W, device=features.device, dtype=features.dtype)
133
+
134
+ # Iterative refinement
135
+ for _ in range(self.warp_iters):
136
+ flow = flow.detach()
137
+
138
+ # Compute warped coordinates
139
+ coords = coords_grid(B * S, H, W, device=features.device).to(fmap.dtype).view(B, S, 2, H, W)
140
+ coords_warped = coords + flow
141
+
142
+ # Warp features using current flow estimate
143
+ warped_fmap = bilinear_sampler(
144
+ fmap.view(B * S, -1, H, W), coords_warped.view(B * S, 2, H, W).permute(0, 2, 3, 1)
145
+ ).view(B, S, -1, H, W)
146
+
147
+ # Build refinement input
148
+ refine_inp = self.warp_linear(
149
+ torch.cat([frame0_expanded, warped_fmap, net, flow], dim=2).view(B * S, -1, H, W)
150
+ ).view(B, S, -1, H, W)
151
+
152
+ # Apply video transformer with temporal attention
153
+ refine_out = self.refine_net(refine_inp)["out"]
154
+
155
+ # Update hidden state
156
+ net = self.refine_transform(
157
+ torch.cat([refine_out.view(B * S, -1, H, W), net.view(B * S, -1, H, W)], dim=1)
158
+ ).view(B, S, -1, H, W)
159
+
160
+ # Predict flow and info update
161
+ update = self.flow_head(net.view(B * S, -1, H, W)).view(B, S, 4, H, W)
162
+ flow = flow + update[:, :, :2]
163
+ info = update[:, :, 2:]
164
+
165
+ # Upsample to original resolution
166
+ weight = 0.25 * self.upsample_weight(net.view(B * S, -1, H, W)).view(B, S, -1, H, W)
167
+ flow_up, info_up = self._upsample_predictions(flow, info, weight)
168
+
169
+ # Convert flow to absolute track coordinates
170
+ tracks = self._flow_to_tracks(flow_up, H_img, W_img)
171
+
172
+ return {
173
+ "track": tracks,
174
+ "vis": torch.sigmoid(info_up[..., 0]),
175
+ "conf": torch.sigmoid(info_up[..., 1]),
176
+ }
177
+
178
+ def _upsample_predictions(
179
+ self,
180
+ flow: torch.Tensor,
181
+ info: torch.Tensor,
182
+ weight: torch.Tensor,
183
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
184
+ """Upsample flow and info using learned convex combination."""
185
+ B, S, _, H, W = flow.shape
186
+
187
+ flow_ups, info_ups = [], []
188
+ for t in range(S):
189
+ f_up, i_up = self._upsample_single(flow[:, t], info[:, t], weight[:, t])
190
+ flow_ups.append(f_up)
191
+ info_ups.append(i_up)
192
+
193
+ return torch.stack(flow_ups, dim=1), torch.stack(info_ups, dim=1)
194
+
195
+ def _upsample_single(
196
+ self,
197
+ flow: torch.Tensor,
198
+ info: torch.Tensor,
199
+ mask: torch.Tensor,
200
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
201
+ """Upsample single frame using soft convex combination."""
202
+ N, _, H, W = flow.shape
203
+ C = info.shape[1]
204
+ factor = self.down_ratio
205
+
206
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
207
+ mask = torch.softmax(mask, dim=2)
208
+
209
+ up_flow = F.unfold(factor * flow, [3, 3], padding=1).view(N, 2, 9, 1, 1, H, W)
210
+ up_info = F.unfold(info, [3, 3], padding=1).view(N, C, 9, 1, 1, H, W)
211
+
212
+ up_flow = torch.sum(mask * up_flow, dim=2).permute(0, 1, 4, 2, 5, 3)
213
+ up_info = torch.sum(mask * up_info, dim=2).permute(0, 1, 4, 2, 5, 3)
214
+
215
+ return (
216
+ up_flow.reshape(N, 2, factor * H, factor * W).permute(0, 2, 3, 1),
217
+ up_info.reshape(N, C, factor * H, factor * W).permute(0, 2, 3, 1),
218
+ )
219
+
220
+ def _flow_to_tracks(
221
+ self,
222
+ flow: torch.Tensor,
223
+ H_img: int,
224
+ W_img: int,
225
+ ) -> torch.Tensor:
226
+ """Convert flow to absolute track coordinates."""
227
+ B, S = flow.shape[:2]
228
+ device, dtype = flow.device, flow.dtype
229
+
230
+ # Create coordinate grid
231
+ y, x = torch.meshgrid(
232
+ torch.arange(H_img, device=device, dtype=dtype),
233
+ torch.arange(W_img, device=device, dtype=dtype),
234
+ indexing="ij",
235
+ )
236
+ coords = torch.stack([x, y], dim=-1).unsqueeze(0).unsqueeze(0).expand(B, S, -1, -1, -1)
237
+
238
+ # Normalize flow relative to frame 0 during inference
239
+ if not self.training:
240
+ flow = flow - flow[:, 0:1]
241
+ flow[:, 0] = 0
242
+
243
+ return coords + flow
cowtracker/inference/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Inference utilities for CowTracker."""
8
+
9
+ from cowtracker.inference.windowed import WindowedInference
10
+
11
+ __all__ = ["WindowedInference"]
12
+
cowtracker/inference/windowed.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Windowed inference for long video processing."""
8
+
9
+ from typing import Dict, List, Tuple
10
+
11
+
12
+ class WindowedInference:
13
+ """
14
+ Manages windowed inference for long videos.
15
+
16
+ Handles window computation, memory frame selection, and prediction merging.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ window_len: int = 100,
22
+ stride: int = 100,
23
+ num_memory_frames: int = 10,
24
+ ):
25
+ """
26
+ Args:
27
+ window_len: Number of frames per window.
28
+ stride: Step size between windows.
29
+ num_memory_frames: Maximum number of memory frames to use.
30
+ """
31
+ self.window_len = window_len
32
+ self.stride = stride
33
+ self.num_memory_frames = num_memory_frames
34
+
35
+ def compute_windows(self, total_frames: int) -> List[Tuple[int, int]]:
36
+ """
37
+ Compute all window (start, end) indices.
38
+
39
+ Args:
40
+ total_frames: Total number of frames in the video.
41
+
42
+ Returns:
43
+ List of (start, end) tuples for each window.
44
+ """
45
+ S = self.window_len
46
+ step = self.stride
47
+
48
+ if total_frames <= S:
49
+ return [(0, total_frames)]
50
+
51
+ windows = []
52
+ start = 0
53
+ while start < total_frames:
54
+ end = min(start + S, total_frames)
55
+ windows.append((start, end))
56
+ if end == total_frames:
57
+ break
58
+ start += step
59
+
60
+ return windows
61
+
62
+ def select_memory_frames(
63
+ self,
64
+ window_idx: int,
65
+ window_start: int,
66
+ ) -> List[int]:
67
+ """
68
+ Select memory frame indices using hybrid strategy.
69
+
70
+ Strategy combines:
71
+ - First frame (always included for global reference)
72
+ - Recent frames (temporal continuity)
73
+ - Uniformly sampled middle frames (long-range context)
74
+
75
+ Args:
76
+ window_idx: Current window index.
77
+ window_start: Start frame index of current window.
78
+
79
+ Returns:
80
+ Sorted list of memory frame indices.
81
+ """
82
+ if window_idx == 0:
83
+ return []
84
+
85
+ memory_indices = [0] # Always include first frame
86
+
87
+ # Recent frames for temporal continuity
88
+ for offset in [2, 1]:
89
+ idx = window_start - offset
90
+ if idx > 0 and idx not in memory_indices:
91
+ memory_indices.append(idx)
92
+
93
+ # Uniform sampling from middle history for long-range context
94
+ if window_start > 10:
95
+ mid_start, mid_end = 5, window_start - 3
96
+ step = (mid_end - mid_start) / 6
97
+ for i in range(5):
98
+ idx = int(mid_start + (i + 1) * step)
99
+ if idx not in memory_indices:
100
+ memory_indices.append(idx)
101
+
102
+ # Limit to maximum number of memory frames
103
+ if len(memory_indices) > self.num_memory_frames:
104
+ memory_indices = sorted(memory_indices)[-self.num_memory_frames :]
105
+
106
+ return sorted(memory_indices)
107
+
108
+ def merge_predictions(
109
+ self,
110
+ window_idx: int,
111
+ window_start: int,
112
+ window_end: int,
113
+ window_pred: Dict,
114
+ accumulated: Dict,
115
+ ) -> None:
116
+ """
117
+ Merge window predictions into accumulated results.
118
+
119
+ Handles overlapping regions by using only non-overlapping parts
120
+ from subsequent windows.
121
+
122
+ Args:
123
+ window_idx: Current window index.
124
+ window_start: Start frame index.
125
+ window_end: End frame index.
126
+ window_pred: Predictions for current window (track, vis, conf).
127
+ accumulated: Accumulated predictions to update in-place.
128
+ """
129
+ S_actual = window_end - window_start
130
+
131
+ if window_idx > 0 and self.stride < self.window_len:
132
+ # Has overlap with previous window - only take non-overlapping part
133
+ overlap_len = min(self.window_len - self.stride, S_actual)
134
+ if overlap_len < S_actual:
135
+ start_offset = overlap_len
136
+ for key in ["track", "vis", "conf"]:
137
+ accumulated[key][:, window_start + start_offset : window_end] = window_pred[key][
138
+ :, start_offset:S_actual
139
+ ]
140
+ else:
141
+ # No overlap or first window - take everything
142
+ for key in ["track", "vis", "conf"]:
143
+ accumulated[key][:, window_start:window_end] = window_pred[key][:, :S_actual]
144
+
cowtracker/layers/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Network layers and backbone modules for CowTracker."""
8
+
9
+ from cowtracker.layers.temporal_attention import TemporalSelfAttentionBlock
10
+ from cowtracker.layers.video_transformer import (
11
+ MODEL_CONFIGS,
12
+ VisionTransformerVideo,
13
+ FlashAttention3,
14
+ replace_attention_with_flash3,
15
+ )
16
+ from cowtracker.layers.patch_embed import PatchEmbed
17
+
18
+ __all__ = [
19
+ "TemporalSelfAttentionBlock",
20
+ "MODEL_CONFIGS",
21
+ "VisionTransformerVideo",
22
+ "FlashAttention3",
23
+ "replace_attention_with_flash3",
24
+ "PatchEmbed",
25
+ ]
26
+
cowtracker/layers/dpt_head.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Custom DPTHead with intermediate feature extraction support.
9
+
10
+ This module imports the base components from the Depth-Anything-V2 submodule
11
+ and provides a modified DPTHead that can return intermediate features.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ # Import base components from submodule
19
+ from cowtracker.thirdparty.DepthAnythingV2.depth_anything_v2.util.blocks import (
20
+ FeatureFusionBlock,
21
+ _make_scratch,
22
+ )
23
+
24
+
25
+ def _make_fusion_block(features, use_bn, size=None):
26
+ return FeatureFusionBlock(
27
+ features,
28
+ nn.ReLU(False),
29
+ deconv=False,
30
+ bn=use_bn,
31
+ expand=False,
32
+ align_corners=True,
33
+ size=size,
34
+ )
35
+
36
+
37
+ class DPTHead(nn.Module):
38
+ """
39
+ DPT decoder head with support for returning intermediate features.
40
+
41
+ This is a modified version that:
42
+ - Removes output_conv2 (final depth prediction layers)
43
+ - Removes resConfUnit1 from refinenet4
44
+ - Supports returning intermediate feature maps via return_intermediate flag
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ in_channels,
50
+ features=256,
51
+ use_bn=False,
52
+ out_channels=[256, 512, 1024, 1024],
53
+ use_clstoken=False
54
+ ):
55
+ super(DPTHead, self).__init__()
56
+
57
+ self.use_clstoken = use_clstoken
58
+
59
+ self.projects = nn.ModuleList([
60
+ nn.Conv2d(
61
+ in_channels=in_channels,
62
+ out_channels=out_channel,
63
+ kernel_size=1,
64
+ stride=1,
65
+ padding=0,
66
+ ) for out_channel in out_channels
67
+ ])
68
+
69
+ self.resize_layers = nn.ModuleList([
70
+ nn.ConvTranspose2d(
71
+ in_channels=out_channels[0],
72
+ out_channels=out_channels[0],
73
+ kernel_size=4,
74
+ stride=4,
75
+ padding=0),
76
+ nn.ConvTranspose2d(
77
+ in_channels=out_channels[1],
78
+ out_channels=out_channels[1],
79
+ kernel_size=2,
80
+ stride=2,
81
+ padding=0),
82
+ nn.Identity(),
83
+ nn.Conv2d(
84
+ in_channels=out_channels[3],
85
+ out_channels=out_channels[3],
86
+ kernel_size=3,
87
+ stride=2,
88
+ padding=1)
89
+ ])
90
+
91
+ if use_clstoken:
92
+ self.readout_projects = nn.ModuleList()
93
+ for _ in range(len(self.projects)):
94
+ self.readout_projects.append(
95
+ nn.Sequential(
96
+ nn.Linear(2 * in_channels, in_channels),
97
+ nn.GELU()))
98
+
99
+ self.scratch = _make_scratch(
100
+ out_channels,
101
+ features,
102
+ groups=1,
103
+ expand=False,
104
+ )
105
+
106
+ self.scratch.stem_transpose = None
107
+
108
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
109
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
110
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
111
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
112
+
113
+ head_features_1 = features
114
+
115
+ self.scratch.output_conv1 = nn.Conv2d(
116
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
117
+ )
118
+
119
+ # Remove resConfUnit1 from refinenet4 (not needed for intermediate feature extraction)
120
+ del self.scratch.refinenet4.resConfUnit1
121
+
122
+ def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
123
+ """
124
+ Forward pass through the DPT head.
125
+
126
+ Args:
127
+ out_features: List of intermediate features from the encoder
128
+ patch_h: Height in patches
129
+ patch_w: Width in patches
130
+ return_intermediate: If True, return intermediate feature maps
131
+
132
+ Returns:
133
+ If return_intermediate=True:
134
+ (out, path_1, path_2, path_3, path_4) - output and intermediate features
135
+ Else:
136
+ out - final output only
137
+ """
138
+ out = []
139
+ for i, x in enumerate(out_features):
140
+ if self.use_clstoken:
141
+ x, cls_token = x[0], x[1]
142
+ readout = cls_token.unsqueeze(1).expand_as(x)
143
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
144
+ else:
145
+ x = x[0]
146
+
147
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
148
+
149
+ x = self.projects[i](x)
150
+ x = self.resize_layers[i](x)
151
+
152
+ out.append(x)
153
+
154
+ layer_1, layer_2, layer_3, layer_4 = out
155
+
156
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
157
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
158
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
159
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
160
+
161
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
162
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
163
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
164
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
165
+
166
+ out = self.scratch.output_conv1(path_1)
167
+
168
+ if return_intermediate:
169
+ return out, path_1, path_2, path_3, path_4
170
+ else:
171
+ out = F.relu(out)
172
+ return out
173
+
cowtracker/layers/patch_embed.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
90
+
cowtracker/layers/resnet_deconv.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """ResNet18-style encoder-decoder for image features."""
8
+
9
+ import torch.nn as nn
10
+
11
+
12
+ class resconv(nn.Module):
13
+ """Residual convolution block."""
14
+
15
+ def __init__(self, inp, oup, k=3, s=1):
16
+ super(resconv, self).__init__()
17
+ self.conv = nn.Sequential(
18
+ nn.GELU(),
19
+ nn.Conv2d(inp, oup, kernel_size=k, stride=s, padding=k // 2, bias=True),
20
+ nn.GELU(),
21
+ nn.Conv2d(oup, oup, kernel_size=3, stride=1, padding=1, bias=True),
22
+ )
23
+ if inp != oup or s != 1:
24
+ self.skip_conv = nn.Conv2d(
25
+ inp, oup, kernel_size=1, stride=s, padding=0, bias=True
26
+ )
27
+ else:
28
+ self.skip_conv = nn.Identity()
29
+
30
+ def forward(self, x):
31
+ return self.conv(x) + self.skip_conv(x)
32
+
33
+
34
+ class ResNet18Deconv(nn.Module):
35
+ """ResNet18-style encoder-decoder for image features."""
36
+
37
+ def __init__(self, inp, oup):
38
+ super(ResNet18Deconv, self).__init__()
39
+ self.ds1 = resconv(inp, 64, k=7, s=2)
40
+ self.conv1 = resconv(64, 64, k=3, s=1)
41
+ self.conv2 = resconv(64, 128, k=3, s=2)
42
+ self.conv3 = resconv(128, 256, k=3, s=2)
43
+ self.conv4 = resconv(256, 512, k=3, s=2)
44
+ self.up_4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0, bias=True)
45
+ self.proj_3 = resconv(256, 256, k=3, s=1)
46
+ self.up_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0, bias=True)
47
+ self.proj_2 = resconv(128, 128, k=3, s=1)
48
+ self.up_2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0, bias=True)
49
+ self.proj_1 = resconv(64, oup, k=3, s=1)
50
+
51
+ def forward(self, x):
52
+ out_1 = self.ds1(x)
53
+ out_1 = self.conv1(out_1)
54
+ out_2 = self.conv2(out_1)
55
+ out_3 = self.conv3(out_2)
56
+ out_4 = self.conv4(out_3)
57
+ out_3 = self.proj_3(out_3 + self.up_4(out_4))
58
+ out_2 = self.proj_2(out_2 + self.up_3(out_3))
59
+ out_1 = self.proj_1(out_1 + self.up_2(out_2))
60
+ return [out_1, out_2, out_3, out_4]
61
+
cowtracker/layers/temporal_attention.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Enhanced Cross Attention Block implementation.
9
+ Self-contained version with all necessary components inline.
10
+ """
11
+
12
+ from typing import Callable, Optional
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn, Tensor
17
+
18
+
19
+ # ============================================================================
20
+ # Inline helper modules (to avoid external dependencies)
21
+ # ============================================================================
22
+
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample."""
26
+
27
+ def __init__(self, drop_prob: float = 0.0):
28
+ super().__init__()
29
+ self.drop_prob = drop_prob
30
+
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ if self.drop_prob == 0.0 or not self.training:
33
+ return x
34
+ keep_prob = 1 - self.drop_prob
35
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
36
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
37
+ if keep_prob > 0.0:
38
+ random_tensor.div_(keep_prob)
39
+ return x * random_tensor
40
+
41
+
42
+ class LayerScale(nn.Module):
43
+ """Layer scale module."""
44
+
45
+ def __init__(self, dim: int, init_values: float = 1e-5):
46
+ super().__init__()
47
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ return x * self.gamma
51
+
52
+
53
+ class Mlp(nn.Module):
54
+ """MLP as used in Vision Transformer."""
55
+
56
+ def __init__(
57
+ self,
58
+ in_features: int,
59
+ hidden_features: int = None,
60
+ out_features: int = None,
61
+ act_layer: Callable[..., nn.Module] = nn.GELU,
62
+ drop: float = 0.0,
63
+ bias: bool = True,
64
+ ):
65
+ super().__init__()
66
+ out_features = out_features or in_features
67
+ hidden_features = hidden_features or in_features
68
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
69
+ self.act = act_layer()
70
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
71
+ self.drop = nn.Dropout(drop)
72
+
73
+ def forward(self, x: Tensor) -> Tensor:
74
+ x = self.fc1(x)
75
+ x = self.act(x)
76
+ x = self.drop(x)
77
+ x = self.fc2(x)
78
+ x = self.drop(x)
79
+ return x
80
+
81
+
82
+ class MemEffAttention(nn.Module):
83
+ """Memory efficient self-attention using PyTorch's scaled_dot_product_attention."""
84
+
85
+ def __init__(
86
+ self,
87
+ dim: int,
88
+ num_heads: int = 8,
89
+ qkv_bias: bool = True,
90
+ proj_bias: bool = True,
91
+ attn_drop: float = 0.0,
92
+ proj_drop: float = 0.0,
93
+ norm_layer: nn.Module = nn.LayerNorm,
94
+ qk_norm: bool = False,
95
+ fused_attn: bool = True,
96
+ rope=None,
97
+ ) -> None:
98
+ super().__init__()
99
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
100
+ self.num_heads = num_heads
101
+ self.head_dim = dim // num_heads
102
+ self.scale = self.head_dim**-0.5
103
+ self.fused_attn = fused_attn
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
107
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
108
+ self.attn_drop = nn.Dropout(attn_drop)
109
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
110
+ self.proj_drop = nn.Dropout(proj_drop)
111
+ self.rope = rope
112
+
113
+ def forward(self, x: Tensor, pos=None) -> Tensor:
114
+ B, N, C = x.shape
115
+ qkv = (
116
+ self.qkv(x)
117
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
118
+ .permute(2, 0, 3, 1, 4)
119
+ )
120
+ q, k, v = qkv.unbind(0)
121
+ q, k = self.q_norm(q), self.k_norm(k)
122
+
123
+ if self.rope is not None and pos is not None:
124
+ q = self.rope(q, pos)
125
+ k = self.rope(k, pos)
126
+
127
+ if self.fused_attn:
128
+ x = F.scaled_dot_product_attention(
129
+ q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0
130
+ )
131
+ else:
132
+ q = q * self.scale
133
+ attn = q @ k.transpose(-2, -1)
134
+ attn = attn.softmax(dim=-1)
135
+ attn = self.attn_drop(attn)
136
+ x = attn @ v
137
+
138
+ x = x.transpose(1, 2).reshape(B, N, C)
139
+ x = self.proj(x)
140
+ x = self.proj_drop(x)
141
+ return x
142
+
143
+
144
+ # ============================================================================
145
+ # Main attention block classes
146
+ # ============================================================================
147
+
148
+
149
+ class SelfAttentionBlock(nn.Module):
150
+ """
151
+ Self attention block using the same architecture as CrossAttentionBlock but for self attention.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ num_heads: int,
158
+ mlp_ratio: float = 4.0,
159
+ qkv_bias: bool = True,
160
+ proj_bias: bool = True,
161
+ ffn_bias: bool = True,
162
+ drop: float = 0.0,
163
+ attn_drop: float = 0.0,
164
+ init_values=None,
165
+ drop_path: float = 0.0,
166
+ act_layer: Callable[..., nn.Module] = nn.GELU,
167
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
168
+ ffn_layer: Callable[..., nn.Module] = Mlp,
169
+ qk_norm: bool = False,
170
+ fused_attn: bool = True,
171
+ rope=None,
172
+ ) -> None:
173
+ super().__init__()
174
+
175
+ self.norm1 = norm_layer(dim)
176
+
177
+ # Use standard Attention for self attention
178
+ self.attn = MemEffAttention(
179
+ dim,
180
+ num_heads=num_heads,
181
+ qkv_bias=qkv_bias,
182
+ proj_bias=proj_bias,
183
+ attn_drop=attn_drop,
184
+ proj_drop=drop,
185
+ qk_norm=qk_norm,
186
+ fused_attn=fused_attn,
187
+ rope=rope,
188
+ )
189
+
190
+ self.ls1 = (
191
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
192
+ )
193
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
194
+
195
+ self.norm2 = norm_layer(dim)
196
+ mlp_hidden_dim = int(dim * mlp_ratio)
197
+ self.mlp = ffn_layer(
198
+ in_features=dim,
199
+ hidden_features=mlp_hidden_dim,
200
+ act_layer=act_layer,
201
+ drop=drop,
202
+ bias=ffn_bias,
203
+ )
204
+ self.ls2 = (
205
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
206
+ )
207
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
208
+
209
+ self.sample_drop_ratio = drop_path
210
+
211
+ def forward(self, x: Tensor, pos=None) -> Tensor:
212
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
213
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
214
+
215
+ def ffn_residual_func(x: Tensor) -> Tensor:
216
+ return self.ls2(self.mlp(self.norm2(x)))
217
+
218
+ if self.training and self.sample_drop_ratio > 0.0:
219
+ x = x + self.drop_path1(attn_residual_func(x, pos))
220
+ x = x + self.drop_path2(ffn_residual_func(x))
221
+ else:
222
+ x = x + attn_residual_func(x, pos)
223
+ x = x + ffn_residual_func(x)
224
+
225
+ return x
226
+
227
+
228
+ class TemporalSelfAttentionBlock(nn.Module):
229
+ """
230
+ Temporal self attention block that applies self-attention across time for each spatial position.
231
+ Input: [B, S, N, C] -> Output: [B, S, N, C]
232
+ For each position n, performs self-attention across all time steps.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ dim: int,
238
+ num_heads: int,
239
+ mlp_ratio: float = 4.0,
240
+ qkv_bias: bool = True,
241
+ proj_bias: bool = True,
242
+ ffn_bias: bool = True,
243
+ drop: float = 0.0,
244
+ attn_drop: float = 0.0,
245
+ init_values=None,
246
+ drop_path: float = 0.0,
247
+ act_layer: Callable[..., nn.Module] = nn.GELU,
248
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
249
+ ffn_layer: Callable[..., nn.Module] = Mlp,
250
+ qk_norm: bool = False,
251
+ fused_attn: bool = True,
252
+ rope=None,
253
+ ) -> None:
254
+ super().__init__()
255
+
256
+ self.self_attn_block = SelfAttentionBlock(
257
+ dim,
258
+ num_heads,
259
+ mlp_ratio,
260
+ qkv_bias,
261
+ proj_bias,
262
+ ffn_bias,
263
+ drop,
264
+ attn_drop,
265
+ init_values,
266
+ drop_path,
267
+ act_layer,
268
+ norm_layer,
269
+ ffn_layer,
270
+ qk_norm,
271
+ fused_attn,
272
+ rope,
273
+ )
274
+
275
+ def forward(self, x: Tensor, pos=None):
276
+ """
277
+ Apply temporal self-attention across time for each spatial position.
278
+
279
+ Args:
280
+ x: Input tensor of shape [B, S, N, C]
281
+ pos: Position encoding
282
+
283
+ Returns:
284
+ Output tensor of shape [B, S, N, C]
285
+ """
286
+ if len(x.shape) != 4:
287
+ raise ValueError(
288
+ f"TemporalSelfAttentionBlock expects 4D input [B, S, N, C], got {x.shape}"
289
+ )
290
+
291
+ B, S, N, C = x.shape
292
+
293
+ if S <= 1:
294
+ # No temporal dimension to attend over, return input unchanged
295
+ return x
296
+
297
+ # Reshape to [B*N, S, C] to process each spatial position independently
298
+ x_reshaped = x.permute(0, 2, 1, 3).reshape(B * N, S, C) # [B*N, S, C]
299
+
300
+ # Apply temporal self-attention
301
+ x_reshaped = self.self_attn_block(x_reshaped, pos=pos)
302
+
303
+ # Reshape back to [B, S, N, C]
304
+ x = x_reshaped.reshape(B, N, S, C).permute(0, 2, 1, 3)
305
+
306
+ return x
307
+
cowtracker/layers/video_transformer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import timm
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import xformers.ops as xops
14
+ from timm.models.vision_transformer import Attention as TimmAttention
15
+ from cowtracker.layers.temporal_attention import TemporalSelfAttentionBlock
16
+ from cowtracker.layers.patch_embed import PatchEmbed
17
+ from cowtracker.layers.dpt_head import DPTHead
18
+
19
+ print("timm version: ", timm.__version__)
20
+
21
+ def get_1d_sincos_pos_embed_from_grid(
22
+ embed_dim: int, pos: torch.Tensor
23
+ ) -> torch.Tensor:
24
+ """
25
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
26
+
27
+ Args:
28
+ - embed_dim: The embedding dimension.
29
+ - pos: The position to generate the embedding from.
30
+
31
+ Returns:
32
+ - emb: The generated 1D positional embedding.
33
+ """
34
+ assert embed_dim % 2 == 0
35
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
36
+ omega /= embed_dim / 2.0
37
+ omega = 1.0 / 10000**omega # (D/2,)
38
+
39
+ pos = pos.reshape(-1) # (M,)
40
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
41
+
42
+ emb_sin = torch.sin(out) # (M, D/2)
43
+ emb_cos = torch.cos(out) # (M, D/2)
44
+
45
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
46
+ return emb[None].float()
47
+
48
+
49
+ def _get_flash_attention_ops():
50
+ """Automatically detect GPU and return appropriate flash attention ops.
51
+
52
+ Returns Flash Attention 3 ops for H100 (compute capability >= 9.0),
53
+ otherwise returns Flash Attention 2 ops.
54
+ """
55
+ if not torch.cuda.is_available():
56
+ return None
57
+
58
+ # Get compute capability of current device
59
+ major, _ = torch.cuda.get_device_capability()
60
+ # print("compute capability: ", torch.cuda.get_device_capability())
61
+ # H100 has compute capability 9.0
62
+ if major >= 9:
63
+ # Use Flash Attention 3 for H100 and newer
64
+ try:
65
+ return (xops.fmha.flash3.FwOp, xops.fmha.flash3.BwOp)
66
+ except AttributeError:
67
+ # Fall back to flash2 if flash3 not available
68
+ print("Flash Attention 3 not available, falling back to Flash Attention 2")
69
+ return (xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
70
+ else:
71
+ # Use Flash Attention 2 for older GPUs
72
+ return (xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)
73
+
74
+
75
+ class FlashAttention3(nn.Module):
76
+ """
77
+ Drop-in replacement for timm.models.vision_transformer.Attention using xformers Flash Attention 3.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ dim: int,
83
+ num_heads: int = 8,
84
+ qkv_bias: bool = False,
85
+ attn_drop: float = 0.0,
86
+ proj_drop: float = 0.0,
87
+ ):
88
+ super().__init__()
89
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
90
+ self.num_heads = num_heads
91
+ self.head_dim = dim // num_heads
92
+ self.scale = self.head_dim**-0.5
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.attn_drop = attn_drop
96
+ self.proj = nn.Linear(dim, dim)
97
+ self.proj_drop = nn.Dropout(proj_drop)
98
+
99
+ # Get Flash Attention ops
100
+ self.flash_ops = _get_flash_attention_ops()
101
+
102
+ def forward(
103
+ self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
104
+ ) -> torch.Tensor:
105
+ B, N, C = x.shape
106
+
107
+ # Compute Q, K, V
108
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
109
+ q, k, v = qkv.unbind(2) # Each is (B, N, num_heads, head_dim)
110
+
111
+ # xformers expects [B, M, H, K] format - we already have it!
112
+ # Use xformers memory_efficient_attention with Flash Attention 3
113
+ x = xops.memory_efficient_attention(
114
+ q,
115
+ k,
116
+ v,
117
+ attn_bias=attn_mask, # Pass attention mask if provided
118
+ p=self.attn_drop if self.training else 0.0,
119
+ scale=self.scale,
120
+ op=self.flash_ops,
121
+ )
122
+
123
+ # Reshape back to [B, N, C]
124
+ x = x.reshape(B, N, C)
125
+
126
+ # Output projection
127
+ x = self.proj(x)
128
+ x = self.proj_drop(x)
129
+
130
+ return x
131
+
132
+
133
+ def replace_attention_with_flash3(model: nn.Module) -> nn.Module:
134
+ """
135
+ Recursively replace all timm.Attention modules with FlashAttention3.
136
+ """
137
+ for name, module in model.named_children():
138
+ if isinstance(module, TimmAttention):
139
+ # Extract parameters from timm Attention
140
+ flash_attn = FlashAttention3(
141
+ dim=module.qkv.in_features,
142
+ num_heads=module.num_heads,
143
+ qkv_bias=module.qkv.bias is not None,
144
+ attn_drop=module.attn_drop.p if hasattr(module, "attn_drop") else 0.0,
145
+ proj_drop=module.proj_drop.p if hasattr(module, "proj_drop") else 0.0,
146
+ )
147
+ # Copy weights from original attention
148
+ flash_attn.qkv.weight.data = module.qkv.weight.data.clone()
149
+ if module.qkv.bias is not None:
150
+ flash_attn.qkv.bias.data = module.qkv.bias.data.clone()
151
+ flash_attn.proj.weight.data = module.proj.weight.data.clone()
152
+ if module.proj.bias is not None:
153
+ flash_attn.proj.bias.data = module.proj.bias.data.clone()
154
+
155
+ # Replace the module
156
+ setattr(model, name, flash_attn)
157
+ # print(
158
+ # f" Replaced attention module '{name}' (dim={module.qkv.in_features}, heads={module.num_heads})"
159
+ # )
160
+ else:
161
+ # Recursively apply to child modules
162
+ replace_attention_with_flash3(module)
163
+
164
+ return model
165
+
166
+
167
+ MODEL_CONFIGS = {
168
+ "vitl": {
169
+ "encoder": "vit_large_patch16_224",
170
+ "features": 256,
171
+ "out_channels": [256, 512, 1024, 1024],
172
+ },
173
+ "vitb": {
174
+ "encoder": "vit_base_patch16_224",
175
+ "features": 128,
176
+ "out_channels": [96, 192, 384, 768],
177
+ },
178
+ "vits": {
179
+ "encoder": "vit_small_patch16_224",
180
+ "features": 64,
181
+ "out_channels": [48, 96, 192, 384],
182
+ },
183
+ "vitt": {
184
+ "encoder": "vit_tiny_patch16_224",
185
+ "features": 32,
186
+ "out_channels": [24, 48, 96, 192],
187
+ },
188
+ }
189
+
190
+ class VisionTransformerVideo(nn.Module):
191
+ """
192
+ Input: (B, T, C, H, W)
193
+ Pipeline: per-frame ViT + interleaved Temporal Attention (across frames)
194
+ Time pos: 1D sinusoidal encoding + linear interpolation for variable T
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ model_name,
200
+ input_dim,
201
+ patch_size=16,
202
+ temporal_interleave_stride=2,
203
+ max_frames=256,
204
+ mlp_ratio=4.0,
205
+ attn_dropout=0.0,
206
+ proj_dropout=0.0,
207
+ drop_path=0.0,
208
+ shared_temporal_block=False,
209
+ num_blocks=None,
210
+ use_flash_attention3=False,
211
+ ):
212
+ super().__init__()
213
+ model = timm.create_model(
214
+ MODEL_CONFIGS[model_name]["encoder"],
215
+ pretrained=False,
216
+ num_classes=0,
217
+ )
218
+ self.intermediate_layer_idx = {
219
+ "vitt": [2, 5, 8, 11],
220
+ "vits": [2, 5, 8, 11],
221
+ "vitb": [2, 5, 8, 11],
222
+ "vitl": [4, 11, 17, 23],
223
+ "vitg": [9, 19, 29, 39],
224
+ }
225
+ self.idx = self.intermediate_layer_idx[model_name]
226
+ self.blks = model.blocks if num_blocks is None else model.blocks[:num_blocks]
227
+
228
+ # Replace attention with Flash Attention 3 if enabled
229
+ if use_flash_attention3:
230
+ self.blks = replace_attention_with_flash3(self.blks)
231
+ num_fa3_modules = sum(
232
+ 1 for m in self.blks.modules() if isinstance(m, FlashAttention3)
233
+ )
234
+ print(
235
+ f"✓ Flash Attention 3 enabled for spatial attention: replaced {num_fa3_modules} attention modules"
236
+ )
237
+
238
+ self.embed_dim = model.embed_dim
239
+ self.input_dim = input_dim
240
+ self.img_size = (224, 224)
241
+ self.patch_size = patch_size
242
+ self.output_dim = MODEL_CONFIGS[model_name]["features"]
243
+
244
+ # Spatial positional embedding (64 corresponds to 224/16 = 14x14)
245
+ self.pos_embed = nn.Parameter(torch.zeros(1, 64, self.embed_dim))
246
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
247
+
248
+ # ====== New: sinusoidal time positional embedding (buffer) ======
249
+ self.max_frames = max_frames
250
+ time_grid = torch.arange(max_frames, dtype=torch.float32) # (T0,)
251
+ time_emb = get_1d_sincos_pos_embed_from_grid(
252
+ self.embed_dim, time_grid
253
+ ) # (1, T0, D)
254
+ # Follow your pattern: register_buffer + interpolate_time_embed
255
+ self.register_buffer("time_emb", time_emb, persistent=False)
256
+
257
+ # Patch embed and DPT head
258
+ self.patch_embed = PatchEmbed(
259
+ img_size=self.img_size,
260
+ patch_size=self.patch_size,
261
+ in_chans=input_dim,
262
+ embed_dim=self.embed_dim,
263
+ )
264
+ self.dpt_head = DPTHead(
265
+ self.embed_dim,
266
+ MODEL_CONFIGS[model_name]["features"],
267
+ out_channels=MODEL_CONFIGS[model_name]["out_channels"],
268
+ )
269
+
270
+ # Temporal block(s)
271
+ num_heads = getattr(model.blocks[0].attn, "num_heads", 8)
272
+ self.shared_temporal_block = shared_temporal_block
273
+
274
+ # Insert a temporal block after every N spatial blocks
275
+ self.temporal_interleave_stride = max(1, int(temporal_interleave_stride))
276
+
277
+ # Calculate how many temporal blocks we need
278
+ num_temporal_blocks = sum(
279
+ 1
280
+ for i in range(len(self.blks))
281
+ if (i + 1) % self.temporal_interleave_stride == 0
282
+ )
283
+
284
+ if shared_temporal_block:
285
+ # Single shared temporal block for all layers
286
+ self.temporal_block = TemporalSelfAttentionBlock(
287
+ dim=self.embed_dim,
288
+ num_heads=num_heads,
289
+ mlp_ratio=mlp_ratio,
290
+ attn_drop=attn_dropout,
291
+ drop=proj_dropout,
292
+ drop_path=drop_path,
293
+ )
294
+ self.temporal_blocks = None
295
+ else:
296
+ # Separate temporal block for each layer
297
+ self.temporal_block = None
298
+ self.temporal_blocks = nn.ModuleList(
299
+ [
300
+ TemporalSelfAttentionBlock(
301
+ dim=self.embed_dim,
302
+ num_heads=num_heads,
303
+ mlp_ratio=mlp_ratio,
304
+ attn_drop=attn_dropout,
305
+ drop=proj_dropout,
306
+ drop_path=drop_path,
307
+ )
308
+ for _ in range(num_temporal_blocks)
309
+ ]
310
+ )
311
+
312
+ # ====== New: interpolate temporal positional embedding ======
313
+ def interpolate_time_embed(self, x_like: torch.Tensor, t: int) -> torch.Tensor:
314
+ """
315
+ x_like: used only to fetch dtype (e.g., fp16)
316
+ Return: time positional embedding of shape (1, t, D)
317
+ """
318
+ previous_dtype = x_like.dtype
319
+ T0 = self.time_emb.shape[1]
320
+ if t == T0:
321
+ return self.time_emb.to(previous_dtype)
322
+ temb = self.time_emb.float() # (1, T0, D)
323
+ temb = F.interpolate(
324
+ temb.permute(0, 2, 1), size=t, mode="linear", align_corners=False
325
+ ).permute(0, 2, 1) # (1, t, D)
326
+ return temb.to(previous_dtype)
327
+
328
+ def interpolate_pos_encoding(self, x, h, w):
329
+ """
330
+ Interpolate the 2D spatial positional encoding to match HxW (in patches).
331
+ """
332
+ previous_dtype = x.dtype
333
+ npatch = x.shape[1]
334
+ N = self.pos_embed.shape[1]
335
+ if npatch == N and w == h:
336
+ return self.pos_embed
337
+ pos_embed = self.pos_embed.float()
338
+ dim = x.shape[-1]
339
+ w0 = w // self.patch_size
340
+ h0 = h // self.patch_size
341
+ sqrt_N = math.sqrt(N)
342
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
343
+ pos_embed = nn.functional.interpolate(
344
+ pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
345
+ scale_factor=(sy, sx),
346
+ mode="bicubic",
347
+ antialias=False,
348
+ )
349
+ assert int(w0) == pos_embed.shape[-1] and int(h0) == pos_embed.shape[-2]
350
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
351
+ return pos_embed.to(previous_dtype)
352
+
353
+ def forward(self, x):
354
+ """
355
+ x: (B, T, C, H, W)
356
+ """
357
+ B, T, C, H, W = x.shape
358
+ # Merge time into batch for per-frame spatial encoding
359
+ x = x.view(B * T, C, H, W)
360
+ x = self.patch_embed(x) # (B*T, Np, D)
361
+
362
+ x = x.view(B, T, *x.shape[1:])
363
+ # Get time positional embedding for current T via linear interpolation: (1, T, D)
364
+ tpos = self.interpolate_time_embed(x, T).unsqueeze(2) # (1, T, 1, D)
365
+ x = x + tpos # (B, T, Np, D)
366
+ x = x.view(B * T, *x.shape[2:]) # (B*T, Np, D)
367
+
368
+ x = x + self.interpolate_pos_encoding(x, H, W)
369
+
370
+ outputs = []
371
+ temporal_block_idx = 0
372
+ for i in range(len(self.blks)):
373
+ # 1) Spatial self-attention (per frame)
374
+ x = self.blks[i](x) # (B*T, Np, D)
375
+ # 2) Interleave temporal self-attention (across frames, same spatial patch)
376
+ if (i + 1) % self.temporal_interleave_stride == 0:
377
+ x = x.view(B, T, *x.shape[1:])
378
+ if self.shared_temporal_block:
379
+ x = self.temporal_block(x)
380
+ else:
381
+ x = self.temporal_blocks[temporal_block_idx](x)
382
+ temporal_block_idx += 1
383
+ x = x.view(B * T, *x.shape[2:])
384
+ # 3) Collect intermediate features for DPT head
385
+ if i in self.idx:
386
+ outputs.append([x])
387
+
388
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
389
+ # DPT head consumes (B*T, Np, D); here batch is B*T
390
+ out, path_1, path_2, path_3, path_4 = self.dpt_head.forward(
391
+ outputs, patch_h, patch_w, return_intermediate=True
392
+ )
393
+ # Upsample per frame
394
+ out = F.interpolate(
395
+ out, (H, W), mode="bilinear", align_corners=True
396
+ ) # (B*T, Cout, H, W)
397
+
398
+ # Restore (B, T, ...)
399
+ def bt_to_btensor(tensor_or_none):
400
+ if tensor_or_none is None:
401
+ return None
402
+ return tensor_or_none.view(B, T, *tensor_or_none.shape[1:])
403
+
404
+ return {
405
+ "out": out.view(B, T, *out.shape[1:]),
406
+ "path_1": bt_to_btensor(path_1),
407
+ "path_2": bt_to_btensor(path_2),
408
+ "path_3": bt_to_btensor(path_3),
409
+ "path_4": bt_to_btensor(path_4),
410
+ }
411
+
cowtracker/models/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CoWTracker models."""
8
+
9
+
10
+ def __getattr__(name):
11
+ """Lazy import to avoid import errors when dependencies are missing."""
12
+ if name == "CoWTracker":
13
+ from cowtracker.models.cowtracker import CoWTracker
14
+
15
+ return CoWTracker
16
+ if name == "CoWTrackerWindowed":
17
+ from cowtracker.models.cowtracker_windowed import CoWTrackerWindowed
18
+
19
+ return CoWTrackerWindowed
20
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
21
+
22
+
23
+ __all__ = ["CoWTracker", "CoWTrackerWindowed"]
cowtracker/models/cowtracker.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CoWTracker: Simple version for short videos."""
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
12
+
13
+ import cowtracker.thirdparty # noqa: F401 - sets up vggt path
14
+ from vggt.models.aggregator import Aggregator
15
+ from cowtracker.heads.feature_extractor import FeatureExtractor
16
+ from cowtracker.heads.tracking_head import CowTrackingHead
17
+
18
+
19
+ class CoWTracker(nn.Module, PyTorchModelHubMixin):
20
+ """
21
+ CoWTracker simple version: processes entire video at once.
22
+
23
+ Suitable for: short videos / sufficient GPU memory.
24
+ For long videos, use CoWTrackerWindowed instead.
25
+ """
26
+
27
+ # Backbone configuration
28
+ IMG_SIZE = 518
29
+ PATCH_SIZE = 14
30
+ EMBED_DIM = 1024
31
+ PATCH_EMBED = "dinov2_vitl14_reg"
32
+ DEPTH = 24
33
+
34
+ # Default HuggingFace repo for model loading
35
+ DEFAULT_REPO_ID = "facebook/cowtracker"
36
+ DEFAULT_FILENAME = "cowtracker_model.pth"
37
+
38
+ def __init__(
39
+ self,
40
+ features: int = 128,
41
+ side_resnet_channels: int = 128,
42
+ down_ratio: int = 2,
43
+ warp_iters: int = 5,
44
+ warp_vit_num_blocks: int = None,
45
+ ):
46
+ """
47
+ Args:
48
+ features: Number of DPT output features.
49
+ side_resnet_channels: Number of ResNet side feature channels.
50
+ down_ratio: Feature downsampling ratio.
51
+ warp_iters: Number of Warping-based iterative refinement iterations.
52
+ warp_vit_num_blocks: Number of transformer blocks (None = default).
53
+ """
54
+ super().__init__()
55
+
56
+ print("Initializing CoWTracker...")
57
+
58
+ # Backbone: VGGT backbone
59
+ self.aggregator = Aggregator(
60
+ img_size=self.IMG_SIZE,
61
+ patch_size=self.PATCH_SIZE,
62
+ embed_dim=self.EMBED_DIM,
63
+ patch_embed=self.PATCH_EMBED,
64
+ depth=self.DEPTH,
65
+ )
66
+
67
+ # High Resolution Feature extraction
68
+ self.feature_extractor = FeatureExtractor(
69
+ features=features,
70
+ down_ratio=down_ratio,
71
+ side_resnet_channels=side_resnet_channels,
72
+ )
73
+
74
+ # Tracking head: warping-based iterative refinement
75
+ self.tracking_head = CowTrackingHead(
76
+ feature_dim=self.feature_extractor.out_dim,
77
+ down_ratio=down_ratio,
78
+ warp_iters=warp_iters,
79
+ warp_vit_num_blocks=warp_vit_num_blocks,
80
+ )
81
+
82
+ print(f" - Features: {features}, Side channels: {side_resnet_channels}")
83
+ print(f" - Warping-based iterative refinement iterations: {warp_iters}")
84
+
85
+ def forward(self, video: torch.Tensor, queries: torch.Tensor = None) -> dict:
86
+ """
87
+ Forward pass for dense tracking.
88
+
89
+ Args:
90
+ video: Input video [B, S, 3, H, W] or [S, 3, H, W] in range [0, 255].
91
+ queries: Optional query points (unused, for API compatibility).
92
+
93
+ Returns:
94
+ dict with:
95
+ - track: Dense tracks [B, S, H, W, 2].
96
+ - vis: Visibility scores [B, S, H, W].
97
+ - conf: Confidence scores [B, S, H, W].
98
+ """
99
+ # Normalize input
100
+ images = video / 255.0
101
+ if images.ndim == 4:
102
+ images = images.unsqueeze(0)
103
+
104
+ B, S, C, H, W = images.shape
105
+
106
+ # Extract backbone tokens
107
+ tokens, patch_idx = self.aggregator(images)
108
+
109
+ # Extract high resolution features
110
+ features = self.feature_extractor(tokens, images, patch_idx)
111
+
112
+ # Run tracking
113
+ predictions = self.tracking_head(features, image_size=(H, W))
114
+
115
+ if not self.training:
116
+ predictions["images"] = images
117
+
118
+ return predictions
119
+
120
+ @staticmethod
121
+ def _remap_legacy_state_dict(state_dict: dict) -> dict:
122
+ """
123
+ Remap legacy checkpoint keys to new model structure.
124
+
125
+ Old structure:
126
+ tracking_head.aggregator.* -> aggregator.*
127
+ tracking_head.feature_extractor.* -> feature_extractor.dpt_head.*
128
+ tracking_head.fnet.* -> feature_extractor.fnet.*
129
+ tracking_head.* (rest) -> tracking_head.*
130
+
131
+ Args:
132
+ state_dict: Original state dict.
133
+
134
+ Returns:
135
+ Remapped state dict.
136
+ """
137
+ new_state_dict = {}
138
+
139
+ for key, value in state_dict.items():
140
+ new_key = key
141
+
142
+ # Remap tracking_head.aggregator -> aggregator
143
+ if key.startswith("tracking_head.aggregator."):
144
+ new_key = key.replace("tracking_head.aggregator.", "aggregator.")
145
+ # Remap tracking_head.feature_extractor -> feature_extractor.dpt_head
146
+ elif key.startswith("tracking_head.feature_extractor."):
147
+ new_key = key.replace(
148
+ "tracking_head.feature_extractor.", "feature_extractor.dpt_head."
149
+ )
150
+ # Remap tracking_head.fnet -> feature_extractor.fnet
151
+ elif key.startswith("tracking_head.fnet."):
152
+ new_key = key.replace("tracking_head.fnet.", "feature_extractor.fnet.")
153
+
154
+ new_state_dict[new_key] = value
155
+
156
+ return new_state_dict
157
+
158
+ @classmethod
159
+ def _load_checkpoint(cls, checkpoint_path: str = None) -> dict:
160
+ """
161
+ Load checkpoint from local path or HuggingFace Hub.
162
+
163
+ Args:
164
+ checkpoint_path: Local file path to checkpoint.
165
+ If None, downloads from default HuggingFace repo.
166
+
167
+ Returns:
168
+ Loaded checkpoint dict.
169
+ """
170
+ import os
171
+
172
+ if checkpoint_path is None:
173
+ # Download from HuggingFace Hub (uses HF_TOKEN env var automatically)
174
+ print(f"Downloading checkpoint from HuggingFace: {cls.DEFAULT_REPO_ID}/{cls.DEFAULT_FILENAME}")
175
+ checkpoint_path = hf_hub_download(
176
+ repo_id=cls.DEFAULT_REPO_ID,
177
+ filename=cls.DEFAULT_FILENAME,
178
+ )
179
+ print(f"Downloaded to: {checkpoint_path}")
180
+ else:
181
+ checkpoint_path = os.path.expanduser(checkpoint_path)
182
+ print(f"Loading checkpoint from local path: {checkpoint_path}")
183
+
184
+ with open(checkpoint_path, "rb") as fp:
185
+ ckpt = torch.load(fp, map_location="cpu")
186
+
187
+ return ckpt
188
+
189
+ @classmethod
190
+ def from_checkpoint(
191
+ cls,
192
+ checkpoint_path: str = None,
193
+ device: str = "cuda",
194
+ dtype=torch.bfloat16,
195
+ ):
196
+ """
197
+ Load model from checkpoint (local path or HuggingFace Hub).
198
+
199
+ Args:
200
+ checkpoint_path: Path to local checkpoint file.
201
+ If None, downloads from default HuggingFace repo.
202
+ device: Target device.
203
+ dtype: Target dtype.
204
+
205
+ Returns:
206
+ Loaded model in eval mode.
207
+ """
208
+ model = cls()
209
+
210
+ ckpt = cls._load_checkpoint(checkpoint_path)
211
+ state_dict = ckpt.get("model", ckpt)
212
+
213
+ # Remap legacy checkpoint keys if needed
214
+ legacy_prefixes = ["tracking_head.feature_extractor.", "tracking_head.aggregator.", "tracking_head.fnet."]
215
+ if any(k.startswith(p) for k in state_dict.keys() for p in legacy_prefixes):
216
+ print("Detected legacy checkpoint format, remapping keys...")
217
+ state_dict = cls._remap_legacy_state_dict(state_dict)
218
+
219
+ msg = model.load_state_dict(state_dict, strict=False)
220
+ print(f"Load message: {msg}")
221
+
222
+ model = model.to(device).to(dtype)
223
+ model.eval()
224
+ for p in model.parameters():
225
+ p.requires_grad = False
226
+
227
+ print("Model loaded successfully!")
228
+ return model
cowtracker/models/cowtracker_windowed.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """CoWTracker Windowed: Full version for long videos."""
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ from cowtracker.models.cowtracker import CoWTracker
14
+ from cowtracker.inference.windowed import WindowedInference
15
+
16
+
17
+ class CoWTrackerWindowed(nn.Module, PyTorchModelHubMixin):
18
+ """
19
+ CoWTracker windowed version: processes video in sliding windows.
20
+
21
+ Suitable for: long videos / limited GPU memory.
22
+ Composes CoWTracker with WindowedInference.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ # Window parameters
28
+ window_len: int = 100,
29
+ stride: int = 100,
30
+ num_memory_frames: int = 10,
31
+ # CoWTracker parameters
32
+ **cow_tracker_kwargs,
33
+ ):
34
+ """
35
+ Args:
36
+ window_len: Number of frames per window.
37
+ stride: Step size between windows.
38
+ num_memory_frames: Maximum number of memory frames.
39
+ **cow_tracker_kwargs: Arguments passed to CoWTracker.
40
+ """
41
+ super().__init__()
42
+
43
+ print(f"Initializing CoWTrackerWindowed: window_len={window_len}, stride={stride}")
44
+
45
+ self.model = CoWTracker(**cow_tracker_kwargs)
46
+ self.windowed = WindowedInference(
47
+ window_len=window_len,
48
+ stride=stride,
49
+ num_memory_frames=num_memory_frames,
50
+ )
51
+
52
+ def forward(self, video: torch.Tensor, queries: torch.Tensor = None) -> dict:
53
+ """
54
+ Forward pass with windowed inference.
55
+
56
+ Args:
57
+ video: Input video [B, S, 3, H, W] or [S, 3, H, W] in range [0, 255].
58
+ queries: Optional query points (unused, for API compatibility).
59
+
60
+ Returns:
61
+ dict with:
62
+ - track: Dense tracks [B, T, H, W, 2].
63
+ - vis: Visibility scores [B, T, H, W].
64
+ - conf: Confidence scores [B, T, H, W].
65
+ """
66
+ # Normalize input
67
+ images = video / 255.0
68
+ if images.ndim == 4:
69
+ images = images.unsqueeze(0)
70
+
71
+ B, T, C, H, W = images.shape
72
+ device, dtype = images.device, images.dtype
73
+
74
+ # Initialize accumulated outputs
75
+ accumulated = {
76
+ "track": torch.zeros((B, T, H, W, 2), device=device, dtype=dtype),
77
+ "vis": torch.zeros((B, T, H, W), device=device, dtype=dtype),
78
+ "conf": torch.zeros((B, T, H, W), device=device, dtype=dtype),
79
+ }
80
+
81
+ windows = self.windowed.compute_windows(T)
82
+ first_frame = images[:, 0:1]
83
+ first_frame_features = None
84
+
85
+ for window_idx, (start, end) in enumerate(windows):
86
+ if not self.training:
87
+ print(f"Processing window {window_idx + 1}/{len(windows)}: frames [{start}, {end})")
88
+
89
+ # Get memory frame indices
90
+ memory_indices = self.windowed.select_memory_frames(window_idx, start)
91
+ if not self.training and memory_indices:
92
+ print(f" Memory frames: {memory_indices}")
93
+
94
+ # Gather frames: first_frame + memory + window
95
+ frames = self._gather_frames(images, first_frame, start, end, memory_indices)
96
+
97
+ # Extract backbone tokens
98
+ tokens, patch_idx = self.model.aggregator(frames)
99
+
100
+ # Extract combined features
101
+ features = self.model.feature_extractor(tokens, frames, patch_idx)
102
+
103
+ # Split features: first_frame | memory | window
104
+ first_frame_features = features[:, 0:1]
105
+ num_memory = len(memory_indices)
106
+ offset = 1 + num_memory
107
+
108
+ # Run tracking on extended features (memory + window), using first_frame as reference
109
+ extended_features = features[:, 1:] # Exclude first_frame from input
110
+ pred = self.model.tracking_head(
111
+ extended_features,
112
+ image_size=(H, W),
113
+ first_frame_features=first_frame_features,
114
+ )
115
+
116
+ # Extract window predictions (remove memory frames from output)
117
+ window_pred = {
118
+ "track": pred["track"][:, num_memory:],
119
+ "vis": pred["vis"][:, num_memory:],
120
+ "conf": pred["conf"][:, num_memory:],
121
+ }
122
+
123
+ # Merge into accumulated results
124
+ self.windowed.merge_predictions(window_idx, start, end, window_pred, accumulated)
125
+
126
+ # Cleanup for memory efficiency
127
+ if not self.training:
128
+ del features, tokens, pred
129
+ torch.cuda.empty_cache()
130
+
131
+ if not self.training:
132
+ accumulated["images"] = images
133
+
134
+ return accumulated
135
+
136
+ def _gather_frames(
137
+ self,
138
+ images: torch.Tensor,
139
+ first_frame: torch.Tensor,
140
+ start: int,
141
+ end: int,
142
+ memory_indices: list,
143
+ ) -> torch.Tensor:
144
+ """Gather first_frame + memory + window frames."""
145
+ parts = [first_frame]
146
+
147
+ if memory_indices:
148
+ parts.append(images[:, memory_indices])
149
+
150
+ parts.append(images[:, start:end])
151
+
152
+ return torch.cat(parts, dim=1)
153
+
154
+ # Proxy properties for convenient access to internal components
155
+ @property
156
+ def aggregator(self):
157
+ return self.model.aggregator
158
+
159
+ @property
160
+ def feature_extractor(self):
161
+ return self.model.feature_extractor
162
+
163
+ @property
164
+ def tracking_head(self):
165
+ return self.model.tracking_head
166
+
167
+ @classmethod
168
+ def from_checkpoint(
169
+ cls,
170
+ checkpoint_path: str = None,
171
+ window_len: int = 100,
172
+ stride: int = 100,
173
+ device: str = "cuda",
174
+ dtype=torch.bfloat16,
175
+ ):
176
+ """
177
+ Load model from checkpoint (local path or HuggingFace Hub).
178
+
179
+ Args:
180
+ checkpoint_path: Path to local checkpoint file.
181
+ If None, downloads from default HuggingFace repo.
182
+ window_len: Number of frames per window.
183
+ stride: Step size between windows.
184
+ device: Target device.
185
+ dtype: Target dtype.
186
+
187
+ Returns:
188
+ Loaded model in eval mode.
189
+ """
190
+ model = cls(window_len=window_len, stride=stride)
191
+
192
+ # Use CoWTracker's checkpoint loading method (handles local path and HuggingFace download)
193
+ ckpt = CoWTracker._load_checkpoint(checkpoint_path)
194
+ state_dict = ckpt.get("model", ckpt)
195
+
196
+ # Remap legacy checkpoint keys if needed (delegate to CoWTracker)
197
+ legacy_prefixes = ["tracking_head.feature_extractor.", "tracking_head.aggregator.", "tracking_head.fnet."]
198
+ if any(k.startswith(p) for k in state_dict.keys() for p in legacy_prefixes):
199
+ print("Detected legacy checkpoint format, remapping keys...")
200
+ state_dict = CoWTracker._remap_legacy_state_dict(state_dict)
201
+
202
+ # Add "model." prefix if checkpoint is from CoWTracker (no prefix)
203
+ # CoWTrackerWindowed wraps CoWTracker as self.model, so keys need "model." prefix
204
+ if not any(k.startswith("model.") for k in state_dict.keys()):
205
+ print("Adding 'model.' prefix to state dict keys...")
206
+ state_dict = {f"model.{k}": v for k, v in state_dict.items()}
207
+
208
+ msg = model.load_state_dict(state_dict, strict=False)
209
+ print(f"Load message: {msg}")
210
+
211
+ model = model.to(device).to(dtype)
212
+ model.eval()
213
+ for p in model.parameters():
214
+ p.requires_grad = False
215
+
216
+ print("Model loaded successfully!")
217
+ return model
218
+
cowtracker/thirdparty/DepthAnythingV2 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit e5a2732d3ea2cddc081d7bfd708fc0bf09f812f1
cowtracker/thirdparty/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Third-party modules.
8
+
9
+ This module sets up sys.path for third-party packages.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+
15
+ # Add vggt to sys.path so that 'from vggt.xxx' imports work
16
+ _vggt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "vggt")
17
+ if _vggt_path not in sys.path:
18
+ sys.path.insert(0, _vggt_path)
19
+
cowtracker/thirdparty/vggt ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 44b3afbd1869d8bde4894dd8ea1e293112dd5eba
cowtracker/utils/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from cowtracker.utils.padding import (
8
+ compute_padding_params,
9
+ apply_padding,
10
+ remove_padding_and_scale_back,
11
+ )
12
+ from cowtracker.utils.visualization import paint_point_track
13
+ from cowtracker.utils.ops import (
14
+ bilinear_sampler,
15
+ coords_grid,
16
+ Padder,
17
+ load_ckpt,
18
+ upflow8,
19
+ )
20
+
21
+ __all__ = [
22
+ "compute_padding_params",
23
+ "apply_padding",
24
+ "remove_padding_and_scale_back",
25
+ "paint_point_track",
26
+ "bilinear_sampler",
27
+ "coords_grid",
28
+ "Padder",
29
+ "load_ckpt",
30
+ "upflow8",
31
+ ]
32
+
cowtracker/utils/ops.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Common operations for tracking: bilinear sampling, coordinate grids, etc."""
8
+
9
+ import cv2
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from scipy import interpolate
14
+
15
+
16
+ def load_ckpt(model, path):
17
+ """Load checkpoint."""
18
+ state_dict = torch.load(path, map_location=torch.device("cpu"))
19
+ model.load_state_dict(state_dict, strict=False)
20
+
21
+
22
+ def resize_data(img1, img2, flow, factor=1.0):
23
+ _, _, h, w = img1.shape
24
+ h = int(h * factor)
25
+ w = int(w * factor)
26
+ img1 = F.interpolate(img1, (h, w), mode="area")
27
+ img2 = F.interpolate(img2, (h, w), mode="area")
28
+ flow = F.interpolate(flow, (h, w), mode="area") * factor
29
+ return img1, img2, flow
30
+
31
+
32
+ class Padder:
33
+ """Pads images such that dimensions are divisible by factor."""
34
+
35
+ def __init__(self, dims, mode="sintel", factor=32):
36
+ self.ht, self.wd = dims[-2:]
37
+ pad_ht = (((self.ht + 8) // factor) + 1) * factor - self.ht
38
+ pad_wd = (((self.wd + 8) // factor) + 1) * factor - self.wd
39
+ if mode == "sintel":
40
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]
41
+ else:
42
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
43
+
44
+ def pad(self, x):
45
+ return F.pad(x, self._pad, mode="constant", value=0)
46
+
47
+ def unpad(self, x):
48
+ ht, wd = x.shape[-2:]
49
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
50
+ return x[..., c[0] : c[1], c[2] : c[3]]
51
+
52
+
53
+ def forward_interpolate(flow):
54
+ flow = flow.detach().cpu().numpy()
55
+ dx, dy = flow[0], flow[1]
56
+
57
+ ht, wd = dx.shape
58
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
59
+
60
+ x1 = x0 + dx
61
+ y1 = y0 + dy
62
+
63
+ x1 = x1.reshape(-1)
64
+ y1 = y1.reshape(-1)
65
+ dx = dx.reshape(-1)
66
+ dy = dy.reshape(-1)
67
+
68
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
69
+ x1 = x1[valid]
70
+ y1 = y1[valid]
71
+ dx = dx[valid]
72
+ dy = dy[valid]
73
+
74
+ flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method="nearest", fill_value=0)
75
+
76
+ flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method="nearest", fill_value=0)
77
+
78
+ flow = np.stack([flow_x, flow_y], axis=0)
79
+ return torch.from_numpy(flow).float()
80
+
81
+
82
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
83
+ """Wrapper for grid_sample, uses pixel coordinates."""
84
+ H, W = img.shape[-2:]
85
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
86
+ xgrid = 2 * xgrid / (W - 1) - 1
87
+ ygrid = 2 * ygrid / (H - 1) - 1
88
+
89
+ grid = torch.cat([xgrid, ygrid], dim=-1)
90
+ img = F.grid_sample(img, grid, align_corners=True)
91
+
92
+ if mask:
93
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
94
+ return img, mask.float()
95
+
96
+ return img
97
+
98
+
99
+ def coords_grid(batch, ht, wd, device):
100
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
101
+ coords = torch.stack(coords[::-1], dim=0).float()
102
+ return coords[None].repeat(batch, 1, 1, 1)
103
+
104
+
105
+ def upflow8(flow, mode="bilinear"):
106
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
107
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
108
+
109
+
110
+ def transform(T, p):
111
+ assert T.shape == (4, 4)
112
+ return np.einsum("H W j, i j -> H W i", p, T[:3, :3]) + T[:3, 3]
113
+
114
+
115
+ def from_homog(x):
116
+ return x[..., :-1] / x[..., [-1]]
117
+
118
+
119
+ def reproject(depth1, pose1, pose2, K1, K2):
120
+ H, W = depth1.shape
121
+ x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
122
+ img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64)
123
+ cam1_coords = np.einsum("H W, H W j, i j -> H W i", depth1, img_1_coords, np.linalg.inv(K1))
124
+ rel_pose = pose2 @ np.linalg.inv(pose1)
125
+ cam2_coords = transform(rel_pose, cam1_coords)
126
+ return from_homog(np.einsum("H W j, i j -> H W i", cam2_coords, K2))
127
+
128
+
129
+ def induced_flow(depth0, depth1, data):
130
+ H, W = depth0.shape
131
+ coords1 = reproject(depth0, data["T0"], data["T1"], data["K0"], data["K1"])
132
+ x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
133
+ coords0 = np.stack([x, y], axis=-1)
134
+ flow_01 = coords1 - coords0
135
+ H, W = depth1.shape
136
+ coords1 = reproject(depth1, data["T1"], data["T0"], data["K1"], data["K0"])
137
+ x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
138
+ coords0 = np.stack([x, y], axis=-1)
139
+ flow_10 = coords1 - coords0
140
+ return flow_01, flow_10
141
+
142
+
143
+ def check_cycle_consistency(flow_01, flow_10):
144
+ H, W = flow_01.shape[:2]
145
+ new_coords = flow_01 + np.stack(np.meshgrid(np.arange(W), np.arange(H), indexing="xy"), axis=-1)
146
+ flow_reprojected = cv2.remap(flow_10, new_coords.astype(np.float32), None, interpolation=cv2.INTER_LINEAR)
147
+ cycle = flow_reprojected + flow_01
148
+ cycle = np.linalg.norm(cycle, axis=-1)
149
+ mask = (cycle < 0.1 * min(H, W)).astype(np.float32)
150
+ return mask
151
+
cowtracker/utils/padding.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Padding utilities for video preprocessing and postprocessing."""
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def compute_padding_params(orig_H, orig_W, inf_H, inf_W, skip_upscaling=False):
14
+ """Compute padding parameters to preserve aspect ratio.
15
+
16
+ Args:
17
+ orig_H: Original height
18
+ orig_W: Original width
19
+ inf_H: Inference height
20
+ inf_W: Inference width
21
+ skip_upscaling: If True and scale > 1, skip upscaling and just pad
22
+
23
+ Returns:
24
+ Dictionary containing:
25
+ - scale: Scale factor that would be applied (1.0 if skipped)
26
+ - scaled_H, scaled_W: Dimensions after scaling (before padding)
27
+ - pad_top, pad_bottom, pad_left, pad_right: Padding amounts
28
+ - orig_H, orig_W: Original dimensions (for reference)
29
+ - upscaling_skipped: Whether upscaling was skipped
30
+ """
31
+ scale = min(inf_H / orig_H, inf_W / orig_W)
32
+
33
+ upscaling_skipped = False
34
+ if skip_upscaling and scale > 1.0:
35
+ scaled_H = orig_H
36
+ scaled_W = orig_W
37
+ upscaling_skipped = True
38
+ else:
39
+ scaled_H = int(orig_H * scale)
40
+ scaled_W = int(orig_W * scale)
41
+
42
+ pad_H = inf_H - scaled_H
43
+ pad_W = inf_W - scaled_W
44
+
45
+ pad_top = pad_H // 2
46
+ pad_bottom = pad_H - pad_top
47
+ pad_left = pad_W // 2
48
+ pad_right = pad_W - pad_left
49
+
50
+ return {
51
+ "scale": scale,
52
+ "scaled_H": scaled_H,
53
+ "scaled_W": scaled_W,
54
+ "pad_top": pad_top,
55
+ "pad_bottom": pad_bottom,
56
+ "pad_left": pad_left,
57
+ "pad_right": pad_right,
58
+ "orig_H": orig_H,
59
+ "orig_W": orig_W,
60
+ "upscaling_skipped": upscaling_skipped,
61
+ }
62
+
63
+
64
+ def apply_padding(rgbs, padding_info):
65
+ """Apply padding to input images to reach inference size.
66
+
67
+ Args:
68
+ rgbs: Input tensor (T, C, H, W)
69
+ padding_info: Dictionary from compute_padding_params
70
+
71
+ Returns:
72
+ Padded tensor (T, C, inf_H, inf_W)
73
+ """
74
+ T, C, H, W = rgbs.shape
75
+ scaled_H = padding_info["scaled_H"]
76
+ scaled_W = padding_info["scaled_W"]
77
+
78
+ if (scaled_H, scaled_W) != (H, W):
79
+ rgbs_scaled = F.interpolate(
80
+ rgbs,
81
+ size=(scaled_H, scaled_W),
82
+ mode="bilinear",
83
+ align_corners=False,
84
+ )
85
+ else:
86
+ rgbs_scaled = rgbs
87
+
88
+ pad_left = padding_info["pad_left"]
89
+ pad_right = padding_info["pad_right"]
90
+ pad_top = padding_info["pad_top"]
91
+ pad_bottom = padding_info["pad_bottom"]
92
+
93
+ rgbs_padded = F.pad(
94
+ rgbs_scaled,
95
+ (pad_left, pad_right, pad_top, pad_bottom),
96
+ mode="constant",
97
+ value=0,
98
+ )
99
+
100
+ return rgbs_padded
101
+
102
+
103
+ def remove_padding_and_scale_back(tracks, visibility, confidence, padding_info):
104
+ """Remove padding from model outputs and scale back to original resolution.
105
+
106
+ Args:
107
+ tracks: Track predictions (T, inf_H, inf_W, 2)
108
+ visibility: Visibility predictions (T, inf_H, inf_W)
109
+ confidence: Confidence predictions (T, inf_H, inf_W)
110
+ padding_info: Dictionary from compute_padding_params
111
+
112
+ Returns:
113
+ Tuple of (tracks, visibility, confidence) scaled to original resolution
114
+ """
115
+ scaled_H = padding_info["scaled_H"]
116
+ scaled_W = padding_info["scaled_W"]
117
+ pad_top = padding_info["pad_top"]
118
+ pad_left = padding_info["pad_left"]
119
+ orig_H = padding_info["orig_H"]
120
+ orig_W = padding_info["orig_W"]
121
+
122
+ tracks_unpadded = tracks[
123
+ :, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W, :
124
+ ]
125
+ visibility_unpadded = visibility[
126
+ :, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W
127
+ ]
128
+ confidence_unpadded = confidence[
129
+ :, pad_top : pad_top + scaled_H, pad_left : pad_left + scaled_W
130
+ ]
131
+
132
+ tracks_unpadded = tracks_unpadded.clone()
133
+ tracks_unpadded[:, :, :, 0] -= pad_left
134
+ tracks_unpadded[:, :, :, 1] -= pad_top
135
+
136
+ if (scaled_H, scaled_W) != (orig_H, orig_W):
137
+ tracks_permuted = tracks_unpadded.permute(0, 3, 1, 2)
138
+ tracks_scaled = F.interpolate(
139
+ tracks_permuted,
140
+ size=(orig_H, orig_W),
141
+ mode="bilinear",
142
+ align_corners=False,
143
+ )
144
+ tracks_final = tracks_scaled.permute(0, 2, 3, 1)
145
+
146
+ tracks_final[:, :, :, 0] *= orig_W / scaled_W
147
+ tracks_final[:, :, :, 1] *= orig_H / scaled_H
148
+
149
+ visibility_final = F.interpolate(
150
+ visibility_unpadded.unsqueeze(1),
151
+ size=(orig_H, orig_W),
152
+ mode="bilinear",
153
+ align_corners=False,
154
+ ).squeeze(1)
155
+
156
+ confidence_final = F.interpolate(
157
+ confidence_unpadded.unsqueeze(1),
158
+ size=(orig_H, orig_W),
159
+ mode="bilinear",
160
+ align_corners=False,
161
+ ).squeeze(1)
162
+ else:
163
+ tracks_final = tracks_unpadded
164
+ visibility_final = visibility_unpadded
165
+ confidence_final = confidence_unpadded
166
+
167
+ return tracks_final, visibility_final, confidence_final
168
+
cowtracker/utils/visualization.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Visualization utilities for point tracking."""
8
+
9
+ import colorsys
10
+ import os
11
+ import random
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import matplotlib
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import torch
18
+
19
+
20
+ # Bremm 2D colormap for position-based coloring
21
+ # This creates a smooth 2D color gradient based on x,y position
22
+ BREMM_COLORMAP = None # Lazy loaded
23
+
24
+
25
+ def _create_bremm_colormap():
26
+ """Create a 2D colormap programmatically (Bremm-style).
27
+
28
+ This creates a smooth 2D color gradient where:
29
+ - X position maps to hue variation
30
+ - Y position maps to saturation/value variation
31
+ """
32
+ size = 256
33
+ colormap = np.zeros((size, size, 3), dtype=np.uint8)
34
+
35
+ for y in range(size):
36
+ for x in range(size):
37
+ # Normalize to [0, 1]
38
+ nx = x / (size - 1)
39
+ ny = y / (size - 1)
40
+
41
+ # Create a 2D color mapping using HSV
42
+ # Hue varies with x, saturation/value with y
43
+ hue = (nx * 0.8 + ny * 0.2) % 1.0 # Mix of x and y for hue
44
+ saturation = 0.6 + 0.4 * (1 - ny) # Higher saturation at top
45
+ value = 0.7 + 0.3 * nx # Higher value on right
46
+
47
+ # Convert HSV to RGB
48
+ rgb = colorsys.hsv_to_rgb(hue, saturation, value)
49
+ colormap[y, x] = [int(c * 255) for c in rgb]
50
+
51
+ return colormap
52
+
53
+
54
+ def _get_bremm_colormap():
55
+ """Get or create the bremm colormap."""
56
+ global BREMM_COLORMAP
57
+ if BREMM_COLORMAP is None:
58
+ # Try to load from file first
59
+ colormap_file = os.path.join(os.path.dirname(__file__), "bremm.png")
60
+ if os.path.exists(colormap_file):
61
+ BREMM_COLORMAP = (plt.imread(colormap_file) * 255).astype(np.uint8)
62
+ if BREMM_COLORMAP.shape[2] == 4: # RGBA
63
+ BREMM_COLORMAP = BREMM_COLORMAP[:, :, :3]
64
+ else:
65
+ BREMM_COLORMAP = _create_bremm_colormap()
66
+ return BREMM_COLORMAP
67
+
68
+
69
+ def get_2d_colors(xys: np.ndarray, H: int, W: int) -> np.ndarray:
70
+ """Get colors based on 2D position using Bremm colormap.
71
+
72
+ This creates position-dependent colors where nearby points have
73
+ similar colors, useful for visualizing spatial coherence.
74
+
75
+ Args:
76
+ xys: Point coordinates [N, 2] in pixel space (x, y)
77
+ H: Image height
78
+ W: Image width
79
+
80
+ Returns:
81
+ Array of RGB colors [N, 3] as uint8
82
+ """
83
+ colormap = _get_bremm_colormap()
84
+ height, width = colormap.shape[:2]
85
+
86
+ N = xys.shape[0]
87
+ output = np.zeros((N, 3), dtype=np.uint8)
88
+
89
+ # Normalize coordinates to [0, 1]
90
+ xys_norm = xys.copy().astype(np.float32)
91
+ xys_norm[:, 0] = xys_norm[:, 0] / max(W - 1, 1)
92
+ xys_norm[:, 1] = xys_norm[:, 1] / max(H - 1, 1)
93
+
94
+ # Clip to valid range
95
+ xys_norm = np.clip(xys_norm, 0, 1)
96
+
97
+ # Map to colormap coordinates
98
+ for i in range(N):
99
+ x, y = xys_norm[i]
100
+ xp = int((width - 1) * x)
101
+ yp = int((height - 1) * y)
102
+ output[i] = colormap[yp, xp]
103
+
104
+ return output
105
+
106
+
107
+ def get_colors_from_cmap(num_colors: int, cmap: str = "gist_rainbow") -> np.ndarray:
108
+ """Gets colormap for points using matplotlib colormap.
109
+
110
+ Args:
111
+ num_colors: Number of colors to generate
112
+ cmap: Matplotlib colormap name (e.g., "gist_rainbow", "jet", "turbo")
113
+
114
+ Returns:
115
+ Array of RGB colors [num_colors, 3] as uint8
116
+ """
117
+ cmap_ = matplotlib.colormaps.get_cmap(cmap)
118
+ colors = []
119
+ for i in range(num_colors):
120
+ c = cmap_(i / float(num_colors))
121
+ colors.append((int(c[0] * 255), int(c[1] * 255), int(c[2] * 255)))
122
+ return np.array(colors)
123
+
124
+
125
+ def paint_point_track(
126
+ frames: np.ndarray,
127
+ point_tracks: np.ndarray,
128
+ visibles: np.ndarray,
129
+ colormap: Optional[Union[List[Tuple[int, int, int]], np.ndarray]] = None,
130
+ rate: int = 1,
131
+ show_bkg: bool = True,
132
+ ) -> np.ndarray:
133
+ """Paint point tracks on video frames using GPU-accelerated scatter.
134
+
135
+ Args:
136
+ frames: Video frames [T, H, W, C] in uint8
137
+ point_tracks: Track coordinates [P, T, 2] (x, y)
138
+ visibles: Visibility mask [P, T]
139
+ colormap: Optional list/array of RGB colors for each point
140
+ rate: Subsampling rate for visualization (affects point size)
141
+ show_bkg: Whether to show background (True) or black out (False)
142
+
143
+ Returns:
144
+ Painted frames [T, H, W, C] in uint8
145
+ """
146
+ print("Starting visualization...")
147
+ device = "cuda" if torch.cuda.is_available() else "cpu"
148
+ frames_t = (
149
+ torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device)
150
+ ) # [T,C,H,W]
151
+
152
+ if show_bkg:
153
+ frames_t = frames_t * 0.5 # darken to see tracks better
154
+ else:
155
+ frames_t = frames_t * 0.0 # black out background
156
+
157
+ point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
158
+ visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
159
+ T, C, H, W = frames_t.shape
160
+ P = point_tracks.shape[0]
161
+
162
+ # Use gist_rainbow colormap (matching app3.py behavior)
163
+ if colormap is None:
164
+ colormap = get_colors_from_cmap(P, "gist_rainbow")
165
+ colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [P,3]
166
+
167
+ # Adjust radius based on rate
168
+ if rate == 1:
169
+ radius = 1
170
+ elif rate == 2:
171
+ radius = 1
172
+ elif rate == 4:
173
+ radius = 2
174
+ elif rate == 8:
175
+ radius = 4
176
+ else:
177
+ radius = 6
178
+
179
+ sharpness = 0.15 + 0.05 * np.log2(rate)
180
+
181
+ D = radius * 2 + 1
182
+ y = torch.arange(D, device=device).float()[:, None] - radius
183
+ x = torch.arange(D, device=device).float()[None, :] - radius
184
+ dist2 = x**2 + y**2
185
+ icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1)
186
+ icon = icon.view(1, D, D)
187
+ dx = torch.arange(-radius, radius + 1, device=device)
188
+ dy = torch.arange(-radius, radius + 1, device=device)
189
+ disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij")
190
+
191
+ for t in range(T):
192
+ mask = visibles_t[:, t]
193
+ if mask.sum() == 0:
194
+ continue
195
+ xy = point_tracks_t[mask, t] + 0.5
196
+ xy[:, 0] = xy[:, 0].clamp(0, W - 1)
197
+ xy[:, 1] = xy[:, 1].clamp(0, H - 1)
198
+ colors_now = colors[mask]
199
+ N = xy.shape[0]
200
+ cx = xy[:, 0].long()
201
+ cy = xy[:, 1].long()
202
+ x_grid = cx[:, None, None] + disp_x
203
+ y_grid = cy[:, None, None] + disp_y
204
+ valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H)
205
+ x_valid = x_grid[valid]
206
+ y_valid = y_grid[valid]
207
+ icon_weights = icon.expand(N, D, D)[valid]
208
+ colors_valid = (
209
+ colors_now[:, :, None, None]
210
+ .expand(N, 3, D, D)
211
+ .permute(1, 0, 2, 3)[:, valid]
212
+ )
213
+ idx_flat = (y_valid * W + x_valid).long()
214
+
215
+ accum = torch.zeros_like(frames_t[t])
216
+ weight = torch.zeros(1, H * W, device=device)
217
+ img_flat = accum.view(C, -1)
218
+ weighted_colors = colors_valid * icon_weights
219
+ img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
220
+ weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0))
221
+ weight = weight.view(1, H, W)
222
+
223
+ alpha = weight.clamp(0, 1)
224
+ accum = accum / (weight + 1e-6)
225
+ frames_t[t] = frames_t[t] * (1 - alpha) + accum * alpha
226
+
227
+ print("Visualization done.")
228
+ return frames_t.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy()
229
+
demo.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ Minimal CoWTracker inference demo.
10
+
11
+ Usage:
12
+ python demo.py --video input.mp4 --output output.mp4
13
+ python demo.py --video input.mp4 --output output.mp4 --checkpoint ~/run168/cow_tracker_model.pth
14
+ """
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import mediapy
20
+ import numpy as np
21
+ import torch
22
+
23
+ from cowtracker import CoWTracker
24
+ from cowtracker.utils.visualization import paint_point_track
25
+
26
+ inf_dtype = torch.float16
27
+ def preprocess_video(video_path, max_frames=200, target_size=(336, 560)):
28
+ """Load and preprocess video.
29
+
30
+ Args:
31
+ video_path: Path to input video
32
+ max_frames: Maximum number of frames to process
33
+ target_size: Target size (H, W) for inference
34
+
35
+ Returns:
36
+ Tuple of (video_array, fps)
37
+ """
38
+ video_arr = mediapy.read_video(video_path)
39
+ video_fps = video_arr.metadata.fps
40
+ num_frames = video_arr.shape[0]
41
+
42
+ # Truncate if too long
43
+ if num_frames > max_frames:
44
+ print(f"Video is too long. Truncating to first {max_frames} frames.")
45
+ video_arr = video_arr[:max_frames]
46
+
47
+ # Resize to target size
48
+ video_arr = mediapy.resize_video(video_arr, target_size)
49
+
50
+ return np.array(video_arr), video_fps
51
+
52
+
53
+ def run_inference(model, video):
54
+ """Run tracking inference on video.
55
+
56
+ Args:
57
+ model: CoWTracker model
58
+ video: Video array [T, H, W, C] in uint8
59
+
60
+ Returns:
61
+ Tuple of (tracks, visibilities, confidences)
62
+ - tracks: [T, H, W, 2]
63
+ - visibilities: [T, H, W]
64
+ - confidences: [T, H, W]
65
+ """
66
+ device = next(model.parameters()).device
67
+
68
+ # Convert to tensor [T, C, H, W]
69
+ video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2).float().to(device)
70
+ T, C, H, W = video_tensor.shape
71
+ print(f"Video size: {H}x{W}")
72
+
73
+ torch.cuda.empty_cache()
74
+
75
+ with torch.no_grad():
76
+ with torch.amp.autocast(device_type="cuda", dtype=inf_dtype):
77
+ predictions = model.forward(video=video_tensor, queries=None)
78
+
79
+ tracks = predictions["track"][0].cpu()
80
+ visibility = predictions["vis"][0].cpu()
81
+ confidence = predictions["conf"][0].cpu()
82
+
83
+ visconf = visibility * confidence
84
+ return tracks, visconf > 0.1, visconf
85
+
86
+
87
+ def create_visualization(video, tracks, visibilities, rate=8, fps=30, show_bkg=True):
88
+ """Create visualization video.
89
+
90
+ Args:
91
+ video: Video array [T, H, W, C]
92
+ tracks: Tracks [T, H, W, 2]
93
+ visibilities: Visibility mask [T, H, W]
94
+ rate: Subsampling rate for points
95
+ fps: Output video fps
96
+ show_bkg: Whether to show background
97
+
98
+ Returns:
99
+ Painted video frames [T, H, W, C]
100
+ """
101
+ T, H, W, _ = video.shape
102
+
103
+ # Subsample tracks for visualization
104
+ tracks_np = tracks.permute(1, 2, 0, 3).reshape(-1, T, 2).numpy() # [HW, T, 2]
105
+ vis_np = visibilities.permute(1, 2, 0).reshape(-1, T).numpy() # [HW, T]
106
+
107
+ # Subsample
108
+ tracks_sub = tracks_np.reshape(H, W, T, 2)[::rate, ::rate].reshape(-1, T, 2)
109
+ vis_sub = vis_np.reshape(H, W, T)[::rate, ::rate].reshape(-1, T)
110
+
111
+ # Paint tracks
112
+ painted_video = paint_point_track(
113
+ video, tracks_sub, vis_sub, rate=rate, show_bkg=show_bkg
114
+ )
115
+
116
+ return painted_video
117
+
118
+
119
+ def main():
120
+ parser = argparse.ArgumentParser(description="CoWTracker Inference Demo")
121
+ parser.add_argument("--video", type=str, required=True, help="Path to input video")
122
+ parser.add_argument("--output", type=str, default=None, help="Path to output video")
123
+ parser.add_argument(
124
+ "--checkpoint",
125
+ type=str,
126
+ default=None,
127
+ help="Path to model checkpoint",
128
+ )
129
+ parser.add_argument(
130
+ "--rate", type=int, default=8, help="Subsampling rate for visualization"
131
+ )
132
+ parser.add_argument(
133
+ "--max_frames", type=int, default=200, help="Maximum number of frames"
134
+ )
135
+ parser.add_argument("--no_bkg", action="store_true", help="Hide background in visualization")
136
+ args = parser.parse_args()
137
+
138
+ # Set output path
139
+ if args.output is None:
140
+ base_name = os.path.splitext(os.path.basename(args.video))[0]
141
+ args.output = f"{base_name}_tracked.mp4"
142
+
143
+ print("=" * 60)
144
+ print("CoWTracker Inference Demo")
145
+ print("=" * 60)
146
+
147
+ # Load model
148
+ print("\n[1/4] Loading model...")
149
+ model = CoWTracker.from_checkpoint(
150
+ args.checkpoint,
151
+ device="cuda" if torch.cuda.is_available() else "cpu",
152
+ dtype=inf_dtype if torch.cuda.is_available() else torch.float32,
153
+ )
154
+
155
+ # Load video
156
+ print("\n[2/4] Loading video...")
157
+ video, fps = preprocess_video(args.video, max_frames=args.max_frames)
158
+ print(f"Video shape: {video.shape}, FPS: {fps}")
159
+
160
+ # Run inference
161
+ print("\n[3/4] Running inference...")
162
+ tracks, visibilities, confidences = run_inference(model, video)
163
+ print(f"Tracks shape: {tracks.shape}")
164
+
165
+ # Create visualization
166
+ print("\n[4/4] Creating visualization...")
167
+ painted_video = create_visualization(
168
+ video, tracks, visibilities, rate=args.rate, fps=fps, show_bkg=not args.no_bkg
169
+ )
170
+
171
+ # Save output
172
+ mediapy.write_video(args.output, painted_video, fps=fps)
173
+ print(f"\nSaved output to: {args.output}")
174
+ print("=" * 60)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
179
+
docs/logo.jpg ADDED

Git LFS Details

  • SHA256: ea7971892491cdc682c4592ce208d755e899269e28b57133cdd343020c1c1726
  • Pointer size: 130 Bytes
  • Size of remote file: 69.4 kB
docs/teaser.jpg ADDED

Git LFS Details

  • SHA256: 23100f29cbaf75d0e40a39ef480b5dbc3a6e61513d2ada83d810a3d4b4541f17
  • Pointer size: 131 Bytes
  • Size of remote file: 283 kB
environments.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cowtracker
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.12
9
+ - pip
10
+ - pip:
11
+ # Core deep learning
12
+ - torch>=2.0.0
13
+ - torchvision>=0.15.0
14
+ - xformers
15
+ - timm
16
+
17
+ # Numerical / scientific
18
+ - numpy
19
+ - scipy
20
+ - einops
21
+
22
+ # Image / video processing
23
+ - opencv-python
24
+ - Pillow
25
+ - mediapy
26
+ - matplotlib
27
+
28
+ # Model hub
29
+ - huggingface_hub
30
+
31
+ # Gradio demo
32
+ - gradio
33
+
34
+ # Optional: development tools
35
+ - ipython
36
+ - ipdb
output.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:989c1e01c2f3eb2e55387ba74c78497d3aba3805c63dc0eeb7f581754465bf7c
3
+ size 2550238
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core deep learning
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ xformers==0.0.33.post1
5
+ timm
6
+
7
+ # Numerical / scientific
8
+ numpy
9
+ scipy
10
+ einops
11
+
12
+ # Image / video processing
13
+ opencv-python-headless
14
+ Pillow
15
+ mediapy
16
+ matplotlib
17
+
18
+ # Model hub
19
+ huggingface_hub
20
+
21
+ # Gradio demo
22
+ gradio>=4.0.0
23
+
24
+ # HuggingFace Spaces (ZeroGPU support)
25
+ spaces
26
+
videos/apple.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7f48c5cfb1479e1dbc1df2373d5cad4f55c198bbdb379da0ece10087971542a
3
+ size 1219872
videos/bear.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeffab1780be601b19b2097be81a8c2d4fa2b624ac1028be0a32191d25acca0f
3
+ size 893943
videos/bmx-bumps.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d4aa73e0342d8dc08c4a7e3c9ea10e46507d363f0363f3e38bfb3ececa1588
3
+ size 3094667
videos/cows.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4ca9ba3b3f720142917dc20935b03c9bbdc629e55502955526daccec567170d
3
+ size 5282840
videos/lab-coat.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf43f05f1a011e6cf376d4dc17b18b249613fa5e082df9e3af941fa012c34c9e
3
+ size 1850114
videos/longboard.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13c977306e09a3c0766952497b7e51d5accfb5b15bdcb5ac32a1c1afc7893f67
3
+ size 2879038
videos/motocross-jump.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f757c903177154b7eceb4b9fe5386bb38cba20ef4da8645cbfc450e0ac39ffef
3
+ size 1343986