Commit ·
012c9b1
1
Parent(s): e06ad17
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- av_hubert/.gitmodules +3 -0
- av_hubert/CODE_OF_CONDUCT.md +80 -0
- av_hubert/CONTRIBUTING.md +31 -0
- av_hubert/LICENSE +159 -0
- av_hubert/README.md +164 -0
- av_hubert/assets/lipreading.gif +3 -0
- av_hubert/avhubert/__init__.py +10 -0
- av_hubert/avhubert/clustering/README.md +100 -0
- av_hubert/avhubert/clustering/dump_hubert_feature.py +177 -0
- av_hubert/avhubert/clustering/dump_km_label.py +99 -0
- av_hubert/avhubert/clustering/dump_mfcc_feature.py +117 -0
- av_hubert/avhubert/clustering/learn_kmeans.py +147 -0
- av_hubert/avhubert/clustering/requirements.txt +6 -0
- av_hubert/avhubert/clustering/submit_cluster.py +132 -0
- av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_30h.yaml +121 -0
- av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_433h.yaml +121 -0
- av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_30h.yaml +124 -0
- av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_433h.yaml +124 -0
- av_hubert/avhubert/conf/finetune/base_lrs3_30h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_lrs3_433h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_vox_30h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/base_vox_433h.yaml +118 -0
- av_hubert/avhubert/conf/finetune/large_lrs3_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_lrs3_433h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_vox_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/large_vox_433h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/self_large_vox_30h.yaml +121 -0
- av_hubert/avhubert/conf/finetune/self_large_vox_433h.yaml +121 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter1.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter2.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter3.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter4.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_lrs3_iter5.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter1.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter2.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter3.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter4.yaml +112 -0
- av_hubert/avhubert/conf/pretrain/base_vox_iter5.yaml +113 -0
- av_hubert/avhubert/conf/pretrain/large_lrs3_iter5.yaml +117 -0
- av_hubert/avhubert/conf/pretrain/large_vox_iter5.yaml +117 -0
- av_hubert/avhubert/conf/pretrain/noise_base_vox_iter5.yaml +115 -0
- av_hubert/avhubert/conf/pretrain/noise_large_vox_iter5.yaml +119 -0
- av_hubert/avhubert/conf/s2s_decode.yaml +23 -0
- av_hubert/avhubert/decoder.py +243 -0
- av_hubert/avhubert/hubert.py +779 -0
- av_hubert/avhubert/hubert_asr.py +521 -0
- av_hubert/avhubert/hubert_criterion.py +169 -0
- av_hubert/avhubert/hubert_dataset.py +529 -0
- av_hubert/avhubert/hubert_pretraining.py +400 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
av_hubert/assets/lipreading.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
av_hubert/fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
|
av_hubert/.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "fairseq"]
|
| 2 |
+
path = fairseq
|
| 3 |
+
url = https://github.com/pytorch/fairseq
|
av_hubert/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
av_hubert/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to av_hubert
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to av_hubert, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
av_hubert/LICENSE
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
AV-HuBERT LICENSE AGREEMENT
|
| 2 |
+
|
| 3 |
+
This License Agreement (as may be amended in accordance with this License
|
| 4 |
+
Agreement, “License”), between you (“Licensee” or “you”) and Meta Platforms,
|
| 5 |
+
Inc. (“Meta” or “we”) applies to your use of any computer program, algorithm,
|
| 6 |
+
source code, object code, or software that is made available by Meta under this
|
| 7 |
+
License (“Software”) and any specifications, manuals, documentation, and other
|
| 8 |
+
written information provided by Meta related to the Software (“Documentation”).
|
| 9 |
+
|
| 10 |
+
By clicking “I Accept” below or by using the Software, you agree to the terms
|
| 11 |
+
of this License. If you do not agree to this License, then you do not have any
|
| 12 |
+
rights to use the Software or Documentation (collectively, the “Software
|
| 13 |
+
Products”), and you must immediately cease using the Software Products.
|
| 14 |
+
|
| 15 |
+
1. LICENSE GRANT a. Subject to your compliance with the Documentation and
|
| 16 |
+
Sections 2, 3, and 5, Meta grants you a non-exclusive, worldwide,
|
| 17 |
+
non-transferable, non-sublicensable, revocable, royalty free and limited
|
| 18 |
+
license under Meta’s copyright interests to reproduce, distribute, and create
|
| 19 |
+
derivative works of the Software solely for your non-commercial research
|
| 20 |
+
purposes. The foregoing license is personal to you, and you may not assign or
|
| 21 |
+
sublicense this License or any other rights or obligations under this License
|
| 22 |
+
without Meta’s prior written consent; any such assignment or sublicense will be
|
| 23 |
+
void and will automatically and immediately terminate this License.
|
| 24 |
+
|
| 25 |
+
b. You may make a reasonable number of copies of the Documentation solely for
|
| 26 |
+
use in connection with the license to the Software granted above.
|
| 27 |
+
|
| 28 |
+
c. The grant of rights expressly set forth in this Section 1 (License Grant)
|
| 29 |
+
are the complete grant of rights to you in the Software Products, and no other
|
| 30 |
+
licenses are granted, whether by waiver, estoppel, implication, equity or
|
| 31 |
+
otherwise. Meta and its licensors reserve all rights not expressly granted by
|
| 32 |
+
this License.
|
| 33 |
+
|
| 34 |
+
2. RESTRICTIONS
|
| 35 |
+
|
| 36 |
+
You will not, and will not permit, assist or cause any third party to:
|
| 37 |
+
|
| 38 |
+
a. use, modify, copy, reproduce, create derivative works of, or distribute the
|
| 39 |
+
Software Products (or any derivative works thereof, works incorporating the
|
| 40 |
+
Software Products, or any data produced by the Software), in whole or in part,
|
| 41 |
+
for (i) any commercial or production purposes, (ii) military purposes or in the
|
| 42 |
+
service of nuclear technology, (iii) purposes of surveillance, including any
|
| 43 |
+
research or development relating to surveillance, (iv) biometric processing,
|
| 44 |
+
(v) in any manner that infringes, misappropriates, or otherwise violates any
|
| 45 |
+
third-party rights, or (vi) in any manner that violates any applicable law,
|
| 46 |
+
including any privacy or security laws, rules, regulations, directives, or
|
| 47 |
+
governmental requirements (including the General Data Privacy Regulation
|
| 48 |
+
(Regulation (EU) 2016/679), the California Consumer Privacy Act, and any and
|
| 49 |
+
all laws governing the processing of biometric information), as well as all
|
| 50 |
+
amendments and successor laws to any of the foregoing;
|
| 51 |
+
|
| 52 |
+
b. decompile, disassemble, or reverse-engineer the Software, in whole or in
|
| 53 |
+
part;
|
| 54 |
+
|
| 55 |
+
c. alter or remove copyright and other proprietary notices which appear on or
|
| 56 |
+
in the Software Products;
|
| 57 |
+
|
| 58 |
+
d. utilize any equipment, device, software, or other means to circumvent or
|
| 59 |
+
remove any security or protection used by Meta in connection with the Software,
|
| 60 |
+
or to circumvent or remove any usage restrictions, or to enable functionality
|
| 61 |
+
disabled by Meta; or
|
| 62 |
+
|
| 63 |
+
e. offer or impose any terms on the Software Products that alter, restrict, or
|
| 64 |
+
are inconsistent with the terms of this License.
|
| 65 |
+
|
| 66 |
+
3. ATTRIBUTION
|
| 67 |
+
|
| 68 |
+
Together with any copies of the Software Products (as well as derivative works
|
| 69 |
+
thereof or works incorporating the Software Products) that you distribute, you
|
| 70 |
+
must provide (i) a copy of this License, and (ii) the following attribution
|
| 71 |
+
notice: “AV-HuBERT is licensed under the AV-HuBERT license, Copyright (c) Meta
|
| 72 |
+
Platforms, Inc. All Rights Reserved.”
|
| 73 |
+
|
| 74 |
+
4. DISCLAIMERS
|
| 75 |
+
|
| 76 |
+
THE SOFTWARE PRODUCTS ARE PROVIDED “AS IS” and “WITH ALL FAULTS” WITH NO
|
| 77 |
+
WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. META EXPRESSLY DISCLAIMS ALL
|
| 78 |
+
REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM,
|
| 79 |
+
USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE SOFTWARE PRODUCTS,
|
| 80 |
+
INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY,
|
| 81 |
+
FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR
|
| 82 |
+
NON-INFRINGEMENT. META MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE SOFTWARE
|
| 83 |
+
PRODUCTS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR
|
| 84 |
+
PRODUCE ANY PARTICULAR RESULTS.
|
| 85 |
+
|
| 86 |
+
5. LIMITATION OF LIABILITY
|
| 87 |
+
|
| 88 |
+
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL META BE LIABLE TO YOU
|
| 89 |
+
(A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE,
|
| 90 |
+
STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY
|
| 91 |
+
INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR
|
| 92 |
+
LOST PROFITS, EVEN IF META HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 93 |
+
THE SOFTWARE PRODUCTS, THEIR CONSTITUENT COMPONENTS, AND ANY OUTPUT
|
| 94 |
+
(COLLECTIVELY, “SOFTWARE MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN
|
| 95 |
+
ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE SOFTWARE MATERIALS
|
| 96 |
+
COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON,
|
| 97 |
+
INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY
|
| 98 |
+
RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A
|
| 99 |
+
“HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE SOFTWARE MATERIALS FOR A
|
| 100 |
+
HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT
|
| 101 |
+
APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN
|
| 102 |
+
CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT
|
| 103 |
+
IN ANY OF THE SOFTWARE MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY
|
| 104 |
+
THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR
|
| 105 |
+
THE FIELD OF THE HIGH-RISK USE.
|
| 106 |
+
|
| 107 |
+
6. TERMINATION; SURVIVAL
|
| 108 |
+
|
| 109 |
+
a. This License will automatically terminate upon any breach by you of the
|
| 110 |
+
terms of this License.
|
| 111 |
+
|
| 112 |
+
b. We may terminate this License, in whole or in part, at any time upon notice
|
| 113 |
+
(including electronic) to you.
|
| 114 |
+
|
| 115 |
+
c. The following sections survive termination of this License: 2
|
| 116 |
+
(Restrictions), 3 (Attribution), 4 (Disclaimers), 5 (Limitation on Liability),
|
| 117 |
+
6 (Termination; Survival), 7 (Third Party Materials), 8 (Trademarks), 9
|
| 118 |
+
(Applicable Law; Dispute Resolution), and 10 (Miscellaneous).
|
| 119 |
+
|
| 120 |
+
7. THIRD PARTY MATERIALS
|
| 121 |
+
|
| 122 |
+
The Software Products may contain third-party software or other components
|
| 123 |
+
(including free and open source software) (all of the foregoing, “Third Party
|
| 124 |
+
Materials”), which are subject to the license terms of the respective
|
| 125 |
+
third-party licensors. Your dealings or correspondence with third parties and
|
| 126 |
+
your use of or interaction with any Third Party Materials are solely between
|
| 127 |
+
you and the third party. Meta does not control or endorse, and makes no
|
| 128 |
+
representations or warranties regarding, any Third Party Materials, and your
|
| 129 |
+
access to and use of such Third Party Materials are at your own risk.
|
| 130 |
+
|
| 131 |
+
8. TRADEMARKS
|
| 132 |
+
|
| 133 |
+
Licensee has not been granted any trademark license as part of this License and
|
| 134 |
+
may not use any name or mark associated with Meta without the prior written
|
| 135 |
+
permission of Meta, except to the extent necessary to make the reference
|
| 136 |
+
required by the “ATTRIBUTION” section of this Agreement.
|
| 137 |
+
|
| 138 |
+
9. APPLICABLE LAW; DISPUTE RESOLUTION
|
| 139 |
+
|
| 140 |
+
This License will be governed and construed under the laws of the State of
|
| 141 |
+
California without regard to conflicts of law provisions. Any suit or
|
| 142 |
+
proceeding arising out of or relating to this License will be brought in the
|
| 143 |
+
federal or state courts, as applicable, in San Mateo County, California, and
|
| 144 |
+
each party irrevocably submits to the jurisdiction and venue of such courts.
|
| 145 |
+
|
| 146 |
+
10. MISCELLANEOUS
|
| 147 |
+
|
| 148 |
+
If any provision or part of a provision of this License is unlawful, void or
|
| 149 |
+
unenforceable, that provision or part of the provision is deemed severed from
|
| 150 |
+
this License, and will not affect the validity and enforceability of any
|
| 151 |
+
remaining provisions. The failure of Meta to exercise or enforce any right or
|
| 152 |
+
provision of this License will not operate as a waiver of such right or
|
| 153 |
+
provision. This License does not confer any third-party beneficiary rights upon
|
| 154 |
+
any other person or entity. This License, together with the Documentation,
|
| 155 |
+
contains the entire understanding between you and Meta regarding the subject
|
| 156 |
+
matter of this License, and supersedes all other written or oral agreements and
|
| 157 |
+
understandings between you and Meta regarding such subject matter. No change or
|
| 158 |
+
addition to any provision of this License will be binding unless it is in
|
| 159 |
+
writing and signed by an authorized representative of both you and Meta.
|
av_hubert/README.md
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AV-HuBERT (Audio-Visual Hidden Unit BERT)
|
| 2 |
+
[Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction](https://arxiv.org/abs/2201.02184)
|
| 3 |
+
|
| 4 |
+
[Robust Self-Supervised Audio-Visual Speech Recognition](https://arxiv.org/abs/2201.01763)
|
| 5 |
+
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
## Introduction
|
| 9 |
+
AV-HuBERT is a self-supervised representation learning framework for audio-visual speech. It achieves state-of-the-art results in lip reading, ASR and audio-visual speech recognition on the LRS3 audio-visual speech benchmark.
|
| 10 |
+
|
| 11 |
+
If you find AV-HuBERT useful in your research, please use the following BibTeX entry for citation.
|
| 12 |
+
```BibTeX
|
| 13 |
+
@article{shi2022avhubert,
|
| 14 |
+
author = {Bowen Shi and Wei-Ning Hsu and Kushal Lakhotia and Abdelrahman Mohamed},
|
| 15 |
+
title = {Learning Audio-Visual Speech Representation by Masked Multimodal Cluster Prediction},
|
| 16 |
+
journal = {arXiv preprint arXiv:2201.02184}
|
| 17 |
+
year = {2022}
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
@article{shi2022avsr,
|
| 21 |
+
author = {Bowen Shi and Wei-Ning Hsu and Abdelrahman Mohamed},
|
| 22 |
+
title = {Robust Self-Supervised Audio-Visual Speech Recognition},
|
| 23 |
+
journal = {arXiv preprint arXiv:2201.01763}
|
| 24 |
+
year = {2022}
|
| 25 |
+
}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## License
|
| 29 |
+
|
| 30 |
+
AV-HuBERT LICENSE AGREEMENT
|
| 31 |
+
|
| 32 |
+
This License Agreement (as may be amended in accordance with this License
|
| 33 |
+
Agreement, “License”), between you (“Licensee” or “you”) and Meta Platforms,
|
| 34 |
+
Inc. (“Meta” or “we”) applies to your use of any computer program, algorithm,
|
| 35 |
+
source code, object code, or software that is made available by Meta under this
|
| 36 |
+
License (“Software”) and any specifications, manuals, documentation, and other
|
| 37 |
+
written information provided by Meta related to the Software (“Documentation”).
|
| 38 |
+
|
| 39 |
+
By using the Software, you agree to the terms of [this
|
| 40 |
+
License](https://github.com/facebookresearch/av_hubert/blob/main/LICENSE). If
|
| 41 |
+
you do not agree to this License, then you do not have any rights to use the
|
| 42 |
+
Software or Documentation (collectively, the “Software Products”), and you must
|
| 43 |
+
immediately cease using the Software Products.
|
| 44 |
+
|
| 45 |
+
## Pre-trained and fine-tuned models
|
| 46 |
+
|
| 47 |
+
Please find the checkpoints [here](http://facebookresearch.github.io/av_hubert)
|
| 48 |
+
|
| 49 |
+
## Demo
|
| 50 |
+
Run our lip-reading demo using Colab: [](https://colab.research.google.com/drive/1bNXkfpHiVHzXQH8WjGhzQ-fsDxolpUjD)
|
| 51 |
+
|
| 52 |
+
## Installation
|
| 53 |
+
First, create a conda virtual environment and activate it:
|
| 54 |
+
```
|
| 55 |
+
conda create -n avhubert python=3.8 -y
|
| 56 |
+
conda activate avhubert
|
| 57 |
+
```
|
| 58 |
+
Then, clone this directory:
|
| 59 |
+
```
|
| 60 |
+
git clone https://github.com/facebookresearch/av_hubert.git
|
| 61 |
+
cd avhubert
|
| 62 |
+
git submodule init
|
| 63 |
+
git submodule update
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
Lastly, install Fairseq and the other packages:
|
| 67 |
+
```
|
| 68 |
+
pip install -r requirements.txt
|
| 69 |
+
cd fairseq
|
| 70 |
+
pip install --editable ./
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Load a pretrained model
|
| 74 |
+
```sh
|
| 75 |
+
$ cd avhubert
|
| 76 |
+
$ python
|
| 77 |
+
>>> import fairseq
|
| 78 |
+
>>> import hubert_pretraining, hubert
|
| 79 |
+
>>> ckpt_path = "/path/to/the/checkpoint.pt"
|
| 80 |
+
>>> models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
| 81 |
+
>>> model = models[0]
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Train a new model
|
| 85 |
+
|
| 86 |
+
### Data preparation
|
| 87 |
+
|
| 88 |
+
Follow the steps in [`preparation`](avhubert/preparation/) to pre-process:
|
| 89 |
+
- LRS3 and VoxCeleb2 datasets
|
| 90 |
+
|
| 91 |
+
Follow the steps in [`clustering`](avhubert/clustering/) (pre-train only) to create:
|
| 92 |
+
- `{train,valid}.km` frame-aligned pseudo label files.
|
| 93 |
+
The `label_rate` is the same as the feature frame rate used for clustering,
|
| 94 |
+
which is 100Hz for MFCC features and 25Hz for AV-HuBERT features by default.
|
| 95 |
+
|
| 96 |
+
### Pre-train an AV-HuBERT model
|
| 97 |
+
|
| 98 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
|
| 99 |
+
are saved at `/path/to/labels`, the configuration file is saved at `/path/to/conf/conf-name`, and the label rate is 100Hz.
|
| 100 |
+
|
| 101 |
+
To train a model, run:
|
| 102 |
+
```sh
|
| 103 |
+
$ cd avhubert
|
| 104 |
+
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name conf-name \
|
| 105 |
+
task.data=/path/to/data task.label_dir=/path/to/label \
|
| 106 |
+
model.label_rate=100 hydra.run.dir=/path/to/experiment/pretrain/ \
|
| 107 |
+
common.user_dir=`pwd`
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### Finetune an AV-HuBERT model with Seq2Seq
|
| 111 |
+
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.wrd`
|
| 112 |
+
are saved at `/path/to/labels`, the configuration file is saved at `/path/to/conf/conf-name`.
|
| 113 |
+
|
| 114 |
+
To fine-tune a pre-trained HuBERT model at `/path/to/checkpoint`, run:
|
| 115 |
+
```sh
|
| 116 |
+
$ cd avhubert
|
| 117 |
+
$ fairseq-hydra-train --config-dir /path/to/conf/ --config-name conf-name \
|
| 118 |
+
task.data=/path/to/data task.label_dir=/path/to/label \
|
| 119 |
+
task.tokenizer_bpe_model=/path/to/tokenizer model.w2v_path=/path/to/checkpoint \
|
| 120 |
+
hydra.run.dir=/path/to/experiment/finetune/ common.user_dir=`pwd`
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Decode an AV-HuBERT model
|
| 124 |
+
Suppose the `test.tsv` and `test.wrd` are the video list and transcripts of
|
| 125 |
+
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
|
| 126 |
+
saved at `/path/to/checkpoint`.
|
| 127 |
+
|
| 128 |
+
#### Seq2Seq decoding
|
| 129 |
+
|
| 130 |
+
`task.normalize` needs to be consistent with the value used during fine-tuning.
|
| 131 |
+
Decoding results will be saved at
|
| 132 |
+
`/path/to/experiment/decode/s2s/test`.
|
| 133 |
+
|
| 134 |
+
```sh
|
| 135 |
+
$ cd avhubert
|
| 136 |
+
$ python -B infer_s2s.py --config-dir ./conf/ --config-name conf-name \
|
| 137 |
+
dataset.gen_subset=test common_eval.path=/path/to/checkpoint \
|
| 138 |
+
common_eval.results_path=/path/to/experiment/decode/s2s/test \
|
| 139 |
+
override.modalities=['video'] common.user_dir=`pwd`
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
The command above uses the default decoding hyperparameter, which can be found
|
| 143 |
+
in `conf/s2s_decode.yaml`. `override.modalities` can be set to `['video']` (for lip reading),
|
| 144 |
+
or `['audio']` (for ASR) or `['audio','video']` (for audio-visual speech recognition).These parameters can be
|
| 145 |
+
configured from the command line. For example, to search with a beam size of
|
| 146 |
+
20, we can append the command above with `generation.beam=20`.
|
| 147 |
+
Important parameters include:
|
| 148 |
+
- generation.beam
|
| 149 |
+
- generation.lenpen
|
| 150 |
+
|
| 151 |
+
#### Different test set
|
| 152 |
+
If your test data are stored in a different directory with the training data, append the following to the above command.
|
| 153 |
+
|
| 154 |
+
`+override.data=/path/to/test +override.label_dir=/path/to/test`
|
| 155 |
+
|
| 156 |
+
, where `/path/to/test` contains `test.{tsv,wrd}`. This is useful when you want to test with the fine-tuned checkpoints we provide.
|
| 157 |
+
|
| 158 |
+
#### Test under noisy environment
|
| 159 |
+
If you want to test your model under noisy environment, append the following to the above command.
|
| 160 |
+
|
| 161 |
+
`+override.noise_wav=/path/to/noise override.noise_prob=1 override.noise_snr={snr}`
|
| 162 |
+
|
| 163 |
+
`{snr}` is the signal-to-noise ratio (SNR) and `/path/to/noise` is a folder containing noise manifest files (`/path/to/noise/{valid,test}.tsv`). See [`preparation`](avhubert/preparation/) for setting up this folder.
|
| 164 |
+
|
av_hubert/assets/lipreading.gif
ADDED
|
Git LFS Details
|
av_hubert/avhubert/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .hubert import * # noqa
|
| 7 |
+
from .hubert_asr import * # noqa
|
| 8 |
+
from .hubert_dataset import *
|
| 9 |
+
from .hubert_pretraining import *
|
| 10 |
+
from .hubert_criterion import *
|
av_hubert/avhubert/clustering/README.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AV-HuBERT Label Preparation
|
| 2 |
+
|
| 3 |
+
This folder contains scripts for preparing AV-HUBERT labels from tsv files, the
|
| 4 |
+
steps are:
|
| 5 |
+
1. feature extraction
|
| 6 |
+
2. k-means clustering
|
| 7 |
+
3. k-means application
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
To prepare labels, you need some additional packages:
|
| 11 |
+
```
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Data preparation
|
| 16 |
+
|
| 17 |
+
`*.tsv` files contains a list of audio, where each line is the root, and
|
| 18 |
+
following lines are the subpath and number of frames of each video and audio separated by `tab`:
|
| 19 |
+
```
|
| 20 |
+
<root-dir>
|
| 21 |
+
<id-1> <video-path-1> <audio-path-1> <video-number-frames-1> <audio-number-frames-1>
|
| 22 |
+
<id-2> <video-path-2> <audio-path-2> <video-number-frames-2> <audio-number-frames-2>
|
| 23 |
+
...
|
| 24 |
+
```
|
| 25 |
+
See [here](../preparation/) for data preparation for LRS3 and VoxCeleb2.
|
| 26 |
+
|
| 27 |
+
## Feature extraction
|
| 28 |
+
|
| 29 |
+
### MFCC feature
|
| 30 |
+
Suppose the tsv file is at `${tsv_dir}/${split}.tsv`. To extract 39-D
|
| 31 |
+
mfcc+delta+ddelta features for the 1st iteration AV-HuBERT training, run:
|
| 32 |
+
```sh
|
| 33 |
+
python dump_mfcc_feature.py ${tsv_dir} ${split} ${nshard} ${rank} ${feat_dir}
|
| 34 |
+
```
|
| 35 |
+
This would shard the tsv file into `${nshard}` and extract features for the
|
| 36 |
+
`${rank}`-th shard, where rank is an integer in `[0, nshard-1]`. Features would
|
| 37 |
+
be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
### AV-HuBERT feature
|
| 41 |
+
To extract features from the `${layer}`-th transformer layer of a trained
|
| 42 |
+
AV-HuBERT model saved at `${ckpt_path}`, run:
|
| 43 |
+
```sh
|
| 44 |
+
python dump_hubert_feature.py ${tsv_dir} ${split} ${ckpt_path} ${layer} ${nshard} ${rank} ${feat_dir} --user_dir `pwd`/../
|
| 45 |
+
```
|
| 46 |
+
Features would also be saved at `${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`.
|
| 47 |
+
|
| 48 |
+
- if out-of-memory, decrease the chunk size with `--max_chunk`
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
## K-means clustering
|
| 52 |
+
To fit a k-means model with `${n_clusters}` clusters on 10% of the `${split}` data, run
|
| 53 |
+
```sh
|
| 54 |
+
python learn_kmeans.py ${feat_dir} ${split} ${nshard} ${km_path} ${n_cluster} --percent 0.1
|
| 55 |
+
```
|
| 56 |
+
This saves the k-means model to `${km_path}`.
|
| 57 |
+
|
| 58 |
+
- set `--precent -1` to use all data
|
| 59 |
+
- more kmeans options can be found with `-h` flag
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
## K-means application
|
| 63 |
+
To apply a trained k-means model `${km_path}` to obtain labels for `${split}`, run
|
| 64 |
+
```sh
|
| 65 |
+
python dump_km_label.py ${feat_dir} ${split} ${km_path} ${nshard} ${rank} ${lab_dir}
|
| 66 |
+
```
|
| 67 |
+
This would extract labels for the `${rank}`-th shard out of `${nshard}` shards
|
| 68 |
+
and dump them to `${lab_dir}/${split}_${rank}_${shard}.km`
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Finally, merge shards for `${split}` by running
|
| 72 |
+
```sh
|
| 73 |
+
for rank in $(seq 0 $((nshard - 1))); do
|
| 74 |
+
cat $lab_dir/${split}_${rank}_${nshard}.km
|
| 75 |
+
done > $lab_dir/${split}.km
|
| 76 |
+
```
|
| 77 |
+
and create a dictionary of cluster indexes by running
|
| 78 |
+
```sh
|
| 79 |
+
for i in $(seq 1 $((n_cluster-1)));do
|
| 80 |
+
echo $i 10000
|
| 81 |
+
done > $lab_dir/dict.{mfcc,km}.txt
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
## Clustering on slurm
|
| 86 |
+
If you are on slurm, you can combine the above steps (feature extraction + K-means clustering + K-means application) by:
|
| 87 |
+
|
| 88 |
+
- MFCC feature cluster:
|
| 89 |
+
```sh
|
| 90 |
+
python submit_cluster.py --tsv ${tsv_dir} --output ${lab_dir} --ncluster ${n_cluster} \
|
| 91 |
+
--nshard ${nshard} --mfcc --percent 0.1
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
- AV-HuBERT feature cluster:
|
| 95 |
+
```sh
|
| 96 |
+
python submit_cluster.py --tsv ${tsv_dir} --output ${lab_dir} --ckpt ${ckpt_path} --nlayer ${layer} \
|
| 97 |
+
--ncluster ${n_cluster} --nshard ${nshard} --percent 0.1
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
This would dump labels to `${lab_dir}/{train,valid}.km`.
|
av_hubert/avhubert/clustering/dump_hubert_feature.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import fairseq
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import tqdm
|
| 17 |
+
from npy_append_array import NpyAppendArray
|
| 18 |
+
import numpy as np
|
| 19 |
+
from python_speech_features import logfbank
|
| 20 |
+
from scipy.io import wavfile
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(
|
| 23 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 24 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 25 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 26 |
+
stream=sys.stdout,
|
| 27 |
+
)
|
| 28 |
+
logger = logging.getLogger("dump_hubert_feature")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HubertFeatureReader(object):
|
| 32 |
+
def __init__(self, ckpt_path, layer, max_chunk=1600000, custom_utils=None):
|
| 33 |
+
(
|
| 34 |
+
model,
|
| 35 |
+
cfg,
|
| 36 |
+
task,
|
| 37 |
+
) = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
|
| 38 |
+
self.model = model[0].eval().cuda()
|
| 39 |
+
self.task = task
|
| 40 |
+
self.layer = layer
|
| 41 |
+
self.max_chunk = max_chunk
|
| 42 |
+
self.stack_order_audio = self.task.cfg.stack_order_audio
|
| 43 |
+
image_crop_size, image_mean, image_std = self.task.cfg.image_crop_size, self.task.cfg.image_mean, self.task.cfg.image_std
|
| 44 |
+
self.transform = custom_utils.Compose([
|
| 45 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
| 46 |
+
custom_utils.CenterCrop((image_crop_size, image_crop_size)),
|
| 47 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
| 48 |
+
|
| 49 |
+
self.custom_utils = custom_utils
|
| 50 |
+
logger.info(f"TASK CONFIG:\n{self.task.cfg}")
|
| 51 |
+
logger.info(f" max_chunk = {self.max_chunk}")
|
| 52 |
+
logger.info(f"Transform: {self.transform}")
|
| 53 |
+
|
| 54 |
+
def load_feature(self, mix_name, ref_len=None):
|
| 55 |
+
def stacker(feats, stack_order):
|
| 56 |
+
feat_dim = feats.shape[1]
|
| 57 |
+
if len(feats) % stack_order != 0:
|
| 58 |
+
res = stack_order - len(feats) % stack_order
|
| 59 |
+
res = np.zeros([res, feat_dim]).astype(feats.dtype)
|
| 60 |
+
feats = np.concatenate([feats, res], axis=0)
|
| 61 |
+
feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
|
| 62 |
+
return feats
|
| 63 |
+
video_fn, audio_fn = mix_name
|
| 64 |
+
video_feats = self.load_image(video_fn)
|
| 65 |
+
|
| 66 |
+
audio_fn = audio_fn.split(':')[0]
|
| 67 |
+
sample_rate, wav_data = wavfile.read(audio_fn)
|
| 68 |
+
assert sample_rate == 16_000 and len(wav_data.shape) == 1
|
| 69 |
+
audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32)
|
| 70 |
+
audio_feats = stacker(audio_feats, self.stack_order_audio)
|
| 71 |
+
|
| 72 |
+
diff = len(audio_feats) - len(video_feats)
|
| 73 |
+
if diff < 0:
|
| 74 |
+
audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
|
| 75 |
+
elif diff > 0:
|
| 76 |
+
audio_feats = audio_feats[:-diff]
|
| 77 |
+
return video_feats, audio_feats
|
| 78 |
+
|
| 79 |
+
def load_image(self, audio_name):
|
| 80 |
+
feats = self.custom_utils.load_video(audio_name)
|
| 81 |
+
feats = self.transform(feats)
|
| 82 |
+
feats = np.expand_dims(feats, axis=-1)
|
| 83 |
+
return feats
|
| 84 |
+
|
| 85 |
+
def get_feats(self, path, ref_len=None):
|
| 86 |
+
video_feats, audio_feats = self.load_feature(path, ref_len)
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)).cuda(), torch.from_numpy(video_feats.astype(np.float32)).cuda()
|
| 89 |
+
if self.task.cfg.normalize:
|
| 90 |
+
audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
|
| 91 |
+
video_feats = video_feats.unsqueeze(dim=0).permute((0, 4, 1, 2, 3)).contiguous()
|
| 92 |
+
audio_feats = audio_feats.unsqueeze(dim=0).transpose(1, 2)
|
| 93 |
+
source = {'audio': audio_feats, 'video': video_feats}
|
| 94 |
+
if self.layer == 0:
|
| 95 |
+
ret_conv, output_layer = True, None
|
| 96 |
+
else:
|
| 97 |
+
ret_conv, output_layer = False, self.layer
|
| 98 |
+
feat, _ = self.model.extract_features(
|
| 99 |
+
source=source,
|
| 100 |
+
padding_mask=None,
|
| 101 |
+
mask=False,
|
| 102 |
+
output_layer=output_layer,
|
| 103 |
+
ret_conv=ret_conv
|
| 104 |
+
# output_layer=self.layer,
|
| 105 |
+
)
|
| 106 |
+
return feat.squeeze(dim=0)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_path_iterator(tsv, nshard, rank):
|
| 110 |
+
with open(tsv, "r") as f:
|
| 111 |
+
root = f.readline().rstrip()
|
| 112 |
+
lines = [line.rstrip() for line in f]
|
| 113 |
+
tot = len(lines)
|
| 114 |
+
shard_size = math.ceil(tot / nshard)
|
| 115 |
+
start, end = rank * shard_size, min((rank + 1) * shard_size, tot)
|
| 116 |
+
assert start < end, "start={start}, end={end}"
|
| 117 |
+
logger.info(
|
| 118 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
| 119 |
+
f"({start}-{end}) out of {tot}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
lines = lines[start:end]
|
| 123 |
+
|
| 124 |
+
def iterate():
|
| 125 |
+
for line in lines:
|
| 126 |
+
items = line.strip().split("\t")
|
| 127 |
+
# audio_path = f"{items[1]}:{items[0]}"
|
| 128 |
+
yield (items[1], items[2]+':'+items[0]), int(items[3])
|
| 129 |
+
|
| 130 |
+
return iterate, len(lines)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def dump_feature(
|
| 134 |
+
tsv_dir, split, ckpt_path, layer, nshard, rank, feat_dir, max_chunk, custom_utils=None, **kwargs
|
| 135 |
+
):
|
| 136 |
+
reader = HubertFeatureReader(ckpt_path, layer, max_chunk, custom_utils=custom_utils)
|
| 137 |
+
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
| 138 |
+
iterator = generator()
|
| 139 |
+
|
| 140 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
| 141 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
| 142 |
+
|
| 143 |
+
os.makedirs(feat_dir, exist_ok=True)
|
| 144 |
+
if os.path.exists(feat_path):
|
| 145 |
+
os.remove(feat_path)
|
| 146 |
+
|
| 147 |
+
feat_f = NpyAppendArray(feat_path)
|
| 148 |
+
with open(leng_path, "w") as leng_f:
|
| 149 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
| 150 |
+
feat = reader.get_feats(path, nsample)
|
| 151 |
+
feat_f.append(feat.cpu().numpy())
|
| 152 |
+
leng_f.write(f"{len(feat)}\n")
|
| 153 |
+
logger.info("finished successfully")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
import argparse
|
| 158 |
+
|
| 159 |
+
parser = argparse.ArgumentParser()
|
| 160 |
+
parser.add_argument("tsv_dir")
|
| 161 |
+
parser.add_argument("split")
|
| 162 |
+
parser.add_argument("ckpt_path")
|
| 163 |
+
parser.add_argument("layer", type=int)
|
| 164 |
+
parser.add_argument("nshard", type=int)
|
| 165 |
+
parser.add_argument("rank", type=int)
|
| 166 |
+
parser.add_argument("feat_dir")
|
| 167 |
+
parser.add_argument("--max_chunk", type=int, default=1600000)
|
| 168 |
+
parser.add_argument("--user_dir", type=str, default=None)
|
| 169 |
+
|
| 170 |
+
args = parser.parse_args()
|
| 171 |
+
logger.info(args)
|
| 172 |
+
fairseq.utils.import_user_module(args)
|
| 173 |
+
sys.path.append(args.user_dir)
|
| 174 |
+
import utils as custom_utils
|
| 175 |
+
kwargs = vars(args)
|
| 176 |
+
kwargs.update({'custom_utils': custom_utils})
|
| 177 |
+
dump_feature(**kwargs)
|
av_hubert/avhubert/clustering/dump_km_label.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import joblib
|
| 14 |
+
import torch
|
| 15 |
+
import tqdm
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 19 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 20 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 21 |
+
stream=sys.stdout,
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger("dump_km_label")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ApplyKmeans(object):
|
| 27 |
+
def __init__(self, km_path):
|
| 28 |
+
self.km_model = joblib.load(km_path)
|
| 29 |
+
self.C_np = self.km_model.cluster_centers_.transpose()
|
| 30 |
+
self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
|
| 31 |
+
|
| 32 |
+
self.C = torch.from_numpy(self.C_np)
|
| 33 |
+
self.Cnorm = torch.from_numpy(self.Cnorm_np)
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
self.C = self.C.cuda()
|
| 36 |
+
self.Cnorm = self.Cnorm.cuda()
|
| 37 |
+
|
| 38 |
+
def __call__(self, x):
|
| 39 |
+
if isinstance(x, torch.Tensor):
|
| 40 |
+
dist = (
|
| 41 |
+
x.pow(2).sum(1, keepdim=True)
|
| 42 |
+
- 2 * torch.matmul(x, self.C)
|
| 43 |
+
+ self.Cnorm
|
| 44 |
+
)
|
| 45 |
+
return dist.argmin(dim=1).cpu().numpy()
|
| 46 |
+
else:
|
| 47 |
+
dist = (
|
| 48 |
+
(x ** 2).sum(1, keepdims=True)
|
| 49 |
+
- 2 * np.matmul(x, self.C_np)
|
| 50 |
+
+ self.Cnorm_np
|
| 51 |
+
)
|
| 52 |
+
return np.argmin(dist, axis=1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_feat_iterator(feat_dir, split, nshard, rank):
|
| 56 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
| 57 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
| 58 |
+
with open(leng_path, "r") as f:
|
| 59 |
+
lengs = [int(line.rstrip()) for line in f]
|
| 60 |
+
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
| 61 |
+
|
| 62 |
+
def iterate():
|
| 63 |
+
feat = np.load(feat_path, mmap_mode="r")
|
| 64 |
+
assert feat.shape[0] == (offsets[-1] + lengs[-1])
|
| 65 |
+
for offset, leng in zip(offsets, lengs):
|
| 66 |
+
yield feat[offset: offset + leng]
|
| 67 |
+
|
| 68 |
+
return iterate, len(lengs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dump_label(feat_dir, split, km_path, nshard, rank, lab_dir):
|
| 72 |
+
apply_kmeans = ApplyKmeans(km_path)
|
| 73 |
+
generator, num = get_feat_iterator(feat_dir, split, nshard, rank)
|
| 74 |
+
iterator = generator()
|
| 75 |
+
|
| 76 |
+
lab_path = f"{lab_dir}/{split}_{rank}_{nshard}.km"
|
| 77 |
+
os.makedirs(lab_dir, exist_ok=True)
|
| 78 |
+
with open(lab_path, "w") as f:
|
| 79 |
+
for feat in tqdm.tqdm(iterator, total=num):
|
| 80 |
+
# feat = torch.from_numpy(feat).cuda()
|
| 81 |
+
lab = apply_kmeans(feat).tolist()
|
| 82 |
+
f.write(" ".join(map(str, lab)) + "\n")
|
| 83 |
+
logger.info("finished successfully")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
import argparse
|
| 88 |
+
|
| 89 |
+
parser = argparse.ArgumentParser()
|
| 90 |
+
parser.add_argument("feat_dir")
|
| 91 |
+
parser.add_argument("split")
|
| 92 |
+
parser.add_argument("km_path")
|
| 93 |
+
parser.add_argument("nshard", type=int)
|
| 94 |
+
parser.add_argument("rank", type=int)
|
| 95 |
+
parser.add_argument("lab_dir")
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
logging.info(str(args))
|
| 98 |
+
|
| 99 |
+
dump_label(**vars(args))
|
av_hubert/avhubert/clustering/dump_mfcc_feature.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
import tqdm
|
| 16 |
+
from npy_append_array import NpyAppendArray
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 21 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 22 |
+
stream=sys.stdout,
|
| 23 |
+
)
|
| 24 |
+
logger = logging.getLogger("dump_mfcc_feature")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MfccFeatureReader(object):
|
| 28 |
+
def __init__(self, sample_rate):
|
| 29 |
+
self.sample_rate = sample_rate
|
| 30 |
+
|
| 31 |
+
def read_audio(self, path, ref_len=None):
|
| 32 |
+
wav, sr = sf.read(path)
|
| 33 |
+
assert sr == self.sample_rate, sr
|
| 34 |
+
if wav.ndim == 2:
|
| 35 |
+
wav = wav.mean(-1)
|
| 36 |
+
assert wav.ndim == 1, wav.ndim
|
| 37 |
+
if ref_len is not None and abs(ref_len - len(wav)) > 160:
|
| 38 |
+
logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
|
| 39 |
+
return wav
|
| 40 |
+
|
| 41 |
+
def get_feats(self, path, ref_len=None):
|
| 42 |
+
x = self.read_audio(path, ref_len)
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
x = torch.from_numpy(x).float()
|
| 45 |
+
x = x.view(1, -1)
|
| 46 |
+
|
| 47 |
+
mfccs = torchaudio.compliance.kaldi.mfcc(
|
| 48 |
+
waveform=x,
|
| 49 |
+
sample_frequency=self.sample_rate,
|
| 50 |
+
use_energy=False,
|
| 51 |
+
) # (time, freq)
|
| 52 |
+
mfccs = mfccs.transpose(0, 1) # (freq, time)
|
| 53 |
+
deltas = torchaudio.functional.compute_deltas(mfccs)
|
| 54 |
+
ddeltas = torchaudio.functional.compute_deltas(deltas)
|
| 55 |
+
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
|
| 56 |
+
concat = concat.transpose(0, 1).contiguous() # (freq, time)
|
| 57 |
+
return concat
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_path_iterator(tsv, nshard, rank):
|
| 61 |
+
with open(tsv, "r") as f:
|
| 62 |
+
root = f.readline().rstrip()
|
| 63 |
+
lines = [line.rstrip() for line in f]
|
| 64 |
+
tot = len(lines)
|
| 65 |
+
shard_size = math.ceil(tot / nshard)
|
| 66 |
+
start, end = rank * shard_size, min((rank + 1) * shard_size, tot)
|
| 67 |
+
assert start < end, "start={start}, end={end}"
|
| 68 |
+
logger.info(
|
| 69 |
+
f"rank {rank} of {nshard}, process {end-start} "
|
| 70 |
+
f"({start}-{end}) out of {tot}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
lines = lines[start:end]
|
| 74 |
+
|
| 75 |
+
def iterate():
|
| 76 |
+
for line in lines:
|
| 77 |
+
_, video_path, wav_path, nsample_video, nsample_wav = line.split("\t")
|
| 78 |
+
yield f"{root}/{wav_path}", int(nsample_wav)
|
| 79 |
+
|
| 80 |
+
return iterate, len(lines)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def dump_feature(tsv_dir, split, nshard, rank, feat_dir, sample_rate=16_000):
|
| 84 |
+
reader = MfccFeatureReader(sample_rate)
|
| 85 |
+
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
|
| 86 |
+
iterator = generator()
|
| 87 |
+
|
| 88 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
| 89 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
| 90 |
+
|
| 91 |
+
os.makedirs(feat_dir, exist_ok=True)
|
| 92 |
+
if os.path.exists(feat_path):
|
| 93 |
+
os.remove(feat_path)
|
| 94 |
+
|
| 95 |
+
feat_f = NpyAppendArray(feat_path)
|
| 96 |
+
with open(leng_path, "w") as leng_f:
|
| 97 |
+
for path, nsample in tqdm.tqdm(iterator, total=num):
|
| 98 |
+
feat = reader.get_feats(path, nsample)
|
| 99 |
+
feat_f.append(feat.cpu().numpy())
|
| 100 |
+
leng_f.write(f"{len(feat)}\n")
|
| 101 |
+
logger.info("finished successfully")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
import argparse
|
| 106 |
+
|
| 107 |
+
parser = argparse.ArgumentParser()
|
| 108 |
+
parser.add_argument("tsv_dir")
|
| 109 |
+
parser.add_argument("split")
|
| 110 |
+
parser.add_argument("nshard", type=int)
|
| 111 |
+
parser.add_argument("rank", type=int)
|
| 112 |
+
parser.add_argument("feat_dir")
|
| 113 |
+
parser.add_argument("--sample_rate", type=int, default=16000)
|
| 114 |
+
args = parser.parse_args()
|
| 115 |
+
logger.info(args)
|
| 116 |
+
|
| 117 |
+
dump_feature(**vars(args))
|
av_hubert/avhubert/clustering/learn_kmeans.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from sklearn.cluster import MiniBatchKMeans
|
| 13 |
+
|
| 14 |
+
import joblib
|
| 15 |
+
|
| 16 |
+
logging.basicConfig(
|
| 17 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 18 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 19 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 20 |
+
stream=sys.stdout,
|
| 21 |
+
)
|
| 22 |
+
logger = logging.getLogger("learn_kmeans")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_km_model(
|
| 26 |
+
n_clusters,
|
| 27 |
+
init,
|
| 28 |
+
max_iter,
|
| 29 |
+
batch_size,
|
| 30 |
+
tol,
|
| 31 |
+
max_no_improvement,
|
| 32 |
+
n_init,
|
| 33 |
+
reassignment_ratio,
|
| 34 |
+
):
|
| 35 |
+
return MiniBatchKMeans(
|
| 36 |
+
n_clusters=n_clusters,
|
| 37 |
+
init=init,
|
| 38 |
+
max_iter=max_iter,
|
| 39 |
+
batch_size=batch_size,
|
| 40 |
+
verbose=1,
|
| 41 |
+
compute_labels=False,
|
| 42 |
+
tol=tol,
|
| 43 |
+
max_no_improvement=max_no_improvement,
|
| 44 |
+
init_size=None,
|
| 45 |
+
n_init=n_init,
|
| 46 |
+
reassignment_ratio=reassignment_ratio,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_feature_shard(feat_dir, split, nshard, rank, percent):
|
| 51 |
+
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
|
| 52 |
+
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
|
| 53 |
+
with open(leng_path, "r") as f:
|
| 54 |
+
lengs = [int(line.rstrip()) for line in f]
|
| 55 |
+
offsets = [0] + np.cumsum(lengs[:-1]).tolist()
|
| 56 |
+
|
| 57 |
+
if percent < 0:
|
| 58 |
+
return np.load(feat_path, mmap_mode="r")
|
| 59 |
+
else:
|
| 60 |
+
nsample = int(np.ceil(len(lengs) * percent))
|
| 61 |
+
indices = np.random.choice(len(lengs), nsample, replace=False)
|
| 62 |
+
feat = np.load(feat_path, mmap_mode="r")
|
| 63 |
+
sampled_feat = np.concatenate(
|
| 64 |
+
[feat[offsets[i]: offsets[i] + lengs[i]] for i in indices], axis=0
|
| 65 |
+
)
|
| 66 |
+
logger.info(
|
| 67 |
+
(
|
| 68 |
+
f"sampled {nsample} utterances, {len(sampled_feat)} frames "
|
| 69 |
+
f"from shard {rank}/{nshard}"
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
return sampled_feat
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_feature(feat_dir, split, nshard, seed, percent):
|
| 76 |
+
assert percent <= 1.0
|
| 77 |
+
feat = np.concatenate(
|
| 78 |
+
[
|
| 79 |
+
load_feature_shard(feat_dir, split, nshard, r, percent)
|
| 80 |
+
for r in range(nshard)
|
| 81 |
+
],
|
| 82 |
+
axis=0,
|
| 83 |
+
)
|
| 84 |
+
logging.info(f"loaded feature with dimension {feat.shape}")
|
| 85 |
+
return feat
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def learn_kmeans(
|
| 89 |
+
feat_dir,
|
| 90 |
+
split,
|
| 91 |
+
nshard,
|
| 92 |
+
km_path,
|
| 93 |
+
n_clusters,
|
| 94 |
+
seed,
|
| 95 |
+
percent,
|
| 96 |
+
init,
|
| 97 |
+
max_iter,
|
| 98 |
+
batch_size,
|
| 99 |
+
tol,
|
| 100 |
+
n_init,
|
| 101 |
+
reassignment_ratio,
|
| 102 |
+
max_no_improvement,
|
| 103 |
+
):
|
| 104 |
+
np.random.seed(seed)
|
| 105 |
+
feat = load_feature(feat_dir, split, nshard, seed, percent)
|
| 106 |
+
km_model = get_km_model(
|
| 107 |
+
n_clusters,
|
| 108 |
+
init,
|
| 109 |
+
max_iter,
|
| 110 |
+
batch_size,
|
| 111 |
+
tol,
|
| 112 |
+
max_no_improvement,
|
| 113 |
+
n_init,
|
| 114 |
+
reassignment_ratio,
|
| 115 |
+
)
|
| 116 |
+
km_model.fit(feat)
|
| 117 |
+
joblib.dump(km_model, km_path)
|
| 118 |
+
|
| 119 |
+
inertia = -km_model.score(feat) / len(feat)
|
| 120 |
+
logger.info("total intertia: %.5f", inertia)
|
| 121 |
+
logger.info("finished successfully")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
import argparse
|
| 126 |
+
|
| 127 |
+
parser = argparse.ArgumentParser()
|
| 128 |
+
parser.add_argument("feat_dir", type=str)
|
| 129 |
+
parser.add_argument("split", type=str)
|
| 130 |
+
parser.add_argument("nshard", type=int)
|
| 131 |
+
parser.add_argument("km_path", type=str)
|
| 132 |
+
parser.add_argument("n_clusters", type=int)
|
| 133 |
+
parser.add_argument("--seed", default=0, type=int)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--percent", default=-1, type=float, help="sample a subset; -1 for all"
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument("--init", default="k-means++")
|
| 138 |
+
parser.add_argument("--max_iter", default=100, type=int)
|
| 139 |
+
parser.add_argument("--batch_size", default=10000, type=int)
|
| 140 |
+
parser.add_argument("--tol", default=0.0, type=float)
|
| 141 |
+
parser.add_argument("--max_no_improvement", default=100, type=int)
|
| 142 |
+
parser.add_argument("--n_init", default=20, type=int)
|
| 143 |
+
parser.add_argument("--reassignment_ratio", default=0.0, type=float)
|
| 144 |
+
args = parser.parse_args()
|
| 145 |
+
logging.info(str(args))
|
| 146 |
+
|
| 147 |
+
learn_kmeans(**vars(args))
|
av_hubert/avhubert/clustering/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
soundfile
|
| 2 |
+
joblib
|
| 3 |
+
sklearn
|
| 4 |
+
torchaudio==0.10.1
|
| 5 |
+
npy-append-array==0.9.13
|
| 6 |
+
submitit==1.4.1
|
av_hubert/avhubert/clustering/submit_cluster.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os, subprocess
|
| 8 |
+
import submitit
|
| 9 |
+
import argparse
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
|
| 12 |
+
def dump_av_hubert(*args, **kwargs):
|
| 13 |
+
from dump_hubert_feature import dump_feature
|
| 14 |
+
import fairseq
|
| 15 |
+
import sys
|
| 16 |
+
av_hubert_dir = os.path.join(os.getcwd(), '..')
|
| 17 |
+
fairseq.utils.import_user_module(Namespace(user_dir=av_hubert_dir))
|
| 18 |
+
sys.path.append(av_hubert_dir)
|
| 19 |
+
import utils as custom_utils
|
| 20 |
+
kwargs.update({'custom_utils': custom_utils})
|
| 21 |
+
args = args[0]
|
| 22 |
+
dump_feature(*args, **kwargs)
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def dump_mfcc(*args, **kwargs):
|
| 27 |
+
from dump_mfcc_feature import dump_feature
|
| 28 |
+
args = args[0]
|
| 29 |
+
dump_feature(*args, **kwargs)
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
def run_kmeans(*args, **kwargs):
|
| 33 |
+
import sys
|
| 34 |
+
from learn_kmeans import learn_kmeans
|
| 35 |
+
learn_kmeans(*args, **kwargs)
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
def apply_kmeans(*args, **kwargs):
|
| 39 |
+
import sys
|
| 40 |
+
from dump_km_label import dump_label
|
| 41 |
+
args = args[0]
|
| 42 |
+
dump_label(*args, **kwargs)
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
def concatenate(*args, **kwargs):
|
| 46 |
+
from concat import main as concat_fn
|
| 47 |
+
args = args[0]
|
| 48 |
+
concat_fn(*args, **kwargs)
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
parser = argparse.ArgumentParser(description='clustering', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 53 |
+
parser.add_argument('--tsv', type=str, help='tsv dir')
|
| 54 |
+
parser.add_argument('--output', type=str, help='output dir (labels)')
|
| 55 |
+
parser.add_argument('--ckpt', type=str, help='checkpoint of last iteration')
|
| 56 |
+
parser.add_argument('--nlayer', type=int, default=12, help='layer index for clustering')
|
| 57 |
+
parser.add_argument('--ncluster', type=int, default=500, help='number of clusters')
|
| 58 |
+
parser.add_argument('--nshard', type=int, default=100, help='number of shards')
|
| 59 |
+
parser.add_argument('--percent', type=float, default=0.05, help='Percentage for clustering')
|
| 60 |
+
parser.add_argument('--mfcc', action='store_true', help='extracting MFCC feature')
|
| 61 |
+
parser.add_argument('--slurm-partition', type=str, help='slurm partitions')
|
| 62 |
+
args = parser.parse_args()
|
| 63 |
+
tsv_dir = args.tsv
|
| 64 |
+
output_dir = args.output
|
| 65 |
+
km_dir = output_dir
|
| 66 |
+
feat_dir = output_dir
|
| 67 |
+
ckpt_path = args.ckpt
|
| 68 |
+
nlayer = args.nlayer
|
| 69 |
+
nshard = args.nshard
|
| 70 |
+
n_clusters = args.ncluster
|
| 71 |
+
slurm_partition = args.slurm_partition
|
| 72 |
+
is_mfcc = args.mfcc
|
| 73 |
+
timeout_min = 240
|
| 74 |
+
percent = 0.1
|
| 75 |
+
log_folder = "log_submit/%j"
|
| 76 |
+
km_path = f"{km_dir}/kmeans.mdl"
|
| 77 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 78 |
+
ext = submitit.AutoExecutor(folder=log_folder)
|
| 79 |
+
|
| 80 |
+
args_array = []
|
| 81 |
+
if is_mfcc:
|
| 82 |
+
print(f"Dump MFCC feature")
|
| 83 |
+
for rank in range(nshard):
|
| 84 |
+
args = [tsv_dir, 'train', nshard, rank, output_dir]
|
| 85 |
+
args_array.append(args)
|
| 86 |
+
args_array.append([tsv_dir, 'valid', 1, 0, output_dir])
|
| 87 |
+
ext.update_parameters(timeout_min=60, slurm_partition=slurm_partition, cpus_per_task=1, slurm_array_parallelism=100)
|
| 88 |
+
jobs = ext.map_array(dump_mfcc, args_array)
|
| 89 |
+
else:
|
| 90 |
+
print(f"Dump AV-Hubert feature")
|
| 91 |
+
for rank in range(nshard):
|
| 92 |
+
args = [tsv_dir, 'train', ckpt_path, nlayer, nshard, rank, output_dir, 1600000]
|
| 93 |
+
args_array.append(args)
|
| 94 |
+
args_array.append([tsv_dir, 'valid', ckpt_path, nlayer, 1, 0, output_dir, 1600000])
|
| 95 |
+
ext.update_parameters(timeout_min=60, slurm_partition=slurm_partition, cpus_per_task=1, gpus_per_node=1, slurm_array_parallelism=100)
|
| 96 |
+
jobs = ext.map_array(dump_av_hubert, args_array)
|
| 97 |
+
[job.result() for job in jobs]
|
| 98 |
+
|
| 99 |
+
print(f"Learn K-means")
|
| 100 |
+
percent, batch_size = percent, 20000
|
| 101 |
+
ext.update_parameters(timeout_min=timeout_min, slurm_partition=slurm_partition, cpus_per_task=8, mem_gb=128)
|
| 102 |
+
args, kwargs = [feat_dir, 'train', nshard, km_path, n_clusters], vars(Namespace(seed=0, percent=percent, init="k-means++", max_iter=100, batch_size=batch_size, tol=0.0, n_init=20, reassignment_ratio=0.0, max_no_improvement=100))
|
| 103 |
+
print(args, kwargs)
|
| 104 |
+
job = ext.submit(run_kmeans, *args, **kwargs)
|
| 105 |
+
job.result()
|
| 106 |
+
|
| 107 |
+
print(f"Apply K-means")
|
| 108 |
+
args_array = []
|
| 109 |
+
for rank in range(nshard):
|
| 110 |
+
args = [feat_dir, 'train', km_path, nshard, rank, output_dir]
|
| 111 |
+
args_array.append(args)
|
| 112 |
+
args_array.append([feat_dir, 'valid', km_path, 1, 0, output_dir])
|
| 113 |
+
ext.update_parameters(timeout_min=10, slurm_partition=slurm_partition, cpus_per_task=1, slurm_array_parallelism=500)
|
| 114 |
+
jobs = ext.map_array(apply_kmeans, args_array)
|
| 115 |
+
[job.result() for job in jobs]
|
| 116 |
+
|
| 117 |
+
print(f"Concatenate labels")
|
| 118 |
+
cont = f"for rank in $(seq 0 {nshard-1}); do cat {output_dir}/train_${{rank}}_{nshard}.km; done > {output_dir}/train.km"
|
| 119 |
+
print(cont)
|
| 120 |
+
subprocess.call(cont, shell=True)
|
| 121 |
+
cont = f"cp {output_dir}/valid*.km {output_dir}/valid.km"
|
| 122 |
+
print(cont)
|
| 123 |
+
subprocess.call(cont, shell=True)
|
| 124 |
+
with open(f"{output_dir}/dict.km.txt", 'w') as fo:
|
| 125 |
+
for i in range(n_clusters):
|
| 126 |
+
fo.write(f"{i} {10000}\n")
|
| 127 |
+
print(f"Please delete intermediate files to save space: rm {output_dir}/*npy")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == '__main__':
|
| 132 |
+
main()
|
av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_30h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video","audio"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
noise_prob: 0.25
|
| 43 |
+
noise_snr: 0
|
| 44 |
+
noise_wav: ???
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
num_workers: 6
|
| 48 |
+
max_tokens: 1000
|
| 49 |
+
validate_after_updates: 0
|
| 50 |
+
validate_interval: 2
|
| 51 |
+
train_subset: train
|
| 52 |
+
valid_subset: valid
|
| 53 |
+
|
| 54 |
+
criterion:
|
| 55 |
+
_name: label_smoothed_cross_entropy
|
| 56 |
+
report_accuracy: true
|
| 57 |
+
label_smoothing: 0.1
|
| 58 |
+
|
| 59 |
+
optimization:
|
| 60 |
+
max_update: 30000
|
| 61 |
+
lr: [0.001]
|
| 62 |
+
sentence_avg: true
|
| 63 |
+
update_freq: [1]
|
| 64 |
+
|
| 65 |
+
optimizer:
|
| 66 |
+
_name: adam
|
| 67 |
+
adam_betas: (0.9,0.98)
|
| 68 |
+
adam_eps: 1e-08
|
| 69 |
+
|
| 70 |
+
lr_scheduler:
|
| 71 |
+
_name: tri_stage
|
| 72 |
+
warmup_steps: 10000
|
| 73 |
+
hold_steps: 0
|
| 74 |
+
decay_steps: 20000
|
| 75 |
+
final_lr_scale: 0.05
|
| 76 |
+
|
| 77 |
+
model:
|
| 78 |
+
_name: av_hubert_seq2seq
|
| 79 |
+
w2v_path: ???
|
| 80 |
+
apply_mask: false
|
| 81 |
+
mask_selection: static
|
| 82 |
+
mask_length: 10
|
| 83 |
+
mask_other: 0
|
| 84 |
+
mask_prob: 0.75
|
| 85 |
+
mask_channel_selection: static
|
| 86 |
+
mask_channel_length: 64
|
| 87 |
+
mask_channel_other: 0
|
| 88 |
+
mask_channel_prob: 0.5
|
| 89 |
+
layerdrop: 0.1
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
activation_dropout: 0.1
|
| 92 |
+
attention_dropout: 0.0
|
| 93 |
+
feature_grad_mult: 1.0
|
| 94 |
+
decoder_layers: 6
|
| 95 |
+
decoder_dropout: 0.1
|
| 96 |
+
decoder_attention_dropout: 0.0
|
| 97 |
+
decoder_activation_dropout: 0.1
|
| 98 |
+
freeze_finetune_updates: 24000
|
| 99 |
+
share_decoder_input_output_embed: true
|
| 100 |
+
decoder_normalize_before: true
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/base_noise_pt_noise_ft_433h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video","audio"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
noise_prob: 0.25
|
| 43 |
+
noise_snr: 0
|
| 44 |
+
noise_wav: ???
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
num_workers: 6
|
| 48 |
+
max_tokens: 1000
|
| 49 |
+
validate_after_updates: 0
|
| 50 |
+
validate_interval: 2
|
| 51 |
+
train_subset: train
|
| 52 |
+
valid_subset: valid
|
| 53 |
+
|
| 54 |
+
criterion:
|
| 55 |
+
_name: label_smoothed_cross_entropy
|
| 56 |
+
report_accuracy: true
|
| 57 |
+
label_smoothing: 0.1
|
| 58 |
+
|
| 59 |
+
optimization:
|
| 60 |
+
max_update: 60000
|
| 61 |
+
lr: [0.001]
|
| 62 |
+
sentence_avg: true
|
| 63 |
+
update_freq: [1]
|
| 64 |
+
|
| 65 |
+
optimizer:
|
| 66 |
+
_name: adam
|
| 67 |
+
adam_betas: (0.9,0.98)
|
| 68 |
+
adam_eps: 1e-08
|
| 69 |
+
|
| 70 |
+
lr_scheduler:
|
| 71 |
+
_name: tri_stage
|
| 72 |
+
warmup_steps: 20000
|
| 73 |
+
hold_steps: 0
|
| 74 |
+
decay_steps: 40000
|
| 75 |
+
final_lr_scale: 0.05
|
| 76 |
+
|
| 77 |
+
model:
|
| 78 |
+
_name: av_hubert_seq2seq
|
| 79 |
+
w2v_path: ???
|
| 80 |
+
apply_mask: false
|
| 81 |
+
mask_selection: static
|
| 82 |
+
mask_length: 10
|
| 83 |
+
mask_other: 0
|
| 84 |
+
mask_prob: 0.75
|
| 85 |
+
mask_channel_selection: static
|
| 86 |
+
mask_channel_length: 64
|
| 87 |
+
mask_channel_other: 0
|
| 88 |
+
mask_channel_prob: 0.5
|
| 89 |
+
layerdrop: 0.1
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
activation_dropout: 0.1
|
| 92 |
+
attention_dropout: 0.0
|
| 93 |
+
feature_grad_mult: 1.0
|
| 94 |
+
decoder_layers: 6
|
| 95 |
+
decoder_dropout: 0.1
|
| 96 |
+
decoder_attention_dropout: 0.0
|
| 97 |
+
decoder_activation_dropout: 0.1
|
| 98 |
+
freeze_finetune_updates: 48000
|
| 99 |
+
share_decoder_input_output_embed: true
|
| 100 |
+
decoder_normalize_before: true
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_30h.yaml
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video","audio"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
noise_prob: 0.25
|
| 43 |
+
noise_snr: 0
|
| 44 |
+
noise_wav: ???
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
num_workers: 6
|
| 48 |
+
max_tokens: 1000
|
| 49 |
+
validate_after_updates: 0
|
| 50 |
+
validate_interval: 2
|
| 51 |
+
train_subset: train
|
| 52 |
+
valid_subset: valid
|
| 53 |
+
|
| 54 |
+
criterion:
|
| 55 |
+
_name: label_smoothed_cross_entropy
|
| 56 |
+
report_accuracy: true
|
| 57 |
+
label_smoothing: 0.1
|
| 58 |
+
|
| 59 |
+
optimization:
|
| 60 |
+
max_update: 18000
|
| 61 |
+
lr: [0.001]
|
| 62 |
+
sentence_avg: true
|
| 63 |
+
update_freq: [1]
|
| 64 |
+
|
| 65 |
+
optimizer:
|
| 66 |
+
_name: adam
|
| 67 |
+
adam_betas: (0.9,0.98)
|
| 68 |
+
adam_eps: 1e-08
|
| 69 |
+
|
| 70 |
+
lr_scheduler:
|
| 71 |
+
_name: tri_stage
|
| 72 |
+
warmup_steps: 6000
|
| 73 |
+
hold_steps: 0
|
| 74 |
+
decay_steps: 18000
|
| 75 |
+
final_lr_scale: 0.05
|
| 76 |
+
|
| 77 |
+
model:
|
| 78 |
+
_name: av_hubert_seq2seq
|
| 79 |
+
w2v_path: ???
|
| 80 |
+
apply_mask: false
|
| 81 |
+
mask_selection: static
|
| 82 |
+
mask_length: 10
|
| 83 |
+
mask_other: 0
|
| 84 |
+
mask_prob: 0.75
|
| 85 |
+
mask_channel_selection: static
|
| 86 |
+
mask_channel_length: 64
|
| 87 |
+
mask_channel_other: 0
|
| 88 |
+
mask_channel_prob: 0.5
|
| 89 |
+
layerdrop: 0.1
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
activation_dropout: 0.1
|
| 92 |
+
attention_dropout: 0.0
|
| 93 |
+
feature_grad_mult: 1.0
|
| 94 |
+
decoder_layers: 9
|
| 95 |
+
decoder_dropout: 0.1
|
| 96 |
+
decoder_attention_dropout: 0.0
|
| 97 |
+
decoder_activation_dropout: 0.1
|
| 98 |
+
freeze_finetune_updates: 30000
|
| 99 |
+
share_decoder_input_output_embed: true
|
| 100 |
+
decoder_normalize_before: true
|
| 101 |
+
decoder_embed_dim: 1024
|
| 102 |
+
decoder_ffn_embed_dim: 4096
|
| 103 |
+
decoder_attention_heads: 8
|
| 104 |
+
|
| 105 |
+
hydra:
|
| 106 |
+
job:
|
| 107 |
+
config:
|
| 108 |
+
override_dirname:
|
| 109 |
+
kv_sep: '-'
|
| 110 |
+
item_sep: '__'
|
| 111 |
+
exclude_keys:
|
| 112 |
+
- run
|
| 113 |
+
- task.data
|
| 114 |
+
- task.label_dir
|
| 115 |
+
- model.w2v_path
|
| 116 |
+
- dataset.train_subset
|
| 117 |
+
- dataset.valid_subset
|
| 118 |
+
- criterion.wer_kenlm_model
|
| 119 |
+
- criterion.wer_lexicon
|
| 120 |
+
run:
|
| 121 |
+
dir: ???
|
| 122 |
+
sweep:
|
| 123 |
+
dir: ???
|
| 124 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/av-finetune/large_noise_pt_noise_ft_433h.yaml
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video","audio"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
noise_prob: 0.25
|
| 43 |
+
noise_snr: 0
|
| 44 |
+
noise_wav: ???
|
| 45 |
+
|
| 46 |
+
dataset:
|
| 47 |
+
num_workers: 6
|
| 48 |
+
max_tokens: 1000
|
| 49 |
+
validate_after_updates: 0
|
| 50 |
+
validate_interval: 2
|
| 51 |
+
train_subset: train
|
| 52 |
+
valid_subset: valid
|
| 53 |
+
|
| 54 |
+
criterion:
|
| 55 |
+
_name: label_smoothed_cross_entropy
|
| 56 |
+
report_accuracy: true
|
| 57 |
+
label_smoothing: 0.1
|
| 58 |
+
|
| 59 |
+
optimization:
|
| 60 |
+
max_update: 60000
|
| 61 |
+
lr: [0.001]
|
| 62 |
+
sentence_avg: true
|
| 63 |
+
update_freq: [1]
|
| 64 |
+
|
| 65 |
+
optimizer:
|
| 66 |
+
_name: adam
|
| 67 |
+
adam_betas: (0.9,0.98)
|
| 68 |
+
adam_eps: 1e-08
|
| 69 |
+
|
| 70 |
+
lr_scheduler:
|
| 71 |
+
_name: tri_stage
|
| 72 |
+
warmup_steps: 20000
|
| 73 |
+
hold_steps: 0
|
| 74 |
+
decay_steps: 40000
|
| 75 |
+
final_lr_scale: 0.05
|
| 76 |
+
|
| 77 |
+
model:
|
| 78 |
+
_name: av_hubert_seq2seq
|
| 79 |
+
w2v_path: ???
|
| 80 |
+
apply_mask: false
|
| 81 |
+
mask_selection: static
|
| 82 |
+
mask_length: 10
|
| 83 |
+
mask_other: 0
|
| 84 |
+
mask_prob: 0.75
|
| 85 |
+
mask_channel_selection: static
|
| 86 |
+
mask_channel_length: 64
|
| 87 |
+
mask_channel_other: 0
|
| 88 |
+
mask_channel_prob: 0.5
|
| 89 |
+
layerdrop: 0.1
|
| 90 |
+
dropout: 0.0
|
| 91 |
+
activation_dropout: 0.1
|
| 92 |
+
attention_dropout: 0.0
|
| 93 |
+
feature_grad_mult: 1.0
|
| 94 |
+
decoder_layers: 9
|
| 95 |
+
decoder_dropout: 0.1
|
| 96 |
+
decoder_attention_dropout: 0.0
|
| 97 |
+
decoder_activation_dropout: 0.1
|
| 98 |
+
freeze_finetune_updates: 48000
|
| 99 |
+
share_decoder_input_output_embed: true
|
| 100 |
+
decoder_normalize_before: true
|
| 101 |
+
decoder_embed_dim: 1024
|
| 102 |
+
decoder_ffn_embed_dim: 4096
|
| 103 |
+
decoder_attention_heads: 8
|
| 104 |
+
|
| 105 |
+
hydra:
|
| 106 |
+
job:
|
| 107 |
+
config:
|
| 108 |
+
override_dirname:
|
| 109 |
+
kv_sep: '-'
|
| 110 |
+
item_sep: '__'
|
| 111 |
+
exclude_keys:
|
| 112 |
+
- run
|
| 113 |
+
- task.data
|
| 114 |
+
- task.label_dir
|
| 115 |
+
- model.w2v_path
|
| 116 |
+
- dataset.train_subset
|
| 117 |
+
- dataset.valid_subset
|
| 118 |
+
- criterion.wer_kenlm_model
|
| 119 |
+
- criterion.wer_lexicon
|
| 120 |
+
run:
|
| 121 |
+
dir: ???
|
| 122 |
+
sweep:
|
| 123 |
+
dir: ???
|
| 124 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_lrs3_30h.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 30000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 20000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 6
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 30000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
- model.w2v_path
|
| 110 |
+
- dataset.train_subset
|
| 111 |
+
- dataset.valid_subset
|
| 112 |
+
- criterion.wer_kenlm_model
|
| 113 |
+
- criterion.wer_lexicon
|
| 114 |
+
run:
|
| 115 |
+
dir: ???
|
| 116 |
+
sweep:
|
| 117 |
+
dir: ???
|
| 118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_lrs3_433h.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 120000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 40000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 80000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 6
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 60000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
- model.w2v_path
|
| 110 |
+
- dataset.train_subset
|
| 111 |
+
- dataset.valid_subset
|
| 112 |
+
- criterion.wer_kenlm_model
|
| 113 |
+
- criterion.wer_lexicon
|
| 114 |
+
run:
|
| 115 |
+
dir: ???
|
| 116 |
+
sweep:
|
| 117 |
+
dir: ???
|
| 118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_vox_30h.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 30000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 20000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 6
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 24000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
- model.w2v_path
|
| 110 |
+
- dataset.train_subset
|
| 111 |
+
- dataset.valid_subset
|
| 112 |
+
- criterion.wer_kenlm_model
|
| 113 |
+
- criterion.wer_lexicon
|
| 114 |
+
run:
|
| 115 |
+
dir: ???
|
| 116 |
+
sweep:
|
| 117 |
+
dir: ???
|
| 118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/base_vox_433h.yaml
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 45000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 15000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 30000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 6
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 22500
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
- model.w2v_path
|
| 110 |
+
- dataset.train_subset
|
| 111 |
+
- dataset.valid_subset
|
| 112 |
+
- criterion.wer_kenlm_model
|
| 113 |
+
- criterion.wer_lexicon
|
| 114 |
+
run:
|
| 115 |
+
dir: ???
|
| 116 |
+
sweep:
|
| 117 |
+
dir: ???
|
| 118 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_lrs3_30h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 18000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 6000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 12000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 14400
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_lrs3_433h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 30000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 20000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 18000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_vox_30h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 30000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 20000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 30000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/large_vox_433h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 8
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 30000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 20000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 30000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/self_large_vox_30h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 32
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 100000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 90000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 80000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/finetune/self_large_vox_433h.yaml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
tensorboard_logdir: tblog
|
| 8 |
+
seed: 1337
|
| 9 |
+
user_dir: ???
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval: 2
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
best_checkpoint_metric: accuracy
|
| 16 |
+
maximize_best_checkpoint_metric: true
|
| 17 |
+
|
| 18 |
+
distributed_training:
|
| 19 |
+
ddp_backend: c10d
|
| 20 |
+
find_unused_parameters: true
|
| 21 |
+
distributed_world_size: 32
|
| 22 |
+
distributed_port: 29671
|
| 23 |
+
nprocs_per_node: 8
|
| 24 |
+
|
| 25 |
+
task:
|
| 26 |
+
_name: av_hubert_pretraining
|
| 27 |
+
is_s2s: true
|
| 28 |
+
data: ???
|
| 29 |
+
label_dir: ???
|
| 30 |
+
tokenizer_bpe_model: ???
|
| 31 |
+
normalize: true # must be consistent with pre-training
|
| 32 |
+
labels: ["wrd"]
|
| 33 |
+
single_target: true
|
| 34 |
+
fine_tuning: true
|
| 35 |
+
stack_order_audio: 4
|
| 36 |
+
tokenizer_bpe_name: sentencepiece
|
| 37 |
+
max_sample_size: 500
|
| 38 |
+
modalities: ["video"]
|
| 39 |
+
image_aug: true
|
| 40 |
+
pad_audio: true
|
| 41 |
+
random_crop: false
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
num_workers: 6
|
| 45 |
+
max_tokens: 1000
|
| 46 |
+
validate_after_updates: 0
|
| 47 |
+
validate_interval: 2
|
| 48 |
+
train_subset: train
|
| 49 |
+
valid_subset: valid
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: label_smoothed_cross_entropy
|
| 53 |
+
report_accuracy: true
|
| 54 |
+
label_smoothing: 0.1
|
| 55 |
+
|
| 56 |
+
optimization:
|
| 57 |
+
max_update: 100000
|
| 58 |
+
lr: [0.001]
|
| 59 |
+
sentence_avg: true
|
| 60 |
+
update_freq: [1]
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-08
|
| 66 |
+
|
| 67 |
+
lr_scheduler:
|
| 68 |
+
_name: tri_stage
|
| 69 |
+
warmup_steps: 10000
|
| 70 |
+
hold_steps: 0
|
| 71 |
+
decay_steps: 90000
|
| 72 |
+
final_lr_scale: 0.05
|
| 73 |
+
|
| 74 |
+
model:
|
| 75 |
+
_name: av_hubert_seq2seq
|
| 76 |
+
w2v_path: ???
|
| 77 |
+
apply_mask: false
|
| 78 |
+
mask_selection: static
|
| 79 |
+
mask_length: 10
|
| 80 |
+
mask_other: 0
|
| 81 |
+
mask_prob: 0.75
|
| 82 |
+
mask_channel_selection: static
|
| 83 |
+
mask_channel_length: 64
|
| 84 |
+
mask_channel_other: 0
|
| 85 |
+
mask_channel_prob: 0.5
|
| 86 |
+
layerdrop: 0.1
|
| 87 |
+
dropout: 0.0
|
| 88 |
+
activation_dropout: 0.1
|
| 89 |
+
attention_dropout: 0.0
|
| 90 |
+
feature_grad_mult: 1.0
|
| 91 |
+
decoder_layers: 9
|
| 92 |
+
decoder_dropout: 0.1
|
| 93 |
+
decoder_attention_dropout: 0.0
|
| 94 |
+
decoder_activation_dropout: 0.1
|
| 95 |
+
freeze_finetune_updates: 80000
|
| 96 |
+
share_decoder_input_output_embed: true
|
| 97 |
+
decoder_normalize_before: true
|
| 98 |
+
decoder_embed_dim: 1024
|
| 99 |
+
decoder_ffn_embed_dim: 4096
|
| 100 |
+
decoder_attention_heads: 8
|
| 101 |
+
|
| 102 |
+
hydra:
|
| 103 |
+
job:
|
| 104 |
+
config:
|
| 105 |
+
override_dirname:
|
| 106 |
+
kv_sep: '-'
|
| 107 |
+
item_sep: '__'
|
| 108 |
+
exclude_keys:
|
| 109 |
+
- run
|
| 110 |
+
- task.data
|
| 111 |
+
- task.label_dir
|
| 112 |
+
- model.w2v_path
|
| 113 |
+
- dataset.train_subset
|
| 114 |
+
- dataset.valid_subset
|
| 115 |
+
- criterion.wer_kenlm_model
|
| 116 |
+
- criterion.wer_lexicon
|
| 117 |
+
run:
|
| 118 |
+
dir: ???
|
| 119 |
+
sweep:
|
| 120 |
+
dir: ???
|
| 121 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter1.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["mfcc"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 500
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: true
|
| 34 |
+
random_crop: false
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 6
|
| 43 |
+
max_tokens: 1000
|
| 44 |
+
skip_invalid_size_inputs_valid_test: true
|
| 45 |
+
validate_interval: 5
|
| 46 |
+
validate_interval_updates: 10000
|
| 47 |
+
|
| 48 |
+
criterion:
|
| 49 |
+
_name: av_hubert
|
| 50 |
+
pred_masked_weight: 1.0
|
| 51 |
+
pred_nomask_weight: 0.0
|
| 52 |
+
loss_weights: [10,]
|
| 53 |
+
|
| 54 |
+
optimization:
|
| 55 |
+
max_update: 400000
|
| 56 |
+
lr: [0.0005]
|
| 57 |
+
clip_norm: 10.0
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_name: adam
|
| 61 |
+
adam_betas: (0.9,0.98)
|
| 62 |
+
adam_eps: 1e-06
|
| 63 |
+
weight_decay: 0.01
|
| 64 |
+
|
| 65 |
+
lr_scheduler:
|
| 66 |
+
_name: polynomial_decay
|
| 67 |
+
warmup_updates: 32000
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
_name: av_hubert
|
| 71 |
+
label_rate: 100
|
| 72 |
+
skip_masked: false
|
| 73 |
+
skip_nomask: false
|
| 74 |
+
modality_dropout: 0
|
| 75 |
+
audio_dropout: 0.5
|
| 76 |
+
modality_fuse: concat
|
| 77 |
+
selection_type: same_seq
|
| 78 |
+
masking_type: feature
|
| 79 |
+
mask_prob_image: 0.8
|
| 80 |
+
mask_length_image: 10
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter2.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["mfcc"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 500
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: true
|
| 34 |
+
random_crop: false
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 6
|
| 43 |
+
max_tokens: 1000
|
| 44 |
+
skip_invalid_size_inputs_valid_test: true
|
| 45 |
+
validate_interval: 5
|
| 46 |
+
validate_interval_updates: 10000
|
| 47 |
+
|
| 48 |
+
criterion:
|
| 49 |
+
_name: av_hubert
|
| 50 |
+
pred_masked_weight: 1.0
|
| 51 |
+
pred_nomask_weight: 0.0
|
| 52 |
+
loss_weights: [10,]
|
| 53 |
+
|
| 54 |
+
optimization:
|
| 55 |
+
max_update: 400000
|
| 56 |
+
lr: [0.0005]
|
| 57 |
+
clip_norm: 10.0
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_name: adam
|
| 61 |
+
adam_betas: (0.9,0.98)
|
| 62 |
+
adam_eps: 1e-06
|
| 63 |
+
weight_decay: 0.01
|
| 64 |
+
|
| 65 |
+
lr_scheduler:
|
| 66 |
+
_name: polynomial_decay
|
| 67 |
+
warmup_updates: 32000
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
_name: av_hubert
|
| 71 |
+
label_rate: 25
|
| 72 |
+
skip_masked: false
|
| 73 |
+
skip_nomask: false
|
| 74 |
+
modality_dropout: 0
|
| 75 |
+
audio_dropout: 0.5
|
| 76 |
+
modality_fuse: concat
|
| 77 |
+
selection_type: same_seq
|
| 78 |
+
masking_type: feature
|
| 79 |
+
mask_prob_image: 0.8
|
| 80 |
+
mask_length_image: 10
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter3.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["mfcc"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 500
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: true
|
| 34 |
+
random_crop: false
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 6
|
| 43 |
+
max_tokens: 1000
|
| 44 |
+
skip_invalid_size_inputs_valid_test: true
|
| 45 |
+
validate_interval: 5
|
| 46 |
+
validate_interval_updates: 10000
|
| 47 |
+
|
| 48 |
+
criterion:
|
| 49 |
+
_name: av_hubert
|
| 50 |
+
pred_masked_weight: 1.0
|
| 51 |
+
pred_nomask_weight: 0.0
|
| 52 |
+
loss_weights: [10,]
|
| 53 |
+
|
| 54 |
+
optimization:
|
| 55 |
+
max_update: 400000
|
| 56 |
+
lr: [0.0005]
|
| 57 |
+
clip_norm: 10.0
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_name: adam
|
| 61 |
+
adam_betas: (0.9,0.98)
|
| 62 |
+
adam_eps: 1e-06
|
| 63 |
+
weight_decay: 0.01
|
| 64 |
+
|
| 65 |
+
lr_scheduler:
|
| 66 |
+
_name: polynomial_decay
|
| 67 |
+
warmup_updates: 32000
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
_name: av_hubert
|
| 71 |
+
label_rate: 25
|
| 72 |
+
skip_masked: false
|
| 73 |
+
skip_nomask: false
|
| 74 |
+
modality_dropout: 0
|
| 75 |
+
audio_dropout: 0.5
|
| 76 |
+
modality_fuse: concat
|
| 77 |
+
selection_type: same_seq
|
| 78 |
+
masking_type: feature
|
| 79 |
+
mask_prob_image: 0.8
|
| 80 |
+
mask_length_image: 10
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter4.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["mfcc"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 500
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: true
|
| 34 |
+
random_crop: false
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 6
|
| 43 |
+
max_tokens: 1000
|
| 44 |
+
skip_invalid_size_inputs_valid_test: true
|
| 45 |
+
validate_interval: 5
|
| 46 |
+
validate_interval_updates: 10000
|
| 47 |
+
|
| 48 |
+
criterion:
|
| 49 |
+
_name: av_hubert
|
| 50 |
+
pred_masked_weight: 1.0
|
| 51 |
+
pred_nomask_weight: 0.0
|
| 52 |
+
loss_weights: [10,]
|
| 53 |
+
|
| 54 |
+
optimization:
|
| 55 |
+
max_update: 400000
|
| 56 |
+
lr: [0.0005]
|
| 57 |
+
clip_norm: 10.0
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_name: adam
|
| 61 |
+
adam_betas: (0.9,0.98)
|
| 62 |
+
adam_eps: 1e-06
|
| 63 |
+
weight_decay: 0.01
|
| 64 |
+
|
| 65 |
+
lr_scheduler:
|
| 66 |
+
_name: polynomial_decay
|
| 67 |
+
warmup_updates: 32000
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
_name: av_hubert
|
| 71 |
+
label_rate: 25
|
| 72 |
+
skip_masked: false
|
| 73 |
+
skip_nomask: false
|
| 74 |
+
modality_dropout: 0
|
| 75 |
+
audio_dropout: 0.5
|
| 76 |
+
modality_fuse: concat
|
| 77 |
+
selection_type: same_seq
|
| 78 |
+
masking_type: feature
|
| 79 |
+
mask_prob_image: 0.8
|
| 80 |
+
mask_length_image: 10
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_lrs3_iter5.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 500
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: true
|
| 34 |
+
random_crop: false
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 6
|
| 43 |
+
max_tokens: 1000
|
| 44 |
+
skip_invalid_size_inputs_valid_test: true
|
| 45 |
+
validate_interval: 5
|
| 46 |
+
validate_interval_updates: 10000
|
| 47 |
+
|
| 48 |
+
criterion:
|
| 49 |
+
_name: av_hubert
|
| 50 |
+
pred_masked_weight: 1.0
|
| 51 |
+
pred_nomask_weight: 0.0
|
| 52 |
+
loss_weights: [10,]
|
| 53 |
+
|
| 54 |
+
optimization:
|
| 55 |
+
max_update: 400000
|
| 56 |
+
lr: [0.0005]
|
| 57 |
+
clip_norm: 10.0
|
| 58 |
+
|
| 59 |
+
optimizer:
|
| 60 |
+
_name: adam
|
| 61 |
+
adam_betas: (0.9,0.98)
|
| 62 |
+
adam_eps: 1e-06
|
| 63 |
+
weight_decay: 0.01
|
| 64 |
+
|
| 65 |
+
lr_scheduler:
|
| 66 |
+
_name: polynomial_decay
|
| 67 |
+
warmup_updates: 32000
|
| 68 |
+
|
| 69 |
+
model:
|
| 70 |
+
_name: av_hubert
|
| 71 |
+
label_rate: ???
|
| 72 |
+
skip_masked: false
|
| 73 |
+
skip_nomask: false
|
| 74 |
+
modality_dropout: 0.5
|
| 75 |
+
audio_dropout: 0.5
|
| 76 |
+
modality_fuse: concat
|
| 77 |
+
selection_type: same_seq
|
| 78 |
+
masking_type: input
|
| 79 |
+
mask_prob_image: 0.3
|
| 80 |
+
mask_length_image: 5
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter1.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["mfcc"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 0.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 800000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 64000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: 100
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: feature
|
| 80 |
+
mask_prob_image: 0.8
|
| 81 |
+
mask_length_image: 10
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
run:
|
| 110 |
+
dir: ???
|
| 111 |
+
sweep:
|
| 112 |
+
dir: ???
|
| 113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter2.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 0.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 800000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 64000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: 25
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: feature
|
| 80 |
+
mask_prob_image: 0.8
|
| 81 |
+
mask_length_image: 10
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
run:
|
| 110 |
+
dir: ???
|
| 111 |
+
sweep:
|
| 112 |
+
dir: ???
|
| 113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter3.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 0.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 800000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 64000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: 25
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: feature
|
| 80 |
+
mask_prob_image: 0.8
|
| 81 |
+
mask_length_image: 10
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
run:
|
| 110 |
+
dir: ???
|
| 111 |
+
sweep:
|
| 112 |
+
dir: ???
|
| 113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter4.yaml
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 0.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 800000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 64000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: 25
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
masking_type: feature
|
| 79 |
+
mask_prob_image: 0.8
|
| 80 |
+
mask_length_image: 10
|
| 81 |
+
mask_prob_audio: 0.8
|
| 82 |
+
mask_length_audio: 10
|
| 83 |
+
extractor_mode: default
|
| 84 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 85 |
+
final_dim: 256
|
| 86 |
+
encoder_layerdrop: 0.05
|
| 87 |
+
dropout_input: 0.1
|
| 88 |
+
dropout_features: 0.1
|
| 89 |
+
dropout: 0.1
|
| 90 |
+
attention_dropout: 0.1
|
| 91 |
+
feature_grad_mult: 0.1
|
| 92 |
+
untie_final_proj: true
|
| 93 |
+
activation_dropout: 0.0
|
| 94 |
+
wav_input: false
|
| 95 |
+
layer_norm_first: true
|
| 96 |
+
audio_feat_dim: 104
|
| 97 |
+
|
| 98 |
+
hydra:
|
| 99 |
+
job:
|
| 100 |
+
config:
|
| 101 |
+
override_dirname:
|
| 102 |
+
kv_sep: '-'
|
| 103 |
+
item_sep: '__'
|
| 104 |
+
exclude_keys:
|
| 105 |
+
- run
|
| 106 |
+
- task.data
|
| 107 |
+
- task.label_dir
|
| 108 |
+
run:
|
| 109 |
+
dir: ???
|
| 110 |
+
sweep:
|
| 111 |
+
dir: ???
|
| 112 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/base_vox_iter5.yaml
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 0.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 800000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 64000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: ???
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: input
|
| 80 |
+
mask_prob_image: 0.3
|
| 81 |
+
mask_length_image: 5
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
|
| 99 |
+
hydra:
|
| 100 |
+
job:
|
| 101 |
+
config:
|
| 102 |
+
override_dirname:
|
| 103 |
+
kv_sep: '-'
|
| 104 |
+
item_sep: '__'
|
| 105 |
+
exclude_keys:
|
| 106 |
+
- run
|
| 107 |
+
- task.data
|
| 108 |
+
- task.label_dir
|
| 109 |
+
run:
|
| 110 |
+
dir: ???
|
| 111 |
+
sweep:
|
| 112 |
+
dir: ???
|
| 113 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/large_lrs3_iter5.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 64
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 1.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 400000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 32000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: 25
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: input
|
| 80 |
+
mask_prob_image: 0.3
|
| 81 |
+
mask_length_image: 5
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
encoder_layers: 24
|
| 99 |
+
encoder_embed_dim: 1024
|
| 100 |
+
encoder_ffn_embed_dim: 4096
|
| 101 |
+
encoder_attention_heads: 16
|
| 102 |
+
|
| 103 |
+
hydra:
|
| 104 |
+
job:
|
| 105 |
+
config:
|
| 106 |
+
override_dirname:
|
| 107 |
+
kv_sep: '-'
|
| 108 |
+
item_sep: '__'
|
| 109 |
+
exclude_keys:
|
| 110 |
+
- run
|
| 111 |
+
- task.data
|
| 112 |
+
- task.label_dir
|
| 113 |
+
run:
|
| 114 |
+
dir: ???
|
| 115 |
+
sweep:
|
| 116 |
+
dir: ???
|
| 117 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/large_vox_iter5.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 64
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
# stack_order: 1
|
| 38 |
+
input_modality: image
|
| 39 |
+
image_aug: true
|
| 40 |
+
max_trim_sample_size: 400
|
| 41 |
+
|
| 42 |
+
dataset:
|
| 43 |
+
num_workers: 6
|
| 44 |
+
max_tokens: 1000
|
| 45 |
+
skip_invalid_size_inputs_valid_test: true
|
| 46 |
+
validate_interval: 5
|
| 47 |
+
validate_interval_updates: 10000
|
| 48 |
+
|
| 49 |
+
criterion:
|
| 50 |
+
_name: av_hubert
|
| 51 |
+
pred_masked_weight: 1.0
|
| 52 |
+
pred_nomask_weight: 1.0
|
| 53 |
+
loss_weights: [10,]
|
| 54 |
+
|
| 55 |
+
optimization:
|
| 56 |
+
max_update: 600000
|
| 57 |
+
lr: [0.002]
|
| 58 |
+
clip_norm: 10.0
|
| 59 |
+
|
| 60 |
+
optimizer:
|
| 61 |
+
_name: adam
|
| 62 |
+
adam_betas: (0.9,0.98)
|
| 63 |
+
adam_eps: 1e-06
|
| 64 |
+
weight_decay: 0.01
|
| 65 |
+
|
| 66 |
+
lr_scheduler:
|
| 67 |
+
_name: polynomial_decay
|
| 68 |
+
warmup_updates: 48000
|
| 69 |
+
|
| 70 |
+
model:
|
| 71 |
+
_name: av_hubert
|
| 72 |
+
label_rate: ???
|
| 73 |
+
skip_masked: false
|
| 74 |
+
skip_nomask: false
|
| 75 |
+
modality_dropout: 0.5
|
| 76 |
+
audio_dropout: 0.5
|
| 77 |
+
modality_fuse: concat
|
| 78 |
+
selection_type: same_seq
|
| 79 |
+
masking_type: input
|
| 80 |
+
mask_prob_image: 0.3
|
| 81 |
+
mask_length_image: 5
|
| 82 |
+
mask_prob_audio: 0.8
|
| 83 |
+
mask_length_audio: 10
|
| 84 |
+
extractor_mode: default
|
| 85 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
+
final_dim: 256
|
| 87 |
+
encoder_layerdrop: 0.05
|
| 88 |
+
dropout_input: 0.1
|
| 89 |
+
dropout_features: 0.1
|
| 90 |
+
dropout: 0.1
|
| 91 |
+
attention_dropout: 0.1
|
| 92 |
+
feature_grad_mult: 0.1
|
| 93 |
+
untie_final_proj: true
|
| 94 |
+
activation_dropout: 0.0
|
| 95 |
+
wav_input: false
|
| 96 |
+
layer_norm_first: true
|
| 97 |
+
audio_feat_dim: 104
|
| 98 |
+
encoder_layers: 24
|
| 99 |
+
encoder_embed_dim: 1024
|
| 100 |
+
encoder_ffn_embed_dim: 4096
|
| 101 |
+
encoder_attention_heads: 16
|
| 102 |
+
|
| 103 |
+
hydra:
|
| 104 |
+
job:
|
| 105 |
+
config:
|
| 106 |
+
override_dirname:
|
| 107 |
+
kv_sep: '-'
|
| 108 |
+
item_sep: '__'
|
| 109 |
+
exclude_keys:
|
| 110 |
+
- run
|
| 111 |
+
- task.data
|
| 112 |
+
- task.label_dir
|
| 113 |
+
run:
|
| 114 |
+
dir: ???
|
| 115 |
+
sweep:
|
| 116 |
+
dir: ???
|
| 117 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/noise_base_vox_iter5.yaml
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 32
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
input_modality: image
|
| 38 |
+
image_aug: true
|
| 39 |
+
max_trim_sample_size: 400
|
| 40 |
+
noise_prob: 0.25
|
| 41 |
+
noise_snr: 0
|
| 42 |
+
noise_wav: ???
|
| 43 |
+
|
| 44 |
+
dataset:
|
| 45 |
+
num_workers: 6
|
| 46 |
+
max_tokens: 1000
|
| 47 |
+
skip_invalid_size_inputs_valid_test: true
|
| 48 |
+
validate_interval: 5
|
| 49 |
+
validate_interval_updates: 10000
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: av_hubert
|
| 53 |
+
pred_masked_weight: 1.0
|
| 54 |
+
pred_nomask_weight: 0.0
|
| 55 |
+
loss_weights: [10,]
|
| 56 |
+
|
| 57 |
+
optimization:
|
| 58 |
+
max_update: 800000
|
| 59 |
+
lr: [0.002]
|
| 60 |
+
clip_norm: 10.0
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-06
|
| 66 |
+
weight_decay: 0.01
|
| 67 |
+
|
| 68 |
+
lr_scheduler:
|
| 69 |
+
_name: polynomial_decay
|
| 70 |
+
warmup_updates: 64000
|
| 71 |
+
|
| 72 |
+
model:
|
| 73 |
+
_name: av_hubert
|
| 74 |
+
label_rate: ???
|
| 75 |
+
skip_masked: false
|
| 76 |
+
skip_nomask: false
|
| 77 |
+
modality_dropout: 0.5
|
| 78 |
+
audio_dropout: 0.5
|
| 79 |
+
modality_fuse: concat
|
| 80 |
+
selection_type: same_seq
|
| 81 |
+
masking_type: input
|
| 82 |
+
mask_prob_image: 0.3
|
| 83 |
+
mask_length_image: 5
|
| 84 |
+
mask_prob_audio: 0.8
|
| 85 |
+
mask_length_audio: 10
|
| 86 |
+
extractor_mode: default
|
| 87 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
+
final_dim: 256
|
| 89 |
+
encoder_layerdrop: 0.05
|
| 90 |
+
dropout_input: 0.1
|
| 91 |
+
dropout_features: 0.1
|
| 92 |
+
dropout: 0.1
|
| 93 |
+
attention_dropout: 0.1
|
| 94 |
+
feature_grad_mult: 0.1
|
| 95 |
+
untie_final_proj: true
|
| 96 |
+
activation_dropout: 0.0
|
| 97 |
+
wav_input: false
|
| 98 |
+
layer_norm_first: true
|
| 99 |
+
audio_feat_dim: 104
|
| 100 |
+
|
| 101 |
+
hydra:
|
| 102 |
+
job:
|
| 103 |
+
config:
|
| 104 |
+
override_dirname:
|
| 105 |
+
kv_sep: '-'
|
| 106 |
+
item_sep: '__'
|
| 107 |
+
exclude_keys:
|
| 108 |
+
- run
|
| 109 |
+
- task.data
|
| 110 |
+
- task.label_dir
|
| 111 |
+
run:
|
| 112 |
+
dir: ???
|
| 113 |
+
sweep:
|
| 114 |
+
dir: ???
|
| 115 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/pretrain/noise_large_vox_iter5.yaml
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
|
| 3 |
+
common:
|
| 4 |
+
fp16: true
|
| 5 |
+
log_format: json
|
| 6 |
+
log_interval: 200
|
| 7 |
+
seed: 1337
|
| 8 |
+
user_dir: ???
|
| 9 |
+
empty_cache_freq: 10000
|
| 10 |
+
|
| 11 |
+
checkpoint:
|
| 12 |
+
save_interval_updates: 25000
|
| 13 |
+
keep_interval_updates: 1
|
| 14 |
+
no_epoch_checkpoints: true
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
distributed_training:
|
| 18 |
+
ddp_backend: no_c10d
|
| 19 |
+
distributed_backend: 'nccl'
|
| 20 |
+
distributed_world_size: 64
|
| 21 |
+
distributed_port: 29671
|
| 22 |
+
nprocs_per_node: 8
|
| 23 |
+
|
| 24 |
+
task:
|
| 25 |
+
_name: av_hubert_pretraining
|
| 26 |
+
data: ???
|
| 27 |
+
label_dir: ???
|
| 28 |
+
labels: ["km"]
|
| 29 |
+
label_rate: ${model.label_rate}
|
| 30 |
+
sample_rate: 25
|
| 31 |
+
max_sample_size: 2000
|
| 32 |
+
min_sample_size: 5
|
| 33 |
+
pad_audio: false
|
| 34 |
+
random_crop: true
|
| 35 |
+
normalize: true
|
| 36 |
+
stack_order_audio: 4
|
| 37 |
+
input_modality: image
|
| 38 |
+
image_aug: true
|
| 39 |
+
max_trim_sample_size: 400
|
| 40 |
+
noise_prob: 0.25
|
| 41 |
+
noise_snr: 0
|
| 42 |
+
noise_wav: ???
|
| 43 |
+
|
| 44 |
+
dataset:
|
| 45 |
+
num_workers: 6
|
| 46 |
+
max_tokens: 1000
|
| 47 |
+
skip_invalid_size_inputs_valid_test: true
|
| 48 |
+
validate_interval: 5
|
| 49 |
+
validate_interval_updates: 10000
|
| 50 |
+
|
| 51 |
+
criterion:
|
| 52 |
+
_name: av_hubert
|
| 53 |
+
pred_masked_weight: 1.0
|
| 54 |
+
pred_nomask_weight: 1.0
|
| 55 |
+
loss_weights: [10,]
|
| 56 |
+
|
| 57 |
+
optimization:
|
| 58 |
+
max_update: 600000
|
| 59 |
+
lr: [0.002]
|
| 60 |
+
clip_norm: 10.0
|
| 61 |
+
|
| 62 |
+
optimizer:
|
| 63 |
+
_name: adam
|
| 64 |
+
adam_betas: (0.9,0.98)
|
| 65 |
+
adam_eps: 1e-06
|
| 66 |
+
weight_decay: 0.01
|
| 67 |
+
|
| 68 |
+
lr_scheduler:
|
| 69 |
+
_name: polynomial_decay
|
| 70 |
+
warmup_updates: 48000
|
| 71 |
+
|
| 72 |
+
model:
|
| 73 |
+
_name: av_hubert
|
| 74 |
+
label_rate: ???
|
| 75 |
+
skip_masked: false
|
| 76 |
+
skip_nomask: false
|
| 77 |
+
modality_dropout: 0.5
|
| 78 |
+
audio_dropout: 0.5
|
| 79 |
+
modality_fuse: concat
|
| 80 |
+
selection_type: same_seq
|
| 81 |
+
masking_type: input
|
| 82 |
+
mask_prob_image: 0.3
|
| 83 |
+
mask_length_image: 5
|
| 84 |
+
mask_prob_audio: 0.8
|
| 85 |
+
mask_length_audio: 10
|
| 86 |
+
extractor_mode: default
|
| 87 |
+
# conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
+
final_dim: 256
|
| 89 |
+
encoder_layerdrop: 0.05
|
| 90 |
+
dropout_input: 0.1
|
| 91 |
+
dropout_features: 0.1
|
| 92 |
+
dropout: 0.1
|
| 93 |
+
attention_dropout: 0.1
|
| 94 |
+
feature_grad_mult: 0.1
|
| 95 |
+
untie_final_proj: true
|
| 96 |
+
activation_dropout: 0.0
|
| 97 |
+
wav_input: false
|
| 98 |
+
layer_norm_first: true
|
| 99 |
+
audio_feat_dim: 104
|
| 100 |
+
encoder_layers: 24
|
| 101 |
+
encoder_embed_dim: 1024
|
| 102 |
+
encoder_ffn_embed_dim: 4096
|
| 103 |
+
encoder_attention_heads: 16
|
| 104 |
+
|
| 105 |
+
hydra:
|
| 106 |
+
job:
|
| 107 |
+
config:
|
| 108 |
+
override_dirname:
|
| 109 |
+
kv_sep: '-'
|
| 110 |
+
item_sep: '__'
|
| 111 |
+
exclude_keys:
|
| 112 |
+
- run
|
| 113 |
+
- task.data
|
| 114 |
+
- task.label_dir
|
| 115 |
+
run:
|
| 116 |
+
dir: ???
|
| 117 |
+
sweep:
|
| 118 |
+
dir: ???
|
| 119 |
+
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
av_hubert/avhubert/conf/s2s_decode.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
common:
|
| 2 |
+
user_dir: ???
|
| 3 |
+
|
| 4 |
+
generation:
|
| 5 |
+
beam: 50
|
| 6 |
+
max_len_a: 1.0
|
| 7 |
+
max_len_b: 0
|
| 8 |
+
lenpen: 1.0
|
| 9 |
+
lm_weight: 0
|
| 10 |
+
|
| 11 |
+
common_eval:
|
| 12 |
+
results_path: ???
|
| 13 |
+
path: ???
|
| 14 |
+
|
| 15 |
+
dataset:
|
| 16 |
+
max_tokens: 1000
|
| 17 |
+
gen_subset: valid
|
| 18 |
+
num_workers: 0
|
| 19 |
+
|
| 20 |
+
override:
|
| 21 |
+
noise_prob: 0.0
|
| 22 |
+
noise_snr: 0
|
| 23 |
+
modalities: ???
|
av_hubert/avhubert/decoder.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
import contextlib
|
| 9 |
+
import copy
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from omegaconf import MISSING, II, open_dict
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
+
|
| 19 |
+
from fairseq import checkpoint_utils, tasks, utils
|
| 20 |
+
from fairseq.dataclass import FairseqDataclass
|
| 21 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 22 |
+
from fairseq.tasks import FairseqTask
|
| 23 |
+
from fairseq.models import (
|
| 24 |
+
BaseFairseqModel,
|
| 25 |
+
FairseqEncoder,
|
| 26 |
+
FairseqEncoderDecoderModel,
|
| 27 |
+
FairseqIncrementalDecoder,
|
| 28 |
+
register_model,
|
| 29 |
+
)
|
| 30 |
+
# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
| 31 |
+
from fairseq.modules import (
|
| 32 |
+
LayerNorm,
|
| 33 |
+
PositionalEmbedding,
|
| 34 |
+
TransformerDecoderLayer,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TransformerDecoder(FairseqIncrementalDecoder):
|
| 39 |
+
"""
|
| 40 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
| 41 |
+
is a :class:`TransformerDecoderLayer`.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 45 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
| 46 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
| 47 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 48 |
+
(default: False).
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
cfg,
|
| 54 |
+
dictionary,
|
| 55 |
+
embed_tokens,
|
| 56 |
+
no_encoder_attn=False,
|
| 57 |
+
):
|
| 58 |
+
super().__init__(dictionary)
|
| 59 |
+
|
| 60 |
+
self.dropout = cfg.decoder_dropout
|
| 61 |
+
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
| 62 |
+
|
| 63 |
+
input_embed_dim = embed_tokens.embedding_dim
|
| 64 |
+
embed_dim = cfg.decoder_embed_dim
|
| 65 |
+
self.output_embed_dim = cfg.decoder_embed_dim
|
| 66 |
+
|
| 67 |
+
self.layerdrop = cfg.decoder_layerdrop
|
| 68 |
+
|
| 69 |
+
padding_idx = embed_tokens.padding_idx
|
| 70 |
+
self.max_target_positions = cfg.max_target_positions
|
| 71 |
+
|
| 72 |
+
self.embed_tokens = embed_tokens
|
| 73 |
+
# self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
|
| 74 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
| 75 |
+
|
| 76 |
+
self.project_in_dim = (
|
| 77 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
| 78 |
+
if embed_dim != input_embed_dim
|
| 79 |
+
else None
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.embed_positions = (
|
| 83 |
+
PositionalEmbedding(
|
| 84 |
+
cfg.max_target_positions,
|
| 85 |
+
embed_dim,
|
| 86 |
+
padding_idx,
|
| 87 |
+
learned=cfg.decoder_learned_pos,
|
| 88 |
+
)
|
| 89 |
+
if not cfg.no_token_positional_embeddings
|
| 90 |
+
else None
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# TODO: update this when transformer gets converted to dataclass configs
|
| 94 |
+
transformer_cfg = copy.deepcopy(cfg)
|
| 95 |
+
# with open_dict(transformer_cfg):
|
| 96 |
+
transformer_cfg.dropout = transformer_cfg.decoder_dropout
|
| 97 |
+
transformer_cfg.attention_dropout = (
|
| 98 |
+
transformer_cfg.decoder_attention_dropout
|
| 99 |
+
)
|
| 100 |
+
transformer_cfg.activation_dropout = (
|
| 101 |
+
transformer_cfg.decoder_activation_dropout
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.layers = nn.ModuleList([])
|
| 105 |
+
self.layers.extend(
|
| 106 |
+
[
|
| 107 |
+
TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
|
| 108 |
+
for _ in range(transformer_cfg.decoder_layers)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not self.share_input_output_embed:
|
| 113 |
+
self.embed_out = nn.Parameter(
|
| 114 |
+
torch.Tensor(len(dictionary), self.output_embed_dim)
|
| 115 |
+
)
|
| 116 |
+
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
|
| 117 |
+
|
| 118 |
+
if transformer_cfg.decoder_normalize_before:
|
| 119 |
+
self.layer_norm = LayerNorm(embed_dim)
|
| 120 |
+
else:
|
| 121 |
+
self.layer_norm = None
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
| 125 |
+
):
|
| 126 |
+
"""
|
| 127 |
+
Args:
|
| 128 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 129 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 130 |
+
encoder_out (Tensor, optional): output from the encoder, used for
|
| 131 |
+
encoder-side attention
|
| 132 |
+
incremental_state (dict): dictionary used for storing state during
|
| 133 |
+
:ref:`Incremental decoding`
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
tuple:
|
| 137 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 138 |
+
- a dictionary with any model-specific outputs
|
| 139 |
+
"""
|
| 140 |
+
prev_output_tokens = prev_output_tokens.long()
|
| 141 |
+
x, extra = self.extract_features(
|
| 142 |
+
prev_output_tokens, encoder_out, incremental_state
|
| 143 |
+
)
|
| 144 |
+
x = self.output_layer(x)
|
| 145 |
+
return x, extra
|
| 146 |
+
|
| 147 |
+
def extract_features(
|
| 148 |
+
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
|
| 149 |
+
):
|
| 150 |
+
"""
|
| 151 |
+
Similar to *forward* but only return features.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
tuple:
|
| 155 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 156 |
+
- a dictionary with any model-specific outputs
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
# embed positions
|
| 160 |
+
positions = (
|
| 161 |
+
self.embed_positions(
|
| 162 |
+
prev_output_tokens, incremental_state=incremental_state
|
| 163 |
+
)
|
| 164 |
+
if self.embed_positions is not None
|
| 165 |
+
else None
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if incremental_state is not None:
|
| 169 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 170 |
+
if positions is not None:
|
| 171 |
+
positions = positions[:, -1:]
|
| 172 |
+
|
| 173 |
+
# embed tokens and positions
|
| 174 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
| 175 |
+
|
| 176 |
+
if self.project_in_dim is not None:
|
| 177 |
+
x = self.project_in_dim(x)
|
| 178 |
+
|
| 179 |
+
if positions is not None:
|
| 180 |
+
x += positions
|
| 181 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 182 |
+
|
| 183 |
+
# B x T x C -> T x B x C
|
| 184 |
+
x = x.transpose(0, 1)
|
| 185 |
+
attn = None
|
| 186 |
+
|
| 187 |
+
inner_states = [x]
|
| 188 |
+
|
| 189 |
+
# decoder layers
|
| 190 |
+
for layer in self.layers:
|
| 191 |
+
dropout_probability = np.random.random()
|
| 192 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 193 |
+
x, attn, _ = layer(
|
| 194 |
+
x,
|
| 195 |
+
encoder_out["encoder_out"] if encoder_out is not None else None,
|
| 196 |
+
encoder_out["padding_mask"] if encoder_out is not None else None,
|
| 197 |
+
incremental_state,
|
| 198 |
+
self_attn_mask=self.buffered_future_mask(x)
|
| 199 |
+
if incremental_state is None
|
| 200 |
+
else None,
|
| 201 |
+
)
|
| 202 |
+
inner_states.append(x)
|
| 203 |
+
|
| 204 |
+
if self.layer_norm:
|
| 205 |
+
x = self.layer_norm(x)
|
| 206 |
+
|
| 207 |
+
# T x B x C -> B x T x C
|
| 208 |
+
x = x.transpose(0, 1)
|
| 209 |
+
|
| 210 |
+
return x, {"attn": attn, "inner_states": inner_states}
|
| 211 |
+
|
| 212 |
+
def output_layer(self, features, **kwargs):
|
| 213 |
+
"""Project features to the vocabulary size."""
|
| 214 |
+
# project back to size of vocabulary
|
| 215 |
+
emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out
|
| 216 |
+
return torch.matmul(features, emb_mat.transpose(0, 1))
|
| 217 |
+
# if self.share_input_output_embed:
|
| 218 |
+
# return F.linear(features, self.embed_tokens.weight)
|
| 219 |
+
# else:
|
| 220 |
+
# return F.linear(features, self.embed_out)
|
| 221 |
+
|
| 222 |
+
def max_positions(self):
|
| 223 |
+
"""Maximum output length supported by the decoder."""
|
| 224 |
+
if self.embed_positions is None:
|
| 225 |
+
return self.max_target_positions
|
| 226 |
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
| 227 |
+
|
| 228 |
+
def buffered_future_mask(self, tensor):
|
| 229 |
+
dim = tensor.size(0)
|
| 230 |
+
if (
|
| 231 |
+
not hasattr(self, "_future_mask")
|
| 232 |
+
or self._future_mask is None
|
| 233 |
+
or self._future_mask.device != tensor.device
|
| 234 |
+
or self._future_mask.size(0) < dim
|
| 235 |
+
):
|
| 236 |
+
self._future_mask = torch.triu(
|
| 237 |
+
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
| 238 |
+
)
|
| 239 |
+
return self._future_mask[:dim, :dim]
|
| 240 |
+
|
| 241 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 242 |
+
return state_dict
|
| 243 |
+
|
av_hubert/avhubert/hubert.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os,sys
|
| 8 |
+
import logging
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from fairseq import utils
|
| 17 |
+
from fairseq.data.data_utils import compute_mask_indices
|
| 18 |
+
from fairseq.data.dictionary import Dictionary
|
| 19 |
+
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
| 20 |
+
from fairseq.models import BaseFairseqModel, register_model
|
| 21 |
+
from fairseq.models.wav2vec.wav2vec2 import (
|
| 22 |
+
ConvFeatureExtractionModel,
|
| 23 |
+
TransformerEncoder,
|
| 24 |
+
)
|
| 25 |
+
from fairseq.modules import GradMultiply, LayerNorm
|
| 26 |
+
from copy import deepcopy
|
| 27 |
+
|
| 28 |
+
DBG=True if len(sys.argv) == 1 else False
|
| 29 |
+
|
| 30 |
+
if DBG:
|
| 31 |
+
from hubert_pretraining import (
|
| 32 |
+
AVHubertPretrainingConfig,
|
| 33 |
+
AVHubertPretrainingTask,
|
| 34 |
+
)
|
| 35 |
+
from resnet import ResEncoder
|
| 36 |
+
logging.basicConfig(
|
| 37 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 38 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 39 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
| 40 |
+
stream=sys.stdout,
|
| 41 |
+
)
|
| 42 |
+
from utils import compute_mask_indices
|
| 43 |
+
from decoder import TransformerDecoder
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
from .hubert_pretraining import (
|
| 47 |
+
AVHubertPretrainingConfig,
|
| 48 |
+
AVHubertPretrainingTask,
|
| 49 |
+
)
|
| 50 |
+
from .resnet import ResEncoder
|
| 51 |
+
from .utils import compute_mask_indices
|
| 52 |
+
from .decoder import TransformerDecoder
|
| 53 |
+
|
| 54 |
+
from omegaconf import II
|
| 55 |
+
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
| 59 |
+
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
|
| 60 |
+
["static", "uniform", "normal", "poisson"]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class AVHubertConfig(FairseqDataclass):
|
| 66 |
+
label_rate: int = II("task.label_rate")
|
| 67 |
+
input_modality: str = II("task.input_modality")
|
| 68 |
+
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
|
| 69 |
+
default="default",
|
| 70 |
+
metadata={
|
| 71 |
+
"help": "mode for feature extractor. default has a single group "
|
| 72 |
+
"norm with d groups in the first conv block, whereas layer_norm "
|
| 73 |
+
"has layer norms in every block (meant to use with normalize=True)"
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
encoder_layers: int = field(
|
| 77 |
+
default=12, metadata={"help": "num encoder layers in the transformer"}
|
| 78 |
+
)
|
| 79 |
+
encoder_embed_dim: int = field(
|
| 80 |
+
default=768, metadata={"help": "encoder embedding dimension"}
|
| 81 |
+
)
|
| 82 |
+
encoder_ffn_embed_dim: int = field(
|
| 83 |
+
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
|
| 84 |
+
)
|
| 85 |
+
encoder_attention_heads: int = field(
|
| 86 |
+
default=12, metadata={"help": "num encoder attention heads"}
|
| 87 |
+
)
|
| 88 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
| 89 |
+
default="gelu", metadata={"help": "activation function to use"}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# dropouts
|
| 93 |
+
dropout: float = field(
|
| 94 |
+
default=0.1,
|
| 95 |
+
metadata={"help": "dropout probability for the transformer"},
|
| 96 |
+
)
|
| 97 |
+
attention_dropout: float = field(
|
| 98 |
+
default=0.1,
|
| 99 |
+
metadata={"help": "dropout probability for attention weights"},
|
| 100 |
+
)
|
| 101 |
+
activation_dropout: float = field(
|
| 102 |
+
default=0.0,
|
| 103 |
+
metadata={"help": "dropout probability after activation in FFN"},
|
| 104 |
+
)
|
| 105 |
+
encoder_layerdrop: float = field(
|
| 106 |
+
default=0.0,
|
| 107 |
+
metadata={"help": "probability of dropping a tarnsformer layer"},
|
| 108 |
+
)
|
| 109 |
+
dropout_input: float = field(
|
| 110 |
+
default=0.0,
|
| 111 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
| 112 |
+
)
|
| 113 |
+
dropout_features: float = field(
|
| 114 |
+
default=0.0,
|
| 115 |
+
metadata={
|
| 116 |
+
"help": "dropout to apply to the features (after feat extr)"
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
final_dim: int = field(
|
| 121 |
+
default=0,
|
| 122 |
+
metadata={
|
| 123 |
+
"help": "project final representations and targets to this many "
|
| 124 |
+
"dimensions. set to encoder_embed_dim is <= 0"
|
| 125 |
+
},
|
| 126 |
+
)
|
| 127 |
+
untie_final_proj: bool = field(
|
| 128 |
+
default=False,
|
| 129 |
+
metadata={"help": "use separate projection for each target"},
|
| 130 |
+
)
|
| 131 |
+
layer_norm_first: bool = field(
|
| 132 |
+
default=False,
|
| 133 |
+
metadata={"help": "apply layernorm first in the transformer"},
|
| 134 |
+
)
|
| 135 |
+
conv_feature_layers: str = field(
|
| 136 |
+
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
|
| 137 |
+
metadata={
|
| 138 |
+
"help": "string describing convolutional feature extraction "
|
| 139 |
+
"layers in form of a python list that contains "
|
| 140 |
+
"[(dim, kernel_size, stride), ...]"
|
| 141 |
+
},
|
| 142 |
+
)
|
| 143 |
+
conv_bias: bool = field(
|
| 144 |
+
default=False, metadata={"help": "include bias in conv encoder"}
|
| 145 |
+
)
|
| 146 |
+
logit_temp: float = field(
|
| 147 |
+
default=0.1, metadata={"help": "temperature to divide logits by"}
|
| 148 |
+
)
|
| 149 |
+
target_glu: bool = field(
|
| 150 |
+
default=False, metadata={"help": "adds projection + glu to targets"}
|
| 151 |
+
)
|
| 152 |
+
feature_grad_mult: float = field(
|
| 153 |
+
default=1.0,
|
| 154 |
+
metadata={"help": "multiply feature extractor var grads by this"},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# masking
|
| 158 |
+
mask_length_audio: int = field(default=10, metadata={"help": "mask length"})
|
| 159 |
+
mask_prob_audio: float = field(
|
| 160 |
+
default=0.65,
|
| 161 |
+
metadata={"help": "probability of replacing a token with mask"},
|
| 162 |
+
)
|
| 163 |
+
mask_length_image: int = field(default=10, metadata={"help": "mask length"})
|
| 164 |
+
mask_prob_image: float = field(
|
| 165 |
+
default=0.65,
|
| 166 |
+
metadata={"help": "probability of replacing a token with mask"},
|
| 167 |
+
)
|
| 168 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
| 169 |
+
default="static", metadata={"help": "how to choose mask length"}
|
| 170 |
+
)
|
| 171 |
+
mask_other: float = field(
|
| 172 |
+
default=0,
|
| 173 |
+
metadata={
|
| 174 |
+
"help": "secondary mask argument "
|
| 175 |
+
"(used for more complex distributions), "
|
| 176 |
+
"see help in compute_mask_indicesh"
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
no_mask_overlap: bool = field(
|
| 180 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
| 181 |
+
)
|
| 182 |
+
mask_min_space: int = field(
|
| 183 |
+
default=1,
|
| 184 |
+
metadata={
|
| 185 |
+
"help": "min space between spans (if no overlap is enabled)"
|
| 186 |
+
},
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# channel masking
|
| 190 |
+
mask_channel_length: int = field(
|
| 191 |
+
default=10,
|
| 192 |
+
metadata={"help": "length of the mask for features (channels)"},
|
| 193 |
+
)
|
| 194 |
+
mask_channel_prob: float = field(
|
| 195 |
+
default=0.0,
|
| 196 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
| 197 |
+
)
|
| 198 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
| 199 |
+
default="static",
|
| 200 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
| 201 |
+
)
|
| 202 |
+
mask_channel_other: float = field(
|
| 203 |
+
default=0,
|
| 204 |
+
metadata={
|
| 205 |
+
"help": "secondary mask argument "
|
| 206 |
+
"(used for more complex distributions), "
|
| 207 |
+
"see help in compute_mask_indicesh"
|
| 208 |
+
},
|
| 209 |
+
)
|
| 210 |
+
no_mask_channel_overlap: bool = field(
|
| 211 |
+
default=False,
|
| 212 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
| 213 |
+
)
|
| 214 |
+
mask_channel_min_space: int = field(
|
| 215 |
+
default=1,
|
| 216 |
+
metadata={
|
| 217 |
+
"help": "min space between spans (if no overlap is enabled)"
|
| 218 |
+
},
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# positional embeddings
|
| 222 |
+
conv_pos: int = field(
|
| 223 |
+
default=128,
|
| 224 |
+
metadata={
|
| 225 |
+
"help": "number of filters for convolutional positional embeddings"
|
| 226 |
+
},
|
| 227 |
+
)
|
| 228 |
+
conv_pos_groups: int = field(
|
| 229 |
+
default=16,
|
| 230 |
+
metadata={
|
| 231 |
+
"help": "number of groups for convolutional positional embedding"
|
| 232 |
+
},
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
latent_temp: Tuple[float, float, float] = field(
|
| 236 |
+
default=(2, 0.5, 0.999995),
|
| 237 |
+
metadata={"help": "legacy (to be removed)"},
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# loss computation
|
| 241 |
+
skip_masked: bool = field(
|
| 242 |
+
default=False,
|
| 243 |
+
metadata={"help": "skip computing losses over masked frames"},
|
| 244 |
+
)
|
| 245 |
+
skip_nomask: bool = field(
|
| 246 |
+
default=False,
|
| 247 |
+
metadata={"help": "skip computing losses over unmasked frames"},
|
| 248 |
+
)
|
| 249 |
+
resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'})
|
| 250 |
+
resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'})
|
| 251 |
+
sim_type: str = field(default='cosine', metadata={"help": 'similarity type'})
|
| 252 |
+
|
| 253 |
+
sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'})
|
| 254 |
+
audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'})
|
| 255 |
+
modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'})
|
| 256 |
+
audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'})
|
| 257 |
+
modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'})
|
| 258 |
+
selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'})
|
| 259 |
+
masking_type : str = field(default='input', metadata={'help': 'input or feature masking'})
|
| 260 |
+
|
| 261 |
+
decoder_embed_dim: int = field(
|
| 262 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
| 263 |
+
)
|
| 264 |
+
decoder_ffn_embed_dim: int = field(
|
| 265 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
| 266 |
+
)
|
| 267 |
+
decoder_layers: int = field(
|
| 268 |
+
default=6, metadata={"help": "num of decoder layers"}
|
| 269 |
+
)
|
| 270 |
+
decoder_layerdrop: float = field(
|
| 271 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
| 272 |
+
)
|
| 273 |
+
decoder_attention_heads: int = field(
|
| 274 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
| 275 |
+
)
|
| 276 |
+
decoder_learned_pos: bool = field(
|
| 277 |
+
default=False,
|
| 278 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
| 279 |
+
)
|
| 280 |
+
decoder_normalize_before: bool = field(
|
| 281 |
+
default=False,
|
| 282 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
| 283 |
+
)
|
| 284 |
+
no_token_positional_embeddings: bool = field(
|
| 285 |
+
default=False,
|
| 286 |
+
metadata={
|
| 287 |
+
"help": "if set, disables positional embeddings "
|
| 288 |
+
"(outside self attention)"
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
decoder_dropout: float = field(
|
| 292 |
+
default=0.1, metadata={"help": "dropout probability in the decoder"}
|
| 293 |
+
)
|
| 294 |
+
decoder_attention_dropout: float = field(
|
| 295 |
+
default=0.1,
|
| 296 |
+
metadata={
|
| 297 |
+
"help": "dropout probability for attention weights "
|
| 298 |
+
"inside the decoder"
|
| 299 |
+
},
|
| 300 |
+
)
|
| 301 |
+
decoder_activation_dropout: float = field(
|
| 302 |
+
default=0.0,
|
| 303 |
+
metadata={
|
| 304 |
+
"help": "dropout probability after activation in FFN "
|
| 305 |
+
"inside the decoder"
|
| 306 |
+
},
|
| 307 |
+
)
|
| 308 |
+
max_target_positions: int = field(
|
| 309 |
+
default=2048, metadata={"help": "max target positions"}
|
| 310 |
+
)
|
| 311 |
+
share_decoder_input_output_embed: bool = field(
|
| 312 |
+
default=False,
|
| 313 |
+
metadata={"help": "share decoder input and output embeddings"},
|
| 314 |
+
)
|
| 315 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
| 316 |
+
|
| 317 |
+
class SubModel(nn.Module):
|
| 318 |
+
def __init__(self, resnet=None, input_dim=None, cfg=None):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.resnet = resnet
|
| 321 |
+
self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim)
|
| 322 |
+
self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
if self.resnet is not None:
|
| 326 |
+
x = self.resnet(x)
|
| 327 |
+
x = self.proj(x.transpose(1, 2))
|
| 328 |
+
if self.encoder is not None:
|
| 329 |
+
x = self.encoder(x)[0].transpose(1, 2)
|
| 330 |
+
else:
|
| 331 |
+
x = x.transpose(1, 2)
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
@register_model("av_hubert", dataclass=AVHubertConfig)
|
| 335 |
+
class AVHubertModel(BaseFairseqModel):
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
cfg: AVHubertConfig,
|
| 339 |
+
task_cfg: AVHubertPretrainingConfig,
|
| 340 |
+
dictionaries: List[Dictionary],
|
| 341 |
+
**kwargs
|
| 342 |
+
) -> None:
|
| 343 |
+
super().__init__()
|
| 344 |
+
logger.info(f"HubertModel Config: {cfg}")
|
| 345 |
+
|
| 346 |
+
feature_ds_rate = 1
|
| 347 |
+
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
| 348 |
+
sub_cfg = deepcopy(cfg)
|
| 349 |
+
sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers
|
| 350 |
+
resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights)
|
| 351 |
+
self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg)
|
| 352 |
+
self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg)
|
| 353 |
+
self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout
|
| 354 |
+
self.modality_fuse = cfg.modality_fuse
|
| 355 |
+
self.encoder_embed_dim = cfg.encoder_embed_dim
|
| 356 |
+
if self.modality_fuse == 'concat':
|
| 357 |
+
self.embed = cfg.encoder_embed_dim * 2
|
| 358 |
+
elif self.modality_fuse == 'add':
|
| 359 |
+
self.embed = cfg.encoder_embed_dim
|
| 360 |
+
self.post_extract_proj = (
|
| 361 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
| 362 |
+
if self.embed != cfg.encoder_embed_dim
|
| 363 |
+
else None
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio
|
| 367 |
+
self.mask_selection = cfg.mask_selection
|
| 368 |
+
self.mask_other = cfg.mask_other
|
| 369 |
+
self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio
|
| 370 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
| 371 |
+
self.mask_min_space = cfg.mask_min_space
|
| 372 |
+
|
| 373 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
| 374 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
| 375 |
+
self.mask_channel_other = cfg.mask_channel_other
|
| 376 |
+
self.mask_channel_length = cfg.mask_channel_length
|
| 377 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
| 378 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
| 379 |
+
|
| 380 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 381 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
| 382 |
+
|
| 383 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
| 384 |
+
self.logit_temp = cfg.logit_temp
|
| 385 |
+
self.skip_masked = cfg.skip_masked
|
| 386 |
+
self.skip_nomask = cfg.skip_nomask
|
| 387 |
+
self.sim_type = cfg.sim_type
|
| 388 |
+
self.selection_type = cfg.selection_type
|
| 389 |
+
self.masking_type = cfg.masking_type
|
| 390 |
+
|
| 391 |
+
final_dim = (
|
| 392 |
+
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
self.mask_emb = nn.Parameter(
|
| 396 |
+
torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
self.encoder = TransformerEncoder(cfg)
|
| 400 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 401 |
+
|
| 402 |
+
self.target_glu = None
|
| 403 |
+
if cfg.target_glu:
|
| 404 |
+
self.target_glu = nn.Sequential(
|
| 405 |
+
nn.Linear(final_dim, final_dim * 2), nn.GLU()
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.untie_final_proj = cfg.untie_final_proj
|
| 409 |
+
if self.untie_final_proj:
|
| 410 |
+
self.final_proj = nn.Linear(
|
| 411 |
+
cfg.encoder_embed_dim, final_dim * len(dictionaries)
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
| 415 |
+
|
| 416 |
+
# modules below are not needed during fine-tuning
|
| 417 |
+
if any([d is None for d in dictionaries]):
|
| 418 |
+
logger.info(
|
| 419 |
+
"cannot find dictionary. assume will be used for fine-tuning"
|
| 420 |
+
)
|
| 421 |
+
else:
|
| 422 |
+
self.num_classes = [len(d) for d in dictionaries]
|
| 423 |
+
self.label_embs_concat = nn.Parameter(
|
| 424 |
+
torch.FloatTensor(sum(self.num_classes), final_dim)
|
| 425 |
+
)
|
| 426 |
+
nn.init.uniform_(self.label_embs_concat)
|
| 427 |
+
|
| 428 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 429 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
| 430 |
+
|
| 431 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 432 |
+
return state_dict
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask):
|
| 436 |
+
"""Build a new model instance."""
|
| 437 |
+
|
| 438 |
+
kwargs = {}
|
| 439 |
+
model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs)
|
| 440 |
+
return model
|
| 441 |
+
|
| 442 |
+
def apply_input_mask(self, x, padding_mask, target_list):
|
| 443 |
+
B, C, T = x.shape[:3]
|
| 444 |
+
is_audio = True if len(x.shape) == 3 else False
|
| 445 |
+
if is_audio:
|
| 446 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
|
| 447 |
+
else:
|
| 448 |
+
mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
|
| 449 |
+
if mask_prob > 0:
|
| 450 |
+
|
| 451 |
+
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
|
| 452 |
+
(B, T),
|
| 453 |
+
padding_mask,
|
| 454 |
+
mask_prob,
|
| 455 |
+
mask_length,
|
| 456 |
+
self.mask_selection,
|
| 457 |
+
self.mask_other,
|
| 458 |
+
min_masks=2,
|
| 459 |
+
no_overlap=self.no_mask_overlap,
|
| 460 |
+
min_space=self.mask_min_space,
|
| 461 |
+
)
|
| 462 |
+
mask_indices_np = mask_indices
|
| 463 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 464 |
+
x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
|
| 465 |
+
if B == 1:
|
| 466 |
+
x[mask_indices] = 0
|
| 467 |
+
elif is_audio:
|
| 468 |
+
x[mask_indices] = self.mask_emb
|
| 469 |
+
elif self.selection_type == 'same_other_seq':
|
| 470 |
+
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
|
| 471 |
+
x_perm = x[perm]
|
| 472 |
+
x[mask_indices] = x_perm[mask_indices]
|
| 473 |
+
elif self.selection_type == 'same_seq':
|
| 474 |
+
batch_indexes_, other_indexes = [], []
|
| 475 |
+
for batch_index, start, end in zip(batch_indexes, starts, ends):
|
| 476 |
+
length = end-start
|
| 477 |
+
other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
|
| 478 |
+
if len(other_start) > 0:
|
| 479 |
+
other_start = np.random.choice(other_start, size=1)
|
| 480 |
+
else:
|
| 481 |
+
other_start = 0
|
| 482 |
+
other_end = other_start + length
|
| 483 |
+
other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
|
| 484 |
+
batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
|
| 485 |
+
batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
|
| 486 |
+
x[mask_indices] = x[batch_indexes, other_indexes]
|
| 487 |
+
|
| 488 |
+
x = x.transpose(1, 2).contiguous()
|
| 489 |
+
else:
|
| 490 |
+
mask_indices = None
|
| 491 |
+
|
| 492 |
+
if self.mask_channel_prob > 0:
|
| 493 |
+
logger.info(f"No mask channel prob for input masking")
|
| 494 |
+
return x, mask_indices
|
| 495 |
+
|
| 496 |
+
def apply_feature_mask(self, x, padding_mask, target_list):
|
| 497 |
+
B, T, C = x.shape
|
| 498 |
+
assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking"
|
| 499 |
+
mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image
|
| 500 |
+
if mask_prob > 0:
|
| 501 |
+
mask_indices, _, _, _ = compute_mask_indices(
|
| 502 |
+
(B, T),
|
| 503 |
+
padding_mask,
|
| 504 |
+
mask_prob,
|
| 505 |
+
mask_length,
|
| 506 |
+
self.mask_selection,
|
| 507 |
+
self.mask_other,
|
| 508 |
+
min_masks=2,
|
| 509 |
+
no_overlap=self.no_mask_overlap,
|
| 510 |
+
min_space=self.mask_min_space,
|
| 511 |
+
)
|
| 512 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 513 |
+
x[mask_indices] = self.mask_emb
|
| 514 |
+
else:
|
| 515 |
+
mask_indices = None
|
| 516 |
+
|
| 517 |
+
if self.mask_channel_prob > 0:
|
| 518 |
+
mask_channel_indices, _, _, _ = compute_mask_indices(
|
| 519 |
+
(B, C),
|
| 520 |
+
None,
|
| 521 |
+
self.mask_channel_prob,
|
| 522 |
+
self.mask_channel_length,
|
| 523 |
+
self.mask_channel_selection,
|
| 524 |
+
self.mask_channel_other,
|
| 525 |
+
no_overlap=self.no_mask_channel_overlap,
|
| 526 |
+
min_space=self.mask_channel_min_space,
|
| 527 |
+
)
|
| 528 |
+
mask_channel_indices = (
|
| 529 |
+
torch.from_numpy(mask_channel_indices)
|
| 530 |
+
.to(x.device)
|
| 531 |
+
.unsqueeze(1)
|
| 532 |
+
.expand(-1, T, -1)
|
| 533 |
+
)
|
| 534 |
+
x[mask_channel_indices] = 0
|
| 535 |
+
|
| 536 |
+
return x, mask_indices
|
| 537 |
+
|
| 538 |
+
def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor:
|
| 539 |
+
extractor = eval(f"self.feature_extractor_{modality}")
|
| 540 |
+
if self.feature_grad_mult > 0:
|
| 541 |
+
features = extractor(source)
|
| 542 |
+
if self.feature_grad_mult != 1.0:
|
| 543 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
| 544 |
+
else:
|
| 545 |
+
with torch.no_grad():
|
| 546 |
+
features = extractor(source)
|
| 547 |
+
return features
|
| 548 |
+
|
| 549 |
+
def forward_targets(
|
| 550 |
+
self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor],
|
| 551 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 552 |
+
# Trim features to ensure labels exist and then get aligned labels
|
| 553 |
+
feat_tsz = features.size(2)
|
| 554 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
| 555 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
| 556 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
| 557 |
+
features = features[..., :feat_tsz]
|
| 558 |
+
if mask_indices is not None:
|
| 559 |
+
mask_indices = mask_indices[..., :feat_tsz]
|
| 560 |
+
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
| 561 |
+
target_list = [t[:, target_inds.long()] for t in target_list]
|
| 562 |
+
return features, mask_indices, target_list
|
| 563 |
+
|
| 564 |
+
def forward_padding_mask(
|
| 565 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
| 566 |
+
) -> torch.Tensor:
|
| 567 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 568 |
+
if extra > 0:
|
| 569 |
+
padding_mask = padding_mask[:, :-extra]
|
| 570 |
+
padding_mask = padding_mask.view(
|
| 571 |
+
padding_mask.size(0), features.size(1), -1
|
| 572 |
+
)
|
| 573 |
+
padding_mask = padding_mask.all(-1)
|
| 574 |
+
return padding_mask
|
| 575 |
+
|
| 576 |
+
def compute_logits(self, feats, emb_mat):
|
| 577 |
+
# feats: [B, T, F], emb_mat: [V, F]
|
| 578 |
+
if self.sim_type == 'dot':
|
| 579 |
+
logits = torch.matmul(feats, emb_mat.transpose(0, 1))
|
| 580 |
+
elif self.sim_type == 'cosine':
|
| 581 |
+
batch_size, timesteps, emb_dim = feats.size()
|
| 582 |
+
feats_ = feats.view(-1, emb_dim)
|
| 583 |
+
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V]
|
| 584 |
+
denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V]
|
| 585 |
+
logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
|
| 586 |
+
else:
|
| 587 |
+
raise NotImplementedError
|
| 588 |
+
logits = logits / self.logit_temp
|
| 589 |
+
return logits
|
| 590 |
+
|
| 591 |
+
def forward(
|
| 592 |
+
self,
|
| 593 |
+
source: torch.Tensor,
|
| 594 |
+
target_list: Optional[List[torch.Tensor]] = None,
|
| 595 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 596 |
+
mask: bool = True,
|
| 597 |
+
features_only: bool = False,
|
| 598 |
+
output_layer: Optional[int] = None
|
| 599 |
+
) -> Dict[str, torch.Tensor]:
|
| 600 |
+
"""output layer is 1-based"""
|
| 601 |
+
src_audio, src_video = source['audio'], source['video']
|
| 602 |
+
if mask and self.masking_type == 'input':
|
| 603 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list)
|
| 604 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list)
|
| 605 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
|
| 606 |
+
else:
|
| 607 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
| 608 |
+
|
| 609 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
| 610 |
+
features_video = self.forward_features(src_video, modality='video')
|
| 611 |
+
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
|
| 612 |
+
if self.training:
|
| 613 |
+
if modality_drop_prob < self.modality_dropout:
|
| 614 |
+
if audio_drop_prob < self.audio_dropout:
|
| 615 |
+
features_audio = 0 * features_audio
|
| 616 |
+
else:
|
| 617 |
+
features_video = 0 * features_video
|
| 618 |
+
if self.modality_fuse == 'concat':
|
| 619 |
+
features = torch.cat([features_audio, features_video], dim=1)
|
| 620 |
+
elif self.modality_fuse == 'add':
|
| 621 |
+
features = features_audio + features_video
|
| 622 |
+
if target_list is not None:
|
| 623 |
+
features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list)
|
| 624 |
+
|
| 625 |
+
features_pen = features.float().pow(2).mean()
|
| 626 |
+
|
| 627 |
+
features = features.transpose(1, 2)
|
| 628 |
+
features = self.layer_norm(features)
|
| 629 |
+
|
| 630 |
+
if padding_mask is not None:
|
| 631 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 632 |
+
|
| 633 |
+
if self.post_extract_proj is not None:
|
| 634 |
+
features = self.post_extract_proj(features)
|
| 635 |
+
|
| 636 |
+
features = self.dropout_input(features)
|
| 637 |
+
if self.masking_type == 'feature' and mask:
|
| 638 |
+
x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list)
|
| 639 |
+
else:
|
| 640 |
+
x = features
|
| 641 |
+
|
| 642 |
+
# feature: (B, T, D), float
|
| 643 |
+
# target: (B, T), long
|
| 644 |
+
# x: (B, T, D), float
|
| 645 |
+
# padding_mask: (B, T), bool
|
| 646 |
+
# mask_indices: (B, T), bool
|
| 647 |
+
x, _ = self.encoder(
|
| 648 |
+
x,
|
| 649 |
+
padding_mask=padding_mask,
|
| 650 |
+
layer=None if output_layer is None else output_layer - 1
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
if features_only:
|
| 654 |
+
return {"x": x, "padding_mask": padding_mask, "features": features}
|
| 655 |
+
|
| 656 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
| 657 |
+
proj_x = self.final_proj(x)
|
| 658 |
+
if self.untie_final_proj:
|
| 659 |
+
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
|
| 660 |
+
else:
|
| 661 |
+
proj_x_list = [proj_x for _ in self.num_classes]
|
| 662 |
+
logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]]
|
| 663 |
+
mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
|
| 664 |
+
logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list]
|
| 665 |
+
target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list]
|
| 666 |
+
result = {
|
| 667 |
+
"logit_m_list": logit_m_list,
|
| 668 |
+
"logit_u_list": logit_u_list,
|
| 669 |
+
"target_m_list": target_m_list,
|
| 670 |
+
"target_u_list": target_u_list,
|
| 671 |
+
"padding_mask": padding_mask,
|
| 672 |
+
"features_pen": features_pen,
|
| 673 |
+
}
|
| 674 |
+
return result
|
| 675 |
+
|
| 676 |
+
def extract_features(
|
| 677 |
+
self,
|
| 678 |
+
source: torch.Tensor,
|
| 679 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 680 |
+
mask: bool = False,
|
| 681 |
+
ret_conv: bool = False,
|
| 682 |
+
output_layer: Optional[int] = None,
|
| 683 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 684 |
+
res = self.forward(
|
| 685 |
+
source,
|
| 686 |
+
padding_mask=padding_mask,
|
| 687 |
+
mask=mask,
|
| 688 |
+
features_only=True,
|
| 689 |
+
output_layer=output_layer,
|
| 690 |
+
)
|
| 691 |
+
feature = res["features"] if ret_conv else res["x"]
|
| 692 |
+
return feature, res["padding_mask"]
|
| 693 |
+
|
| 694 |
+
def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
|
| 695 |
+
src_audio, src_video = source['audio'], source['video']
|
| 696 |
+
if mask and self.masking_type == 'input':
|
| 697 |
+
src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
|
| 698 |
+
src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
|
| 699 |
+
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
|
| 700 |
+
else:
|
| 701 |
+
src_audio, src_video, mask_indices = src_audio, src_video, None
|
| 702 |
+
|
| 703 |
+
if src_audio is not None and src_video is None:
|
| 704 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
| 705 |
+
features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
|
| 706 |
+
elif src_audio is None and src_video is not None:
|
| 707 |
+
features_video = self.forward_features(src_video, modality='video')
|
| 708 |
+
features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1))
|
| 709 |
+
elif src_audio is not None and src_video is not None:
|
| 710 |
+
features_video = self.forward_features(src_video, modality='video')
|
| 711 |
+
features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
|
| 712 |
+
|
| 713 |
+
if self.modality_fuse == 'concat':
|
| 714 |
+
features = torch.cat([features_audio, features_video], dim=1)
|
| 715 |
+
elif self.modality_fuse == 'add':
|
| 716 |
+
features = features_audio + features_video
|
| 717 |
+
features_pen = features.float().pow(2).mean()
|
| 718 |
+
|
| 719 |
+
features = features.transpose(1, 2)
|
| 720 |
+
features = self.layer_norm(features)
|
| 721 |
+
unmasked_features = features.clone()
|
| 722 |
+
|
| 723 |
+
if padding_mask is not None:
|
| 724 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 725 |
+
|
| 726 |
+
if self.post_extract_proj is not None:
|
| 727 |
+
features = self.post_extract_proj(features)
|
| 728 |
+
|
| 729 |
+
features = self.dropout_input(features)
|
| 730 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
| 731 |
+
x = features
|
| 732 |
+
mask_indices = None
|
| 733 |
+
|
| 734 |
+
# feature: (B, T, D), float
|
| 735 |
+
# target: (B, T), long
|
| 736 |
+
# x: (B, T, D), float
|
| 737 |
+
# padding_mask: (B, T), bool
|
| 738 |
+
# mask_indices: (B, T), bool
|
| 739 |
+
x, _ = self.encoder(
|
| 740 |
+
x,
|
| 741 |
+
padding_mask=padding_mask,
|
| 742 |
+
layer=None if output_layer is None else output_layer - 1
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
return x, padding_mask
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def get_extra_losses(self, net_output):
|
| 749 |
+
extra_losses = []
|
| 750 |
+
names = []
|
| 751 |
+
if "features_pen" in net_output:
|
| 752 |
+
extra_losses.append(net_output["features_pen"])
|
| 753 |
+
names.append("features_pen")
|
| 754 |
+
|
| 755 |
+
return extra_losses, names
|
| 756 |
+
|
| 757 |
+
def remove_pretraining_modules(self):
|
| 758 |
+
self.target_glu = None
|
| 759 |
+
self.final_proj = None
|
| 760 |
+
|
| 761 |
+
def get_logits(self, net_output, is_masked=True):
|
| 762 |
+
raise NotImplementedError
|
| 763 |
+
|
| 764 |
+
def get_targets(self, net_output, is_masked=True):
|
| 765 |
+
raise NotImplementedError
|
| 766 |
+
|
| 767 |
+
def compute_nce(self, x, pos, negs):
|
| 768 |
+
neg_is_pos = (pos == negs).all(-1)
|
| 769 |
+
pos = pos.unsqueeze(0)
|
| 770 |
+
targets = torch.cat([pos, negs], dim=0)
|
| 771 |
+
|
| 772 |
+
logits = torch.cosine_similarity(
|
| 773 |
+
x.float(), targets.float(), dim=-1
|
| 774 |
+
).type_as(x)
|
| 775 |
+
logits /= self.logit_temp
|
| 776 |
+
if neg_is_pos.any():
|
| 777 |
+
logits[1:][neg_is_pos] = float("-inf")
|
| 778 |
+
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
| 779 |
+
return logits
|
av_hubert/avhubert/hubert_asr.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys,logging
|
| 8 |
+
import contextlib
|
| 9 |
+
import tempfile
|
| 10 |
+
from argparse import Namespace
|
| 11 |
+
from typing import Any, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from fairseq import checkpoint_utils, tasks, utils
|
| 17 |
+
from fairseq.dataclass import FairseqDataclass
|
| 18 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| 19 |
+
from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model
|
| 20 |
+
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
|
| 21 |
+
from fairseq.tasks import FairseqTask
|
| 22 |
+
from omegaconf import II, MISSING
|
| 23 |
+
|
| 24 |
+
DBG=True if len(sys.argv) == 1 else False
|
| 25 |
+
|
| 26 |
+
if DBG:
|
| 27 |
+
from hubert import AVHubertModel
|
| 28 |
+
from decoder import TransformerDecoder
|
| 29 |
+
else:
|
| 30 |
+
from .hubert import AVHubertModel
|
| 31 |
+
from .decoder import TransformerDecoder
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class AVHubertAsrConfig(FairseqDataclass):
|
| 38 |
+
w2v_path: str = field(
|
| 39 |
+
default=MISSING, metadata={"help": "path to hubert model"}
|
| 40 |
+
)
|
| 41 |
+
no_pretrained_weights: bool = field(
|
| 42 |
+
default=False,
|
| 43 |
+
metadata={"help": "if true, does not load pretrained weights"},
|
| 44 |
+
)
|
| 45 |
+
dropout_input: float = field(
|
| 46 |
+
default=0.0,
|
| 47 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
| 48 |
+
)
|
| 49 |
+
final_dropout: float = field(
|
| 50 |
+
default=0.0,
|
| 51 |
+
metadata={
|
| 52 |
+
"help": "dropout after transformer and before final projection"
|
| 53 |
+
},
|
| 54 |
+
)
|
| 55 |
+
dropout: float = field(
|
| 56 |
+
default=0.0,
|
| 57 |
+
metadata={"help": "dropout probability inside hubert model"},
|
| 58 |
+
)
|
| 59 |
+
attention_dropout: float = field(
|
| 60 |
+
default=0.0,
|
| 61 |
+
metadata={
|
| 62 |
+
"help": "dropout probability for attention weights "
|
| 63 |
+
"inside hubert model"
|
| 64 |
+
},
|
| 65 |
+
)
|
| 66 |
+
activation_dropout: float = field(
|
| 67 |
+
default=0.0,
|
| 68 |
+
metadata={
|
| 69 |
+
"help": "dropout probability after activation in FFN "
|
| 70 |
+
"inside hubert model"
|
| 71 |
+
},
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# masking
|
| 75 |
+
apply_mask: bool = field(
|
| 76 |
+
default=False, metadata={"help": "apply masking during fine-tuning"}
|
| 77 |
+
)
|
| 78 |
+
mask_length: int = field(
|
| 79 |
+
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
| 80 |
+
)
|
| 81 |
+
mask_prob: float = field(
|
| 82 |
+
default=0.5,
|
| 83 |
+
metadata={
|
| 84 |
+
"help": "probability of replacing a token with mask "
|
| 85 |
+
"(normalized by length)"
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
| 89 |
+
default="static", metadata={"help": "how to choose masks"}
|
| 90 |
+
)
|
| 91 |
+
mask_other: float = field(
|
| 92 |
+
default=0,
|
| 93 |
+
metadata={
|
| 94 |
+
"help": "secondary mask argument "
|
| 95 |
+
"(used for more complex distributions), "
|
| 96 |
+
"see help in compute_mask_indices"
|
| 97 |
+
},
|
| 98 |
+
)
|
| 99 |
+
no_mask_overlap: bool = field(
|
| 100 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# channel masking
|
| 104 |
+
mask_channel_length: int = field(
|
| 105 |
+
default=10,
|
| 106 |
+
metadata={"help": "length of the mask for features (channels)"},
|
| 107 |
+
)
|
| 108 |
+
mask_channel_prob: float = field(
|
| 109 |
+
default=0.0,
|
| 110 |
+
metadata={"help": "probability of replacing a feature with 0"},
|
| 111 |
+
)
|
| 112 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
| 113 |
+
default="static",
|
| 114 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
| 115 |
+
)
|
| 116 |
+
mask_channel_other: float = field(
|
| 117 |
+
default=0,
|
| 118 |
+
metadata={
|
| 119 |
+
"help": "secondary mask argument "
|
| 120 |
+
"(used for more complex distributions), "
|
| 121 |
+
"see help in compute_mask_indices"
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
no_mask_channel_overlap: bool = field(
|
| 125 |
+
default=False,
|
| 126 |
+
metadata={"help": "whether to allow channel masks to overlap"},
|
| 127 |
+
)
|
| 128 |
+
freeze_finetune_updates: int = field(
|
| 129 |
+
default=0,
|
| 130 |
+
metadata={"help": "dont finetune hubert for this many updates"},
|
| 131 |
+
)
|
| 132 |
+
feature_grad_mult: float = field(
|
| 133 |
+
default=0.0,
|
| 134 |
+
metadata={"help": "reset feature grad mult in hubert to this"},
|
| 135 |
+
)
|
| 136 |
+
layerdrop: float = field(
|
| 137 |
+
default=0.0,
|
| 138 |
+
metadata={"help": "probability of dropping a layer in hubert"},
|
| 139 |
+
)
|
| 140 |
+
normalize: bool = II("task.normalize")
|
| 141 |
+
data: str = II("task.data")
|
| 142 |
+
|
| 143 |
+
# this holds the loaded hubert args
|
| 144 |
+
w2v_args: Any = None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@dataclass
|
| 148 |
+
class AVHubertCtcConfig(AVHubertAsrConfig):
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig)
|
| 153 |
+
class AVHubertCtc(BaseFairseqModel):
|
| 154 |
+
def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.cfg = cfg
|
| 157 |
+
self.w2v_encoder = w2v_encoder
|
| 158 |
+
|
| 159 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 160 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 161 |
+
return state_dict
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask):
|
| 165 |
+
"""Build a new model instance."""
|
| 166 |
+
w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
|
| 167 |
+
return cls(cfg, w2v_encoder)
|
| 168 |
+
|
| 169 |
+
def get_normalized_probs(self, net_output, log_probs):
|
| 170 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
| 171 |
+
|
| 172 |
+
logits = net_output["encoder_out"]
|
| 173 |
+
if log_probs:
|
| 174 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
| 175 |
+
else:
|
| 176 |
+
return utils.softmax(logits.float(), dim=-1)
|
| 177 |
+
|
| 178 |
+
def get_logits(self, net_output):
|
| 179 |
+
logits = net_output["encoder_out"]
|
| 180 |
+
padding = net_output["encoder_padding_mask"]
|
| 181 |
+
if padding is not None and padding.any():
|
| 182 |
+
padding = padding.T
|
| 183 |
+
logits[padding][..., 0] = 0
|
| 184 |
+
logits[padding][..., 1:] = float("-inf")
|
| 185 |
+
|
| 186 |
+
return logits
|
| 187 |
+
|
| 188 |
+
def forward(self, **kwargs):
|
| 189 |
+
x = self.w2v_encoder(**kwargs)
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@dataclass
|
| 194 |
+
class AVHubertSeq2SeqConfig(AVHubertAsrConfig):
|
| 195 |
+
decoder_embed_dim: int = field(
|
| 196 |
+
default=768, metadata={"help": "decoder embedding dimension"}
|
| 197 |
+
)
|
| 198 |
+
decoder_ffn_embed_dim: int = field(
|
| 199 |
+
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
| 200 |
+
)
|
| 201 |
+
decoder_layers: int = field(
|
| 202 |
+
default=6, metadata={"help": "num of decoder layers"}
|
| 203 |
+
)
|
| 204 |
+
decoder_layerdrop: float = field(
|
| 205 |
+
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
| 206 |
+
)
|
| 207 |
+
decoder_attention_heads: int = field(
|
| 208 |
+
default=4, metadata={"help": "num decoder attention heads"}
|
| 209 |
+
)
|
| 210 |
+
decoder_learned_pos: bool = field(
|
| 211 |
+
default=False,
|
| 212 |
+
metadata={"help": "use learned positional embeddings in the decoder"},
|
| 213 |
+
)
|
| 214 |
+
decoder_normalize_before: bool = field(
|
| 215 |
+
default=False,
|
| 216 |
+
metadata={"help": "apply layernorm before each decoder block"},
|
| 217 |
+
)
|
| 218 |
+
no_token_positional_embeddings: bool = field(
|
| 219 |
+
default=False,
|
| 220 |
+
metadata={
|
| 221 |
+
"help": "if set, disables positional embeddings "
|
| 222 |
+
"(outside self attention)"
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
+
decoder_dropout: float = field(
|
| 226 |
+
default=0.0, metadata={"help": "dropout probability in the decoder"}
|
| 227 |
+
)
|
| 228 |
+
decoder_attention_dropout: float = field(
|
| 229 |
+
default=0.0,
|
| 230 |
+
metadata={
|
| 231 |
+
"help": "dropout probability for attention weights "
|
| 232 |
+
"inside the decoder"
|
| 233 |
+
},
|
| 234 |
+
)
|
| 235 |
+
decoder_activation_dropout: float = field(
|
| 236 |
+
default=0.0,
|
| 237 |
+
metadata={
|
| 238 |
+
"help": "dropout probability after activation in FFN "
|
| 239 |
+
"inside the decoder"
|
| 240 |
+
},
|
| 241 |
+
)
|
| 242 |
+
max_target_positions: int = field(
|
| 243 |
+
default=2048, metadata={"help": "max target positions"}
|
| 244 |
+
)
|
| 245 |
+
share_decoder_input_output_embed: bool = field(
|
| 246 |
+
default=False,
|
| 247 |
+
metadata={"help": "share decoder input and output embeddings"},
|
| 248 |
+
)
|
| 249 |
+
no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
|
| 250 |
+
|
| 251 |
+
class HubertEncoder(FairseqEncoder):
|
| 252 |
+
def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
|
| 253 |
+
self.apply_mask = cfg.apply_mask
|
| 254 |
+
|
| 255 |
+
arg_overrides = {
|
| 256 |
+
"dropout": cfg.dropout,
|
| 257 |
+
"activation_dropout": cfg.activation_dropout,
|
| 258 |
+
"dropout_input": cfg.dropout_input,
|
| 259 |
+
"attention_dropout": cfg.attention_dropout,
|
| 260 |
+
"mask_length": cfg.mask_length,
|
| 261 |
+
"mask_prob": cfg.mask_prob,
|
| 262 |
+
"mask_selection": cfg.mask_selection,
|
| 263 |
+
"mask_other": cfg.mask_other,
|
| 264 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
| 265 |
+
"mask_channel_length": cfg.mask_channel_length,
|
| 266 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
| 267 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
| 268 |
+
"mask_channel_other": cfg.mask_channel_other,
|
| 269 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
| 270 |
+
"encoder_layerdrop": cfg.layerdrop,
|
| 271 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
if cfg.w2v_args is None:
|
| 275 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
| 276 |
+
cfg.w2v_path, arg_overrides
|
| 277 |
+
)
|
| 278 |
+
w2v_args = state.get("cfg", None)
|
| 279 |
+
if w2v_args is None:
|
| 280 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
| 281 |
+
cfg.w2v_args = w2v_args
|
| 282 |
+
else:
|
| 283 |
+
state = None
|
| 284 |
+
w2v_args = cfg.w2v_args
|
| 285 |
+
if isinstance(w2v_args, Namespace):
|
| 286 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
| 287 |
+
w2v_args
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
| 291 |
+
"Fine-tuning works best when data normalization is the same. "
|
| 292 |
+
"Please check that --normalize is set or unset for "
|
| 293 |
+
"both pre-training and here"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
w2v_args.task.data = cfg.data
|
| 297 |
+
|
| 298 |
+
task = tasks.setup_task(w2v_args.task)
|
| 299 |
+
model = task.build_model(w2v_args.model)
|
| 300 |
+
|
| 301 |
+
if state is not None and not cfg.no_pretrained_weights:
|
| 302 |
+
# set strict=False because we omit some modules
|
| 303 |
+
model.load_state_dict(state["model"], strict=False)
|
| 304 |
+
|
| 305 |
+
model.remove_pretraining_modules()
|
| 306 |
+
|
| 307 |
+
super().__init__(task.source_dictionary)
|
| 308 |
+
|
| 309 |
+
d = model.encoder.embedding_dim
|
| 310 |
+
|
| 311 |
+
self.w2v_model = model
|
| 312 |
+
|
| 313 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
| 314 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
| 315 |
+
self.num_updates = 0
|
| 316 |
+
|
| 317 |
+
if tgt_dict is not None:
|
| 318 |
+
self.proj = Linear(d, len(tgt_dict))
|
| 319 |
+
elif getattr(cfg, "decoder_embed_dim", d) != d:
|
| 320 |
+
self.proj = Linear(d, cfg.decoder_embed_dim)
|
| 321 |
+
else:
|
| 322 |
+
self.proj = None
|
| 323 |
+
|
| 324 |
+
def set_num_updates(self, num_updates):
|
| 325 |
+
"""Set the number of parameters updates."""
|
| 326 |
+
super().set_num_updates(num_updates)
|
| 327 |
+
self.num_updates = num_updates
|
| 328 |
+
|
| 329 |
+
def forward(self, source, padding_mask, tbc=True, **kwargs):
|
| 330 |
+
|
| 331 |
+
w2v_args = {
|
| 332 |
+
"source": source,
|
| 333 |
+
"padding_mask": padding_mask,
|
| 334 |
+
"mask": self.apply_mask and self.training,
|
| 335 |
+
}
|
| 336 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
| 337 |
+
|
| 338 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
| 339 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
| 340 |
+
|
| 341 |
+
if tbc:
|
| 342 |
+
# B x T x C -> T x B x C
|
| 343 |
+
x = x.transpose(0, 1)
|
| 344 |
+
|
| 345 |
+
x = self.final_dropout(x)
|
| 346 |
+
|
| 347 |
+
if self.proj:
|
| 348 |
+
x = self.proj(x)
|
| 349 |
+
|
| 350 |
+
return {
|
| 351 |
+
"encoder_out": x, # T x B x C
|
| 352 |
+
"encoder_padding_mask": padding_mask, # B x T
|
| 353 |
+
"padding_mask": padding_mask,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
| 357 |
+
if encoder_out["encoder_out"] is not None:
|
| 358 |
+
encoder_out["encoder_out"] = encoder_out[
|
| 359 |
+
"encoder_out"
|
| 360 |
+
].index_select(1, new_order)
|
| 361 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
| 362 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
| 363 |
+
"encoder_padding_mask"
|
| 364 |
+
].index_select(0, new_order)
|
| 365 |
+
return encoder_out
|
| 366 |
+
|
| 367 |
+
def max_positions(self):
|
| 368 |
+
"""Maximum input length supported by the encoder."""
|
| 369 |
+
return None
|
| 370 |
+
|
| 371 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 372 |
+
return state_dict
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class HubertEncoderWrapper(FairseqEncoder):
|
| 376 |
+
def __init__(self, w2v_model):
|
| 377 |
+
super().__init__(None)
|
| 378 |
+
self.w2v_model = w2v_model
|
| 379 |
+
|
| 380 |
+
def forward(self, source, padding_mask, **kwargs):
|
| 381 |
+
w2v_args = {
|
| 382 |
+
"source": source,
|
| 383 |
+
"padding_mask": padding_mask,
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
| 387 |
+
# B x T x C -> T x B x C
|
| 388 |
+
x = x.transpose(0, 1)
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
"encoder_out": x, # T x B x C
|
| 392 |
+
"encoder_padding_mask": padding_mask, # B x T
|
| 393 |
+
"padding_mask": padding_mask
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
def reorder_encoder_out(self, encoder_out, new_order):
|
| 397 |
+
if encoder_out["encoder_out"] is not None:
|
| 398 |
+
encoder_out["encoder_out"] = encoder_out[
|
| 399 |
+
"encoder_out"
|
| 400 |
+
].index_select(1, new_order)
|
| 401 |
+
if encoder_out["encoder_padding_mask"] is not None:
|
| 402 |
+
encoder_out["encoder_padding_mask"] = encoder_out[
|
| 403 |
+
"encoder_padding_mask"
|
| 404 |
+
].index_select(0, new_order)
|
| 405 |
+
if encoder_out["padding_mask"] is not None:
|
| 406 |
+
encoder_out["padding_mask"] = encoder_out[
|
| 407 |
+
"padding_mask"
|
| 408 |
+
].index_select(0, new_order)
|
| 409 |
+
return encoder_out
|
| 410 |
+
|
| 411 |
+
@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig)
|
| 412 |
+
class AVHubertSeq2Seq(FairseqEncoderDecoderModel):
|
| 413 |
+
def __init__(self, encoder, decoder, tgt_dict, cfg):
|
| 414 |
+
super().__init__(encoder, decoder)
|
| 415 |
+
self.cfg = cfg
|
| 416 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
| 417 |
+
|
| 418 |
+
@classmethod
|
| 419 |
+
def build_model(cls, cfg, task):
|
| 420 |
+
"""Build a new model instance."""
|
| 421 |
+
|
| 422 |
+
arg_overrides = {
|
| 423 |
+
"dropout": cfg.dropout,
|
| 424 |
+
"activation_dropout": cfg.activation_dropout,
|
| 425 |
+
"dropout_input": cfg.dropout_input,
|
| 426 |
+
"attention_dropout": cfg.attention_dropout,
|
| 427 |
+
"mask_length": cfg.mask_length,
|
| 428 |
+
"mask_prob": cfg.mask_prob,
|
| 429 |
+
"mask_selection": cfg.mask_selection,
|
| 430 |
+
"mask_other": cfg.mask_other,
|
| 431 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
| 432 |
+
"mask_channel_length": cfg.mask_channel_length,
|
| 433 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
| 434 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
| 435 |
+
"mask_channel_other": cfg.mask_channel_other,
|
| 436 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
| 437 |
+
"encoder_layerdrop": cfg.layerdrop,
|
| 438 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
if cfg.w2v_args is None:
|
| 442 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
| 443 |
+
cfg.w2v_path, arg_overrides
|
| 444 |
+
)
|
| 445 |
+
w2v_args = state.get("cfg", None)
|
| 446 |
+
if w2v_args is None:
|
| 447 |
+
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
| 448 |
+
cfg.w2v_args = w2v_args
|
| 449 |
+
else:
|
| 450 |
+
state = None
|
| 451 |
+
w2v_args = cfg.w2v_args
|
| 452 |
+
if isinstance(w2v_args, Namespace):
|
| 453 |
+
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
| 454 |
+
w2v_args
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
assert cfg.normalize == w2v_args.task.normalize, (
|
| 458 |
+
"Fine-tuning works best when data normalization is the same. "
|
| 459 |
+
"Please check that --normalize is set or unset for "
|
| 460 |
+
"both pre-training and here"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
w2v_args.task.data = cfg.data
|
| 464 |
+
|
| 465 |
+
task_pretrain = tasks.setup_task(w2v_args.task)
|
| 466 |
+
if state is not None:
|
| 467 |
+
task_pretrain.load_state_dict(state['task_state'])
|
| 468 |
+
|
| 469 |
+
encoder_ = task_pretrain.build_model(w2v_args.model)
|
| 470 |
+
|
| 471 |
+
encoder = HubertEncoderWrapper(encoder_)
|
| 472 |
+
if state is not None and not cfg.no_pretrained_weights:
|
| 473 |
+
# set strict=False because we omit some modules
|
| 474 |
+
del state['model']['mask_emb']
|
| 475 |
+
encoder.w2v_model.load_state_dict(state["model"], strict=False)
|
| 476 |
+
|
| 477 |
+
encoder.w2v_model.remove_pretraining_modules()
|
| 478 |
+
|
| 479 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
| 480 |
+
|
| 481 |
+
def build_embedding(dictionary, embed_dim):
|
| 482 |
+
num_embeddings = len(dictionary)
|
| 483 |
+
padding_idx = dictionary.pad()
|
| 484 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
|
| 485 |
+
return emb
|
| 486 |
+
|
| 487 |
+
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
|
| 488 |
+
decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
|
| 489 |
+
|
| 490 |
+
return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def forward(self, **kwargs):
|
| 494 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
| 495 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
| 496 |
+
output = self.encoder(**kwargs)
|
| 497 |
+
decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output)
|
| 498 |
+
return decoder_out
|
| 499 |
+
|
| 500 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 501 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 502 |
+
return state_dict
|
| 503 |
+
|
| 504 |
+
def set_num_updates(self, num_updates):
|
| 505 |
+
"""Set the number of parameters updates."""
|
| 506 |
+
super().set_num_updates(num_updates)
|
| 507 |
+
self.num_updates = num_updates
|
| 508 |
+
|
| 509 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
| 510 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 511 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
| 512 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
| 513 |
+
return m
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def Linear(in_features, out_features, bias=True):
|
| 517 |
+
m = nn.Linear(in_features, out_features, bias)
|
| 518 |
+
nn.init.xavier_uniform_(m.weight)
|
| 519 |
+
if bias:
|
| 520 |
+
nn.init.constant_(m.bias, 0.0)
|
| 521 |
+
return m
|
av_hubert/avhubert/hubert_criterion.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from fairseq import metrics, utils
|
| 15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 16 |
+
from fairseq.dataclass import FairseqDataclass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class AVHubertCriterionConfig(FairseqDataclass):
|
| 21 |
+
pred_masked_weight: float = field(
|
| 22 |
+
default=1.0,
|
| 23 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
| 24 |
+
)
|
| 25 |
+
pred_nomask_weight: float = field(
|
| 26 |
+
default=0.0,
|
| 27 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
| 28 |
+
)
|
| 29 |
+
loss_weights: Optional[List[float]] = field(
|
| 30 |
+
default=None,
|
| 31 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
| 32 |
+
)
|
| 33 |
+
log_keys: List[str] = field(
|
| 34 |
+
default_factory=lambda: [],
|
| 35 |
+
metadata={"help": "output keys to log"},
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig)
|
| 40 |
+
class AVHubertCriterion(FairseqCriterion):
|
| 41 |
+
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
|
| 42 |
+
super().__init__(task)
|
| 43 |
+
self.pred_masked_weight = pred_masked_weight
|
| 44 |
+
self.pred_nomask_weight = pred_nomask_weight
|
| 45 |
+
self.loss_weights = loss_weights
|
| 46 |
+
self.log_keys = [] if log_keys is None else log_keys
|
| 47 |
+
|
| 48 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
| 49 |
+
"""Compute the loss for the given sample.
|
| 50 |
+
Returns a tuple with three elements:
|
| 51 |
+
1) the loss
|
| 52 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 53 |
+
3) logging outputs to display while training
|
| 54 |
+
"""
|
| 55 |
+
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
| 56 |
+
loss = 0.
|
| 57 |
+
sample_size = 0
|
| 58 |
+
logging_output = {}
|
| 59 |
+
reduction = "sum" if reduce else "none"
|
| 60 |
+
|
| 61 |
+
loss_m_list = []
|
| 62 |
+
logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list']
|
| 63 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
| 64 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
| 65 |
+
loss_m_list.append(loss_m)
|
| 66 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
| 67 |
+
if self.pred_masked_weight > 0:
|
| 68 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
| 69 |
+
sample_size += targ_m_list[0].numel()
|
| 70 |
+
|
| 71 |
+
loss_u_list = []
|
| 72 |
+
logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list']
|
| 73 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
| 74 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
| 75 |
+
loss_u_list.append(loss_u)
|
| 76 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
| 77 |
+
if self.pred_nomask_weight > 0:
|
| 78 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
| 79 |
+
sample_size += targ_u_list[0].numel()
|
| 80 |
+
|
| 81 |
+
if self.loss_weights is not None:
|
| 82 |
+
assert hasattr(model, "get_extra_losses")
|
| 83 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
| 84 |
+
if torch.is_tensor(extra_losses):
|
| 85 |
+
extra_losses = [extra_losses]
|
| 86 |
+
names = [names]
|
| 87 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
| 88 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
| 89 |
+
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
| 90 |
+
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
| 91 |
+
if coef != 0 and p is not None:
|
| 92 |
+
p = coef * p.float() * sample_size
|
| 93 |
+
loss += p
|
| 94 |
+
logging_output[f"loss_{n}"] = p.item()
|
| 95 |
+
|
| 96 |
+
logging_output = {
|
| 97 |
+
"loss": loss.item() if reduce else loss,
|
| 98 |
+
"ntokens": sample_size,
|
| 99 |
+
"nsentences": sample["id"].numel(),
|
| 100 |
+
"sample_size": sample_size,
|
| 101 |
+
**logging_output,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
for lk in self.log_keys:
|
| 105 |
+
if lk in net_output:
|
| 106 |
+
logging_output[lk] = float((net_output[lk]))
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
for i, logp_m in enumerate(logp_m_list):
|
| 110 |
+
# corr_m, count_m = compute_correct(logp_m)
|
| 111 |
+
if logp_m.numel() == 0:
|
| 112 |
+
corr_m, count_m = 0, 0
|
| 113 |
+
else:
|
| 114 |
+
corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i])
|
| 115 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
| 116 |
+
logging_output[f"count_m_{i}"] = count_m
|
| 117 |
+
|
| 118 |
+
for i, logp_u in enumerate(logp_u_list):
|
| 119 |
+
if logp_u.numel() == 0:
|
| 120 |
+
corr_u, count_u = 0, 0
|
| 121 |
+
else:
|
| 122 |
+
corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i])
|
| 123 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
| 124 |
+
logging_output[f"count_u_{i}"] = count_u
|
| 125 |
+
|
| 126 |
+
return loss, sample_size, logging_output
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def reduce_metrics(logging_outputs) -> None:
|
| 130 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
| 131 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 132 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 133 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 134 |
+
|
| 135 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
| 136 |
+
if sample_size != ntokens:
|
| 137 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
| 138 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
| 139 |
+
else:
|
| 140 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
| 141 |
+
|
| 142 |
+
counts = {}
|
| 143 |
+
for lk in logging_outputs[0].keys():
|
| 144 |
+
if lk.startswith("count_"):
|
| 145 |
+
val = sum(log[lk] for log in logging_outputs)
|
| 146 |
+
metrics.log_scalar(lk, val)
|
| 147 |
+
counts[lk] = val
|
| 148 |
+
|
| 149 |
+
for lk in logging_outputs[0].keys():
|
| 150 |
+
if lk.startswith("loss_"):
|
| 151 |
+
val = sum(log[lk] for log in logging_outputs)
|
| 152 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
| 153 |
+
elif lk.startswith("correct_"):
|
| 154 |
+
val = sum(log[lk] for log in logging_outputs)
|
| 155 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def aggregate_logging_outputs(logging_outputs):
|
| 159 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 160 |
+
raise NotImplementedError()
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def logging_outputs_can_be_summed() -> bool:
|
| 164 |
+
"""
|
| 165 |
+
Whether the logging outputs returned by `forward` can be summed
|
| 166 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
| 167 |
+
to True will improves distributed training speed.
|
| 168 |
+
"""
|
| 169 |
+
return False
|
av_hubert/avhubert/hubert_dataset.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import itertools
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
from typing import Any, List, Optional, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from fairseq.data import data_utils
|
| 19 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
| 20 |
+
from python_speech_features import logfbank
|
| 21 |
+
from scipy.io import wavfile
|
| 22 |
+
|
| 23 |
+
DBG=True if len(sys.argv) == 1 else False
|
| 24 |
+
|
| 25 |
+
if DBG:
|
| 26 |
+
import utils as custom_utils
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 30 |
+
level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
|
| 31 |
+
stream=sys.stdout,
|
| 32 |
+
)
|
| 33 |
+
else:
|
| 34 |
+
from . import utils as custom_utils
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
|
| 40 |
+
def is_audio_label_aligned(audio_dur, label_durs):
|
| 41 |
+
return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs])
|
| 42 |
+
|
| 43 |
+
n_long, n_short, n_unaligned = 0, 0, 0
|
| 44 |
+
names, inds, sizes = [], [], []
|
| 45 |
+
dur_from_label_list = []
|
| 46 |
+
is_seq_label = any([x==-1 for x in label_rates])
|
| 47 |
+
for label_path, label_rate in zip(label_paths, label_rates):
|
| 48 |
+
label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()]
|
| 49 |
+
dur_from_label_list.append(label_lengths)
|
| 50 |
+
dur_from_label_list = list(zip(*dur_from_label_list))
|
| 51 |
+
|
| 52 |
+
with open(manifest_path) as f:
|
| 53 |
+
root = f.readline().strip()
|
| 54 |
+
for ind, line in enumerate(f):
|
| 55 |
+
items = line.strip().split("\t")
|
| 56 |
+
sz = int(items[-2]) #
|
| 57 |
+
if min_keep is not None and sz < min_keep:
|
| 58 |
+
n_short += 1
|
| 59 |
+
elif max_keep is not None and sz > max_keep:
|
| 60 |
+
n_long += 1
|
| 61 |
+
elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])):
|
| 62 |
+
n_unaligned += 1
|
| 63 |
+
else:
|
| 64 |
+
video_path = items[1]
|
| 65 |
+
audio_path = items[2]
|
| 66 |
+
audio_id = items[0]
|
| 67 |
+
names.append((video_path, audio_path+':'+audio_id))
|
| 68 |
+
inds.append(ind)
|
| 69 |
+
sizes.append(sz)
|
| 70 |
+
tot = ind + 1
|
| 71 |
+
logger.info(
|
| 72 |
+
(
|
| 73 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
| 74 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, "
|
| 75 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
return root, names, inds, tot, sizes
|
| 79 |
+
|
| 80 |
+
def load_label(label_path, inds, tot):
|
| 81 |
+
with open(label_path) as f:
|
| 82 |
+
labels = [line.rstrip() for line in f]
|
| 83 |
+
assert (
|
| 84 |
+
len(labels) == tot
|
| 85 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
| 86 |
+
labels = [labels[i] for i in inds]
|
| 87 |
+
return labels
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_label_offset(label_path, inds, tot):
|
| 91 |
+
with open(label_path) as f:
|
| 92 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
| 93 |
+
assert (
|
| 94 |
+
len(code_lengths) == tot
|
| 95 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
| 96 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
| 97 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
| 98 |
+
return offsets
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def verify_label_lengths(
|
| 102 |
+
audio_sizes,
|
| 103 |
+
audio_rate,
|
| 104 |
+
label_path,
|
| 105 |
+
label_rate,
|
| 106 |
+
inds,
|
| 107 |
+
tot,
|
| 108 |
+
tol=0.1, # tolerance in seconds
|
| 109 |
+
):
|
| 110 |
+
if label_rate < 0:
|
| 111 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
with open(label_path) as f:
|
| 115 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
| 116 |
+
assert len(lengths) == tot
|
| 117 |
+
lengths = [lengths[i] for i in inds]
|
| 118 |
+
num_invalid = 0
|
| 119 |
+
for i, ind in enumerate(inds):
|
| 120 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
| 121 |
+
dur_from_label = lengths[i] / label_rate
|
| 122 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
| 123 |
+
logger.warning(
|
| 124 |
+
(
|
| 125 |
+
f"audio and label duration differ too much "
|
| 126 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
| 127 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
| 128 |
+
f"is correctly set (currently {label_rate}). "
|
| 129 |
+
f"num. of samples = {audio_sizes[i]}; "
|
| 130 |
+
f"label length = {lengths[i]}"
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
num_invalid += 1
|
| 134 |
+
if num_invalid > 0:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class AVHubertDataset(FairseqDataset):
|
| 141 |
+
def __init__(
|
| 142 |
+
self,
|
| 143 |
+
manifest_path: str,
|
| 144 |
+
sample_rate: float,
|
| 145 |
+
label_paths: List[str],
|
| 146 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
| 147 |
+
pad_list: List[str],
|
| 148 |
+
eos_list: List[str],
|
| 149 |
+
label_processors: Optional[List[Any]] = None,
|
| 150 |
+
max_keep_sample_size: Optional[int] = None,
|
| 151 |
+
min_keep_sample_size: Optional[int] = None,
|
| 152 |
+
max_sample_size: Optional[int] = None,
|
| 153 |
+
shuffle: bool = True,
|
| 154 |
+
pad_audio: bool = False,
|
| 155 |
+
normalize: bool = False,
|
| 156 |
+
store_labels: bool = True,
|
| 157 |
+
random_crop: bool = False,
|
| 158 |
+
single_target: bool = False,
|
| 159 |
+
stack_order_audio: int=1,
|
| 160 |
+
skip_verify: bool=False,
|
| 161 |
+
image_mean: float=0,
|
| 162 |
+
image_std: float=1,
|
| 163 |
+
image_crop_size: int=88,
|
| 164 |
+
image_aug: bool=False,
|
| 165 |
+
modalities: Optional[List[str]]=None,
|
| 166 |
+
is_s2s=False,
|
| 167 |
+
noise_fn=None,
|
| 168 |
+
noise_prob=0,
|
| 169 |
+
noise_snr=0,
|
| 170 |
+
noise_num=1
|
| 171 |
+
):
|
| 172 |
+
self.label_rates = (
|
| 173 |
+
[label_rates for _ in range(len(label_paths))]
|
| 174 |
+
if isinstance(label_rates, int)
|
| 175 |
+
else label_rates
|
| 176 |
+
)
|
| 177 |
+
self.modalities = set(modalities)
|
| 178 |
+
self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
|
| 179 |
+
self.sample_rate = sample_rate
|
| 180 |
+
self.stack_order_audio = stack_order_audio
|
| 181 |
+
self.shuffle = shuffle
|
| 182 |
+
self.random_crop = random_crop
|
| 183 |
+
|
| 184 |
+
self.num_labels = len(label_paths)
|
| 185 |
+
self.pad_list = pad_list
|
| 186 |
+
self.eos_list = eos_list
|
| 187 |
+
self.label_processors = label_processors
|
| 188 |
+
self.single_target = single_target
|
| 189 |
+
self.store_labels = store_labels
|
| 190 |
+
self.is_s2s = is_s2s
|
| 191 |
+
self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num
|
| 192 |
+
|
| 193 |
+
assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)"
|
| 194 |
+
if store_labels:
|
| 195 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
| 196 |
+
else:
|
| 197 |
+
self.label_paths = label_paths
|
| 198 |
+
self.label_offsets_list = [
|
| 199 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
| 200 |
+
]
|
| 201 |
+
assert (
|
| 202 |
+
label_processors is None
|
| 203 |
+
or len(label_processors) == self.num_labels
|
| 204 |
+
)
|
| 205 |
+
if not skip_verify:
|
| 206 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
| 207 |
+
verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot)
|
| 208 |
+
else:
|
| 209 |
+
logger.info(f"Skip label alignment verifying")
|
| 210 |
+
|
| 211 |
+
self.max_sample_size = (
|
| 212 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
| 213 |
+
)
|
| 214 |
+
self.pad_audio = pad_audio
|
| 215 |
+
self.normalize = normalize
|
| 216 |
+
if image_aug:
|
| 217 |
+
self.transform = custom_utils.Compose([
|
| 218 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
| 219 |
+
custom_utils.RandomCrop((image_crop_size, image_crop_size)),
|
| 220 |
+
custom_utils.HorizontalFlip(0.5),
|
| 221 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
| 222 |
+
else:
|
| 223 |
+
self.transform = custom_utils.Compose([
|
| 224 |
+
custom_utils.Normalize( 0.0,255.0 ),
|
| 225 |
+
custom_utils.CenterCrop((image_crop_size, image_crop_size)),
|
| 226 |
+
custom_utils.Normalize(image_mean, image_std) ])
|
| 227 |
+
logger.info(f"image transform: {self.transform}")
|
| 228 |
+
|
| 229 |
+
logger.info(
|
| 230 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
| 231 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}, "
|
| 232 |
+
f"seqs2seq data={self.is_s2s},")
|
| 233 |
+
logger.info(
|
| 234 |
+
f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def get_label(self, index, label_idx):
|
| 238 |
+
if self.store_labels:
|
| 239 |
+
label = self.label_list[label_idx][index]
|
| 240 |
+
else:
|
| 241 |
+
with open(self.label_paths[label_idx]) as f:
|
| 242 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
| 243 |
+
f.seek(offset_s)
|
| 244 |
+
label = f.read(offset_e - offset_s)
|
| 245 |
+
|
| 246 |
+
if self.label_processors is not None:
|
| 247 |
+
label = self.label_processors[label_idx](label)
|
| 248 |
+
return label
|
| 249 |
+
|
| 250 |
+
def get_labels(self, index):
|
| 251 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
| 252 |
+
|
| 253 |
+
def load_feature(self, mix_name):
|
| 254 |
+
"""
|
| 255 |
+
Load image and audio feature
|
| 256 |
+
Returns:
|
| 257 |
+
video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F]
|
| 258 |
+
"""
|
| 259 |
+
def stacker(feats, stack_order):
|
| 260 |
+
"""
|
| 261 |
+
Concatenating consecutive audio frames
|
| 262 |
+
Args:
|
| 263 |
+
feats - numpy.ndarray of shape [T, F]
|
| 264 |
+
stack_order - int (number of neighboring frames to concatenate
|
| 265 |
+
Returns:
|
| 266 |
+
feats - numpy.ndarray of shape [T', F']
|
| 267 |
+
"""
|
| 268 |
+
feat_dim = feats.shape[1]
|
| 269 |
+
if len(feats) % stack_order != 0:
|
| 270 |
+
res = stack_order - len(feats) % stack_order
|
| 271 |
+
res = np.zeros([res, feat_dim]).astype(feats.dtype)
|
| 272 |
+
feats = np.concatenate([feats, res], axis=0)
|
| 273 |
+
feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
|
| 274 |
+
return feats
|
| 275 |
+
video_fn, audio_fn = mix_name
|
| 276 |
+
if 'video' in self.modalities:
|
| 277 |
+
video_feats = self.load_video(video_fn) # [T, H, W, 1]
|
| 278 |
+
else:
|
| 279 |
+
video_feats = None
|
| 280 |
+
if 'audio' in self.modalities:
|
| 281 |
+
audio_fn = audio_fn.split(':')[0]
|
| 282 |
+
sample_rate, wav_data = wavfile.read(audio_fn)
|
| 283 |
+
assert sample_rate == 16_000 and len(wav_data.shape) == 1
|
| 284 |
+
if np.random.rand() < self.noise_prob:
|
| 285 |
+
wav_data = self.add_noise(wav_data)
|
| 286 |
+
audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F]
|
| 287 |
+
audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio]
|
| 288 |
+
else:
|
| 289 |
+
audio_feats = None
|
| 290 |
+
if audio_feats is not None and video_feats is not None:
|
| 291 |
+
diff = len(audio_feats) - len(video_feats)
|
| 292 |
+
if diff < 0:
|
| 293 |
+
audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
|
| 294 |
+
elif diff > 0:
|
| 295 |
+
audio_feats = audio_feats[:-diff]
|
| 296 |
+
return video_feats, audio_feats
|
| 297 |
+
|
| 298 |
+
def load_video(self, audio_name):
|
| 299 |
+
feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name))
|
| 300 |
+
feats = self.transform(feats)
|
| 301 |
+
feats = np.expand_dims(feats, axis=-1)
|
| 302 |
+
return feats
|
| 303 |
+
|
| 304 |
+
def select_noise(self):
|
| 305 |
+
rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num)
|
| 306 |
+
noise_wav = []
|
| 307 |
+
for x in rand_indexes:
|
| 308 |
+
noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32))
|
| 309 |
+
if self.noise_num == 1:
|
| 310 |
+
return noise_wav[0]
|
| 311 |
+
else:
|
| 312 |
+
min_len = min([len(x) for x in noise_wav])
|
| 313 |
+
noise_wav = [x[:min_len] for x in noise_wav]
|
| 314 |
+
noise_wav = np.floor(np.stack(noise_wav).mean(axis=0))
|
| 315 |
+
return noise_wav
|
| 316 |
+
|
| 317 |
+
def add_noise(self, clean_wav):
|
| 318 |
+
clean_wav = clean_wav.astype(np.float32)
|
| 319 |
+
noise_wav = self.select_noise()
|
| 320 |
+
if type(self.noise_snr) == int or type(self.noise_snr) == float:
|
| 321 |
+
snr = self.noise_snr
|
| 322 |
+
elif type(self.noise_snr) == tuple:
|
| 323 |
+
snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1)
|
| 324 |
+
clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1))
|
| 325 |
+
if len(clean_wav) > len(noise_wav):
|
| 326 |
+
ratio = int(np.ceil(len(clean_wav)/len(noise_wav)))
|
| 327 |
+
noise_wav = np.concatenate([noise_wav for _ in range(ratio)])
|
| 328 |
+
if len(clean_wav) < len(noise_wav):
|
| 329 |
+
start = 0
|
| 330 |
+
noise_wav = noise_wav[start: start + len(clean_wav)]
|
| 331 |
+
noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1))
|
| 332 |
+
adjusted_noise_rms = clean_rms / (10**(snr/20))
|
| 333 |
+
adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms)
|
| 334 |
+
mixed = clean_wav + adjusted_noise_wav
|
| 335 |
+
|
| 336 |
+
#Avoid clipping noise
|
| 337 |
+
max_int16 = np.iinfo(np.int16).max
|
| 338 |
+
min_int16 = np.iinfo(np.int16).min
|
| 339 |
+
if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16:
|
| 340 |
+
if mixed.max(axis=0) >= abs(mixed.min(axis=0)):
|
| 341 |
+
reduction_rate = max_int16 / mixed.max(axis=0)
|
| 342 |
+
else :
|
| 343 |
+
reduction_rate = min_int16 / mixed.min(axis=0)
|
| 344 |
+
mixed = mixed * (reduction_rate)
|
| 345 |
+
mixed = mixed.astype(np.int16)
|
| 346 |
+
return mixed
|
| 347 |
+
|
| 348 |
+
def __getitem__(self, index):
|
| 349 |
+
video_feats, audio_feats = self.load_feature(self.names[index])
|
| 350 |
+
audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None
|
| 351 |
+
if self.normalize and 'audio' in self.modalities:
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
|
| 354 |
+
labels = self.get_labels(index)
|
| 355 |
+
fid = self.names[index][1].split(':')[1]
|
| 356 |
+
return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels}
|
| 357 |
+
|
| 358 |
+
def __len__(self):
|
| 359 |
+
return len(self.sizes)
|
| 360 |
+
|
| 361 |
+
def crop_to_max_size(self, wav, target_size, start=None):
|
| 362 |
+
size = len(wav)
|
| 363 |
+
diff = size - target_size
|
| 364 |
+
if diff <= 0:
|
| 365 |
+
return wav, 0
|
| 366 |
+
# longer utterances
|
| 367 |
+
if start is None:
|
| 368 |
+
start, end = 0, target_size
|
| 369 |
+
if self.random_crop:
|
| 370 |
+
start = np.random.randint(0, diff + 1)
|
| 371 |
+
end = size - diff + start
|
| 372 |
+
else:
|
| 373 |
+
end = start + target_size
|
| 374 |
+
return wav[start:end], start
|
| 375 |
+
|
| 376 |
+
def collater(self, samples):
|
| 377 |
+
samples = [s for s in samples if s["id"] is not None]
|
| 378 |
+
if len(samples) == 0:
|
| 379 |
+
return {}
|
| 380 |
+
|
| 381 |
+
audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
|
| 382 |
+
if audio_source[0] is None:
|
| 383 |
+
audio_source = None
|
| 384 |
+
if video_source[0] is None:
|
| 385 |
+
video_source = None
|
| 386 |
+
if audio_source is not None:
|
| 387 |
+
audio_sizes = [len(s) for s in audio_source]
|
| 388 |
+
else:
|
| 389 |
+
audio_sizes = [len(s) for s in video_source]
|
| 390 |
+
if self.pad_audio:
|
| 391 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
| 392 |
+
else:
|
| 393 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
| 394 |
+
if audio_source is not None:
|
| 395 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
|
| 396 |
+
else:
|
| 397 |
+
collated_audios, audio_starts = None, None
|
| 398 |
+
if video_source is not None:
|
| 399 |
+
collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
|
| 400 |
+
else:
|
| 401 |
+
collated_videos = None
|
| 402 |
+
targets_by_label = [
|
| 403 |
+
[s["label_list"][i] for s in samples]
|
| 404 |
+
for i in range(self.num_labels)
|
| 405 |
+
]
|
| 406 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
| 407 |
+
targets_by_label, audio_size, audio_starts
|
| 408 |
+
)
|
| 409 |
+
source = {"audio": collated_audios, "video": collated_videos}
|
| 410 |
+
net_input = {"source": source, "padding_mask": padding_mask}
|
| 411 |
+
batch = {
|
| 412 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
| 413 |
+
"net_input": net_input,
|
| 414 |
+
"utt_id": [s['fid'] for s in samples]
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
if self.single_target:
|
| 418 |
+
batch["target_lengths"] = lengths_list[0]
|
| 419 |
+
batch["ntokens"] = ntokens_list[0]
|
| 420 |
+
if self.is_s2s:
|
| 421 |
+
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
|
| 422 |
+
else:
|
| 423 |
+
batch["target"] = targets_list[0]
|
| 424 |
+
else:
|
| 425 |
+
batch["target_lengths_list"] = lengths_list
|
| 426 |
+
batch["ntokens_list"] = ntokens_list
|
| 427 |
+
batch["target_list"] = targets_list
|
| 428 |
+
return batch
|
| 429 |
+
|
| 430 |
+
def collater_audio(self, audios, audio_size, audio_starts=None):
|
| 431 |
+
audio_feat_shape = list(audios[0].shape[1:])
|
| 432 |
+
collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape)
|
| 433 |
+
padding_mask = (
|
| 434 |
+
torch.BoolTensor(len(audios), audio_size).fill_(False) #
|
| 435 |
+
)
|
| 436 |
+
start_known = audio_starts is not None
|
| 437 |
+
audio_starts = [0 for _ in audios] if not start_known else audio_starts
|
| 438 |
+
for i, audio in enumerate(audios):
|
| 439 |
+
diff = len(audio) - audio_size
|
| 440 |
+
if diff == 0:
|
| 441 |
+
collated_audios[i] = audio
|
| 442 |
+
elif diff < 0:
|
| 443 |
+
assert self.pad_audio
|
| 444 |
+
collated_audios[i] = torch.cat(
|
| 445 |
+
[audio, audio.new_full([-diff]+audio_feat_shape, 0.0)]
|
| 446 |
+
)
|
| 447 |
+
padding_mask[i, diff:] = True
|
| 448 |
+
else:
|
| 449 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
| 450 |
+
audio, audio_size, audio_starts[i] if start_known else None
|
| 451 |
+
)
|
| 452 |
+
if len(audios[0].shape) == 2:
|
| 453 |
+
collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T]
|
| 454 |
+
else:
|
| 455 |
+
collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
|
| 456 |
+
return collated_audios, padding_mask, audio_starts
|
| 457 |
+
|
| 458 |
+
def collater_frm_label(
|
| 459 |
+
self, targets, audio_size, audio_starts, label_rate, pad
|
| 460 |
+
):
|
| 461 |
+
assert label_rate > 0
|
| 462 |
+
s2f = label_rate / self.sample_rate # num label per sample
|
| 463 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
| 464 |
+
frm_size = int(round(audio_size * s2f))
|
| 465 |
+
if not self.pad_audio:
|
| 466 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
| 467 |
+
frm_size = min(frm_size, *rem_size)
|
| 468 |
+
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
|
| 469 |
+
logger.debug(f"audio_starts={audio_starts}")
|
| 470 |
+
logger.debug(f"frame_starts={frm_starts}")
|
| 471 |
+
logger.debug(f"frame_size={frm_size}")
|
| 472 |
+
|
| 473 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
| 474 |
+
ntokens = lengths.sum().item()
|
| 475 |
+
targets = data_utils.collate_tokens(
|
| 476 |
+
targets, pad_idx=pad, left_pad=False
|
| 477 |
+
)
|
| 478 |
+
return targets, lengths, ntokens
|
| 479 |
+
|
| 480 |
+
def collater_seq_label(self, targets, pad):
|
| 481 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
| 482 |
+
ntokens = lengths.sum().item()
|
| 483 |
+
targets = data_utils.collate_tokens(
|
| 484 |
+
targets, pad_idx=pad, left_pad=False
|
| 485 |
+
)
|
| 486 |
+
return targets, lengths, ntokens
|
| 487 |
+
|
| 488 |
+
def collater_seq_label_s2s(self, targets, pad):
|
| 489 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
| 490 |
+
ntokens = lengths.sum().item()
|
| 491 |
+
pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos()
|
| 492 |
+
targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False)
|
| 493 |
+
prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True)
|
| 494 |
+
return (targets_, prev_output_tokens), lengths, ntokens
|
| 495 |
+
|
| 496 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
| 497 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
| 498 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
| 499 |
+
for targets, label_rate, pad in itr:
|
| 500 |
+
if label_rate == -1:
|
| 501 |
+
if self.is_s2s:
|
| 502 |
+
targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad)
|
| 503 |
+
else:
|
| 504 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
| 505 |
+
else:
|
| 506 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
| 507 |
+
targets, audio_size, audio_starts, label_rate, pad
|
| 508 |
+
)
|
| 509 |
+
targets_list.append(targets)
|
| 510 |
+
lengths_list.append(lengths)
|
| 511 |
+
ntokens_list.append(ntokens)
|
| 512 |
+
return targets_list, lengths_list, ntokens_list
|
| 513 |
+
|
| 514 |
+
def num_tokens(self, index):
|
| 515 |
+
return self.size(index)
|
| 516 |
+
|
| 517 |
+
def size(self, index):
|
| 518 |
+
if self.pad_audio:
|
| 519 |
+
return self.sizes[index]
|
| 520 |
+
return min(self.sizes[index], self.max_sample_size)
|
| 521 |
+
|
| 522 |
+
def ordered_indices(self):
|
| 523 |
+
if self.shuffle:
|
| 524 |
+
order = [np.random.permutation(len(self))]
|
| 525 |
+
else:
|
| 526 |
+
order = [np.arange(len(self))]
|
| 527 |
+
|
| 528 |
+
order.append(self.sizes)
|
| 529 |
+
return np.lexsort(order)[::-1]
|
av_hubert/avhubert/hubert_pretraining.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os, glob
|
| 9 |
+
import sys
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from fairseq import metrics, search
|
| 16 |
+
from fairseq.data import Dictionary, encoders
|
| 17 |
+
from fairseq.dataclass.configs import FairseqDataclass
|
| 18 |
+
from fairseq.tasks import register_task
|
| 19 |
+
from fairseq.tasks.fairseq_task import FairseqTask
|
| 20 |
+
from omegaconf import MISSING, II
|
| 21 |
+
import numpy as np
|
| 22 |
+
from argparse import Namespace
|
| 23 |
+
|
| 24 |
+
DBG=True if len(sys.argv) == 1 else False
|
| 25 |
+
|
| 26 |
+
if DBG:
|
| 27 |
+
from hubert_dataset import AVHubertDataset
|
| 28 |
+
from sequence_generator import SequenceGenerator
|
| 29 |
+
else:
|
| 30 |
+
from .hubert_dataset import AVHubertDataset
|
| 31 |
+
from .sequence_generator import SequenceGenerator
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LabelEncoder(object):
|
| 37 |
+
def __init__(self, dictionary: Dictionary) -> None:
|
| 38 |
+
self.dictionary = dictionary
|
| 39 |
+
|
| 40 |
+
def __call__(self, label: str) -> List[str]:
|
| 41 |
+
return self.dictionary.encode_line(
|
| 42 |
+
label, append_eos=False, add_if_not_exist=False,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
class LabelEncoderS2SToken(object):
|
| 46 |
+
def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None:
|
| 47 |
+
self.bpe_tokenizer = bpe_tokenizer
|
| 48 |
+
self.dictionary = dictionary
|
| 49 |
+
|
| 50 |
+
def __call__(self, label: str) -> List[str]:
|
| 51 |
+
label = self.bpe_tokenizer.encode(label.lower())
|
| 52 |
+
return self.dictionary.encode_line(
|
| 53 |
+
label, append_eos=True, add_if_not_exist=False,
|
| 54 |
+
).long()
|
| 55 |
+
|
| 56 |
+
def decode(self, tok, symbols_ignore=None):
|
| 57 |
+
tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore)
|
| 58 |
+
if self.bpe_tokenizer:
|
| 59 |
+
tok = self.bpe_tokenizer.decode(tok)
|
| 60 |
+
return tok
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class AVHubertPretrainingConfig(FairseqDataclass):
|
| 64 |
+
data: str = field(
|
| 65 |
+
default=MISSING, metadata={"help": "path to data directory"}
|
| 66 |
+
)
|
| 67 |
+
labels: List[str] = field(
|
| 68 |
+
default_factory=lambda: ["ltr"],
|
| 69 |
+
metadata={
|
| 70 |
+
"help": (
|
| 71 |
+
"extension of the label files to load, frame-level labels for"
|
| 72 |
+
" pre-training, and sequence-level label for fine-tuning"
|
| 73 |
+
)
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
label_dir: Optional[str] = field(
|
| 77 |
+
default=None,
|
| 78 |
+
metadata={
|
| 79 |
+
"help": "if set, looks for labels in this directory instead",
|
| 80 |
+
},
|
| 81 |
+
)
|
| 82 |
+
label_rate: int = field(
|
| 83 |
+
default=-1,
|
| 84 |
+
metadata={"help": "label frame rate. -1 for sequence label"},
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
sample_rate: int = field(
|
| 88 |
+
default=16_000,
|
| 89 |
+
metadata={
|
| 90 |
+
"help": "target sample rate. audio files will be up/down "
|
| 91 |
+
"sampled to this rate"
|
| 92 |
+
},
|
| 93 |
+
)
|
| 94 |
+
normalize: bool = field(
|
| 95 |
+
default=False,
|
| 96 |
+
metadata={
|
| 97 |
+
"help": "if set, normalizes input to have 0 mean and unit variance"
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
enable_padding: bool = field(
|
| 101 |
+
default=False,
|
| 102 |
+
metadata={"help": "pad shorter samples instead of cropping"},
|
| 103 |
+
)
|
| 104 |
+
max_sample_size: Optional[int] = field(
|
| 105 |
+
default=None,
|
| 106 |
+
metadata={"help": "max sample size to keep in training"},
|
| 107 |
+
)
|
| 108 |
+
min_sample_size: Optional[int] = field(
|
| 109 |
+
default=None,
|
| 110 |
+
metadata={"help": "min sample size to keep in training"},
|
| 111 |
+
)
|
| 112 |
+
max_trim_sample_size: Optional[int] = field(
|
| 113 |
+
default=II("task.max_sample_size"),
|
| 114 |
+
metadata={"help": "max sample size to trim to for batching"},
|
| 115 |
+
)
|
| 116 |
+
single_target: Optional[bool] = field(
|
| 117 |
+
default=False,
|
| 118 |
+
metadata={
|
| 119 |
+
"help": "if set, AddTargetDatasets outputs same keys "
|
| 120 |
+
"as AddTargetDataset"
|
| 121 |
+
},
|
| 122 |
+
)
|
| 123 |
+
random_crop: Optional[bool] = field(
|
| 124 |
+
default=True,
|
| 125 |
+
metadata={"help": "always crop from the beginning if false"},
|
| 126 |
+
)
|
| 127 |
+
pad_audio: Optional[bool] = field(
|
| 128 |
+
default=False,
|
| 129 |
+
metadata={"help": "pad audio to the longest one in the batch if true"},
|
| 130 |
+
)
|
| 131 |
+
pdb: Optional[bool] = field(
|
| 132 |
+
default=False,
|
| 133 |
+
metadata={"help": "pdb"},
|
| 134 |
+
)
|
| 135 |
+
stack_order_audio: int = field(
|
| 136 |
+
default=1,
|
| 137 |
+
metadata={"help": "concatenate n consecutive audio frames for one step"},
|
| 138 |
+
)
|
| 139 |
+
skip_verify: Optional[bool] = field(
|
| 140 |
+
default=False,
|
| 141 |
+
metadata={"help": "skip verifying label-audio alignment"},
|
| 142 |
+
)
|
| 143 |
+
image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'})
|
| 144 |
+
image_crop_size: int = field(
|
| 145 |
+
default=88, metadata={"help": "image ROI size"})
|
| 146 |
+
image_mean: float = field(
|
| 147 |
+
default=0.421, metadata={"help": "image mean"})
|
| 148 |
+
image_std: float = field(
|
| 149 |
+
default=0.165, metadata={"help": "image std"})
|
| 150 |
+
modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'})
|
| 151 |
+
is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'})
|
| 152 |
+
tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'})
|
| 153 |
+
tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'})
|
| 154 |
+
noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'})
|
| 155 |
+
noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
|
| 156 |
+
noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'})
|
| 157 |
+
noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'})
|
| 158 |
+
fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"})
|
| 159 |
+
|
| 160 |
+
@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig)
|
| 161 |
+
class AVHubertPretrainingTask(FairseqTask):
|
| 162 |
+
|
| 163 |
+
cfg: AVHubertPretrainingConfig
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
cfg: AVHubertPretrainingConfig,
|
| 168 |
+
) -> None:
|
| 169 |
+
super().__init__(cfg)
|
| 170 |
+
|
| 171 |
+
logger.info(f"current directory is {os.getcwd()}")
|
| 172 |
+
logger.info(f"AVHubertPretrainingTask Config {cfg}")
|
| 173 |
+
|
| 174 |
+
self.fine_tuning = cfg.fine_tuning
|
| 175 |
+
if cfg.fine_tuning:
|
| 176 |
+
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
| 177 |
+
if cfg.is_s2s:
|
| 178 |
+
self.state.add_factory("s2s_tokenizer", self.load_tokenizer)
|
| 179 |
+
else:
|
| 180 |
+
self.state.add_factory("dictionaries", self.load_dictionaries)
|
| 181 |
+
|
| 182 |
+
self.blank_symbol = "<s>"
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def source_dictionary(self) -> Optional[Dictionary]:
|
| 186 |
+
return None # self._source_dictionary
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def target_dictionary(self) -> Optional[Dictionary]:
|
| 190 |
+
return self.state.target_dictionary # self._target_dictionary
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def dictionaries(self) -> List[Dictionary]:
|
| 194 |
+
return self.state.dictionaries
|
| 195 |
+
|
| 196 |
+
def load_dictionaries(self):
|
| 197 |
+
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
|
| 198 |
+
dictionaries = [
|
| 199 |
+
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
| 200 |
+
for label in self.cfg.labels
|
| 201 |
+
]
|
| 202 |
+
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
| 203 |
+
|
| 204 |
+
def load_tokenizer(self):
|
| 205 |
+
bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model})
|
| 206 |
+
bpe_tokenizer = encoders.build_bpe(bpe_args)
|
| 207 |
+
return bpe_tokenizer
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def s2s_tokenizer(self):
|
| 211 |
+
return self.state.s2s_tokenizer
|
| 212 |
+
|
| 213 |
+
@classmethod
|
| 214 |
+
def setup_task(
|
| 215 |
+
cls, cfg: AVHubertPretrainingConfig, **kwargs
|
| 216 |
+
) -> "AVHubertPretrainingTask":
|
| 217 |
+
if cfg.pdb:
|
| 218 |
+
import pdb
|
| 219 |
+
pdb.set_trace()
|
| 220 |
+
return cls(cfg)
|
| 221 |
+
|
| 222 |
+
def get_label_dir(self) -> str:
|
| 223 |
+
if self.cfg.label_dir is None:
|
| 224 |
+
return self.cfg.data
|
| 225 |
+
return self.cfg.label_dir
|
| 226 |
+
|
| 227 |
+
def load_dataset(self, split: str, **kwargs) -> None:
|
| 228 |
+
manifest = f"{self.cfg.data}/{split}.tsv"
|
| 229 |
+
dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries
|
| 230 |
+
pad_list = [dictionary.pad() for dictionary in dictionaries]
|
| 231 |
+
eos_list = [dictionary.eos() for dictionary in dictionaries]
|
| 232 |
+
if not self.cfg.is_s2s:
|
| 233 |
+
procs = [LabelEncoder(dictionary) for dictionary in dictionaries]
|
| 234 |
+
else:
|
| 235 |
+
logger.info(f"Using tokenizer")
|
| 236 |
+
bpe_tokenizer = self.s2s_tokenizer
|
| 237 |
+
procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries]
|
| 238 |
+
paths = [
|
| 239 |
+
f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
|
| 240 |
+
]
|
| 241 |
+
image_aug = self.cfg.image_aug if split == 'train' else False
|
| 242 |
+
noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr)
|
| 243 |
+
noise_num = self.cfg.noise_num #
|
| 244 |
+
self.datasets[split] = AVHubertDataset(
|
| 245 |
+
manifest,
|
| 246 |
+
sample_rate=self.cfg.sample_rate,
|
| 247 |
+
label_paths=paths,
|
| 248 |
+
label_rates=self.cfg.label_rate,
|
| 249 |
+
pad_list=pad_list,
|
| 250 |
+
eos_list=eos_list,
|
| 251 |
+
label_processors=procs,
|
| 252 |
+
max_keep_sample_size=self.cfg.max_sample_size,
|
| 253 |
+
min_keep_sample_size=self.cfg.min_sample_size,
|
| 254 |
+
max_sample_size=self.cfg.max_trim_sample_size,
|
| 255 |
+
pad_audio=self.cfg.pad_audio,
|
| 256 |
+
normalize=self.cfg.normalize,
|
| 257 |
+
store_labels=False,
|
| 258 |
+
random_crop=self.cfg.random_crop,
|
| 259 |
+
single_target=self.cfg.single_target,
|
| 260 |
+
stack_order_audio=self.cfg.stack_order_audio,
|
| 261 |
+
skip_verify=self.cfg.skip_verify,
|
| 262 |
+
image_mean=self.cfg.image_mean,
|
| 263 |
+
image_std=self.cfg.image_std,
|
| 264 |
+
image_crop_size=self.cfg.image_crop_size,
|
| 265 |
+
image_aug=image_aug,
|
| 266 |
+
modalities=self.cfg.modalities,
|
| 267 |
+
is_s2s=self.cfg.is_s2s,
|
| 268 |
+
noise_fn=noise_fn,
|
| 269 |
+
noise_prob=self.cfg.noise_prob,
|
| 270 |
+
noise_snr=noise_snr,
|
| 271 |
+
noise_num=noise_num
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def max_positions(self) -> Tuple[int, int]:
|
| 275 |
+
return (sys.maxsize, sys.maxsize)
|
| 276 |
+
|
| 277 |
+
def filter_indices_by_size(
|
| 278 |
+
self, indices: np.array, *args, **kwargs
|
| 279 |
+
) -> np.array:
|
| 280 |
+
return indices
|
| 281 |
+
|
| 282 |
+
def build_generator(
|
| 283 |
+
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
|
| 284 |
+
):
|
| 285 |
+
"""
|
| 286 |
+
Build a :class:`~fairseq.SequenceGenerator` instance for this
|
| 287 |
+
task.
|
| 288 |
+
Args:
|
| 289 |
+
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
| 290 |
+
args (fairseq.dataclass.configs.GenerationConfig):
|
| 291 |
+
configuration object (dataclass) for generation
|
| 292 |
+
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
|
| 293 |
+
through to SequenceGenerator
|
| 294 |
+
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
|
| 295 |
+
If provided, this function constrains the beam search to
|
| 296 |
+
allowed tokens only at each step. The provided function
|
| 297 |
+
should take 2 arguments: the batch ID (`batch_id: int`)
|
| 298 |
+
and a unidimensional tensor of token ids (`inputs_ids:
|
| 299 |
+
torch.Tensor`). It has to return a `List[int]` with the
|
| 300 |
+
allowed tokens for the next generation step conditioned
|
| 301 |
+
on the previously generated tokens (`inputs_ids`) and
|
| 302 |
+
the batch ID (`batch_id`). This argument is useful for
|
| 303 |
+
constrained generation conditioned on the prefix, as
|
| 304 |
+
described in "Autoregressive Entity Retrieval"
|
| 305 |
+
(https://arxiv.org/abs/2010.00904) and
|
| 306 |
+
https://github.com/facebookresearch/GENRE.
|
| 307 |
+
"""
|
| 308 |
+
if getattr(args, "score_reference", False):
|
| 309 |
+
from fairseq.sequence_scorer import SequenceScorer
|
| 310 |
+
|
| 311 |
+
return SequenceScorer(
|
| 312 |
+
self.target_dictionary,
|
| 313 |
+
compute_alignment=getattr(args, "print_alignment", False),
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Choose search strategy. Defaults to Beam Search.
|
| 317 |
+
sampling = getattr(args, "sampling", False)
|
| 318 |
+
sampling_topk = getattr(args, "sampling_topk", -1)
|
| 319 |
+
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
| 320 |
+
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
| 321 |
+
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
| 322 |
+
match_source_len = getattr(args, "match_source_len", False)
|
| 323 |
+
diversity_rate = getattr(args, "diversity_rate", -1)
|
| 324 |
+
constrained = getattr(args, "constraints", False)
|
| 325 |
+
if prefix_allowed_tokens_fn is None:
|
| 326 |
+
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
| 327 |
+
if (
|
| 328 |
+
sum(
|
| 329 |
+
int(cond)
|
| 330 |
+
for cond in [
|
| 331 |
+
sampling,
|
| 332 |
+
diverse_beam_groups > 0,
|
| 333 |
+
match_source_len,
|
| 334 |
+
diversity_rate > 0,
|
| 335 |
+
]
|
| 336 |
+
)
|
| 337 |
+
> 1
|
| 338 |
+
):
|
| 339 |
+
raise ValueError("Provided Search parameters are mutually exclusive.")
|
| 340 |
+
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
| 341 |
+
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
| 342 |
+
|
| 343 |
+
if sampling:
|
| 344 |
+
search_strategy = search.Sampling(
|
| 345 |
+
self.target_dictionary, sampling_topk, sampling_topp
|
| 346 |
+
)
|
| 347 |
+
elif diverse_beam_groups > 0:
|
| 348 |
+
search_strategy = search.DiverseBeamSearch(
|
| 349 |
+
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
| 350 |
+
)
|
| 351 |
+
elif match_source_len:
|
| 352 |
+
# this is useful for tagging applications where the output
|
| 353 |
+
# length should match the input length, so we hardcode the
|
| 354 |
+
# length constraints for simplicity
|
| 355 |
+
search_strategy = search.LengthConstrainedBeamSearch(
|
| 356 |
+
self.target_dictionary,
|
| 357 |
+
min_len_a=1,
|
| 358 |
+
min_len_b=0,
|
| 359 |
+
max_len_a=1,
|
| 360 |
+
max_len_b=0,
|
| 361 |
+
)
|
| 362 |
+
elif diversity_rate > -1:
|
| 363 |
+
search_strategy = search.DiverseSiblingsSearch(
|
| 364 |
+
self.target_dictionary, diversity_rate
|
| 365 |
+
)
|
| 366 |
+
elif constrained:
|
| 367 |
+
search_strategy = search.LexicallyConstrainedBeamSearch(
|
| 368 |
+
self.target_dictionary, args.constraints
|
| 369 |
+
)
|
| 370 |
+
elif prefix_allowed_tokens_fn:
|
| 371 |
+
search_strategy = search.PrefixConstrainedBeamSearch(
|
| 372 |
+
self.target_dictionary, prefix_allowed_tokens_fn
|
| 373 |
+
)
|
| 374 |
+
else:
|
| 375 |
+
search_strategy = search.BeamSearch(self.target_dictionary)
|
| 376 |
+
|
| 377 |
+
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
| 378 |
+
if seq_gen_cls is None:
|
| 379 |
+
if getattr(args, "print_alignment", False):
|
| 380 |
+
seq_gen_cls = SequenceGeneratorWithAlignment
|
| 381 |
+
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
|
| 382 |
+
else:
|
| 383 |
+
seq_gen_cls = SequenceGenerator
|
| 384 |
+
|
| 385 |
+
return seq_gen_cls(
|
| 386 |
+
models,
|
| 387 |
+
self.target_dictionary,
|
| 388 |
+
beam_size=getattr(args, "beam", 5),
|
| 389 |
+
max_len_a=getattr(args, "max_len_a", 0),
|
| 390 |
+
max_len_b=getattr(args, "max_len_b", 200),
|
| 391 |
+
min_len=getattr(args, "min_len", 1),
|
| 392 |
+
normalize_scores=(not getattr(args, "unnormalized", False)),
|
| 393 |
+
len_penalty=getattr(args, "lenpen", 1),
|
| 394 |
+
unk_penalty=getattr(args, "unkpen", 0),
|
| 395 |
+
temperature=getattr(args, "temperature", 1.0),
|
| 396 |
+
match_source_len=getattr(args, "match_source_len", False),
|
| 397 |
+
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
| 398 |
+
search_strategy=search_strategy,
|
| 399 |
+
**extra_gen_cls_kwargs,
|
| 400 |
+
)
|