Spaces:
Running
on
Zero
Running
on
Zero
Upload 43 files
Browse files- .gitattributes +2 -0
- .github/workflows/ci.yaml +44 -0
- .gitignore +8 -0
- .pre-commit-config.yaml +15 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +61 -0
- README.md +137 -14
- assets/sam_audio_main_model.png +3 -0
- eval/README.md +100 -0
- eval/dataset/__init__.py +70 -0
- eval/dataset/musdb.py +75 -0
- eval/dataset/sam_audio_bench.py +153 -0
- eval/main.py +162 -0
- eval/metrics/__init__.py +13 -0
- eval/metrics/aes.py +49 -0
- eval/metrics/clap.py +46 -0
- eval/metrics/imagebind.py +52 -0
- eval/metrics/judge.py +44 -0
- examples/assets/office.mp4 +3 -0
- examples/span_prompting.ipynb +0 -0
- examples/text_prompting.ipynb +0 -0
- examples/visual_prompting.ipynb +0 -0
- pyproject.toml +62 -0
- sam_audio/__init__.py +4 -0
- sam_audio/model/__init__.py +4 -0
- sam_audio/model/align.py +50 -0
- sam_audio/model/base.py +62 -0
- sam_audio/model/codec.py +109 -0
- sam_audio/model/config.py +251 -0
- sam_audio/model/judge.py +135 -0
- sam_audio/model/model.py +362 -0
- sam_audio/model/patcher.py +164 -0
- sam_audio/model/rope.py +155 -0
- sam_audio/model/text_encoder.py +37 -0
- sam_audio/model/transformer.py +524 -0
- sam_audio/model/vision_encoder.py +113 -0
- sam_audio/processor.py +382 -0
- sam_audio/ranking/__init__.py +30 -0
- sam_audio/ranking/clap.py +86 -0
- sam_audio/ranking/imagebind.py +197 -0
- sam_audio/ranking/judge.py +42 -0
- sam_audio/ranking/ranker.py +36 -0
- sam_audio/ranking/sound_activity.py +129 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
Sam_audio/assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Sam_audio/examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
Sam_audio/assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
Sam_audio/examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/ci.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
pull_request:
|
| 5 |
+
|
| 6 |
+
env:
|
| 7 |
+
CACHE_NUMBER: 0
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
build:
|
| 11 |
+
runs-on: 32-core-ubuntu
|
| 12 |
+
steps:
|
| 13 |
+
- uses: actions/checkout@v4
|
| 14 |
+
|
| 15 |
+
- uses: mamba-org/setup-micromamba@v1.8.1
|
| 16 |
+
with:
|
| 17 |
+
environment-name: sam-audio
|
| 18 |
+
init-shell: bash
|
| 19 |
+
create-args: >-
|
| 20 |
+
python=3.11
|
| 21 |
+
pip=24.2
|
| 22 |
+
ruff==
|
| 23 |
+
|
| 24 |
+
- uses: actions/cache@v4
|
| 25 |
+
with:
|
| 26 |
+
path: /home/runner/micromamba/envs/sam-audio
|
| 27 |
+
key: ${{ hashFiles('pyproject.toml') }}-${{ env.CACHE_NUMBER }}
|
| 28 |
+
id: cache
|
| 29 |
+
|
| 30 |
+
- name: Update environment
|
| 31 |
+
shell: bash -l {0}
|
| 32 |
+
run: |
|
| 33 |
+
pip install .
|
| 34 |
+
if: steps.cache.outputs.cache-hit != 'true'
|
| 35 |
+
|
| 36 |
+
- name: Check formatting
|
| 37 |
+
shell: bash -l {0}
|
| 38 |
+
run: |
|
| 39 |
+
ruff format --check .
|
| 40 |
+
|
| 41 |
+
- name: Lint
|
| 42 |
+
shell: bash -l {0}
|
| 43 |
+
run: |
|
| 44 |
+
ruff check .
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.egg-info
|
| 3 |
+
*.pyc
|
| 4 |
+
*.so
|
| 5 |
+
build
|
| 6 |
+
dist
|
| 7 |
+
.checkpoints
|
| 8 |
+
.ipynb_checkpoints
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.6.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
- id: check-yaml
|
| 7 |
+
args:
|
| 8 |
+
- --allow-multiple-documents
|
| 9 |
+
- id: end-of-file-fixer
|
| 10 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 11 |
+
rev: v0.12.0
|
| 12 |
+
hooks:
|
| 13 |
+
- id: ruff
|
| 14 |
+
args: [ --fix ]
|
| 15 |
+
- id: ruff-format
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to segment-anything-model-audio
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to segment-anything-model-audio, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SAM License
|
| 2 |
+
Last Updated: November 19, 2025
|
| 3 |
+
|
| 4 |
+
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
| 8 |
+
|
| 9 |
+
“Documentation” means the specifications, manuals and documentation accompanying
|
| 10 |
+
SAM Materials distributed by Meta.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
|
| 23 |
+
|
| 24 |
+
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
1. License Rights and Redistribution.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
|
| 31 |
+
|
| 32 |
+
b. Redistribution and Use.
|
| 33 |
+
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
| 40 |
+
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
|
| 41 |
+
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
| 42 |
+
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
|
| 46 |
+
|
| 47 |
+
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 48 |
+
|
| 49 |
+
5. Intellectual Property.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
| 53 |
+
|
| 54 |
+
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
|
| 55 |
+
|
| 56 |
+
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
| 57 |
+
|
| 58 |
+
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
README.md
CHANGED
|
@@ -1,14 +1,137 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# SAM-Audio
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
</div>
|
| 10 |
+
|
| 11 |
+
Segment Anything Model for Audio [[**Blog**](https://ai.meta.com/blog/sam-audio/)] [[**Paper**](https://ai.meta.com/research/publications/sam-audio-segment-anything-in-audio/)] [[**Demo**](https://aidemos.meta.com/segment-anything/editor/segment-audio)]
|
| 12 |
+
|
| 13 |
+
SAM-Audio is a foundation model for isolating any sound in audio using text, visual, or temporal prompts. It can separate specific sounds from complex audio mixtures based on natural language descriptions, visual cues from video, or time spans.
|
| 14 |
+
|
| 15 |
+
SAM-Audio and the Judge model crucially rely on [Perception-Encoder Audio-Visual (PE-AV)](https://huggingface.co/facebook/pe-av-large), which you can read more about [here](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)
|
| 16 |
+
|
| 17 |
+
## Setup
|
| 18 |
+
|
| 19 |
+
**Requirements:**
|
| 20 |
+
- Python >= 3.10
|
| 21 |
+
- CUDA-compatible GPU (recommended)
|
| 22 |
+
|
| 23 |
+
Install dependencies:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
pip install .
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
⚠️ Before using SAM Audio, please request access to the checkpoints on the SAM Audio
|
| 32 |
+
Hugging Face [repo](https://huggingface.co/facebook/sam-audio-large). Once accepted, you
|
| 33 |
+
need to be authenticated to download the checkpoints. You can do this by running
|
| 34 |
+
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
|
| 35 |
+
(e.g. `hf auth login` after generating an access token.)
|
| 36 |
+
|
| 37 |
+
### Basic Text Prompting
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from sam_audio import SAMAudio, SAMAudioProcessor
|
| 41 |
+
import torchaudio
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
model = SAMAudio.from_pretrained("facebook/sam-audio-large")
|
| 45 |
+
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
|
| 46 |
+
model = model.eval().cuda()
|
| 47 |
+
|
| 48 |
+
file = "<audio file>" # audio file path or torch tensor
|
| 49 |
+
description = "<description>"
|
| 50 |
+
|
| 51 |
+
batch = processor(
|
| 52 |
+
audios=[file],
|
| 53 |
+
descriptions=[description],
|
| 54 |
+
).to("cuda")
|
| 55 |
+
|
| 56 |
+
with torch.inference_mode():
|
| 57 |
+
# NOTE: `predict_spans` and `reranking_candidates` have a large impact on performance.
|
| 58 |
+
# Setting `predict_span=True` and `reranking_candidates=8` will give you better results at the cost of
|
| 59 |
+
# latency and memory. See the "Span Prediction" section below for more details
|
| 60 |
+
result = model.separate(batch, predict_spans=False, reranking_candidates=1)
|
| 61 |
+
|
| 62 |
+
# Save separated audio
|
| 63 |
+
sample_rate = processor.audio_sampling_rate
|
| 64 |
+
torchaudio.save("target.wav", result.target.cpu(), sample_rate) # The isolated sound
|
| 65 |
+
torchaudio.save("residual.wav", result.residual.cpu(), sample_rate) # Everything else
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### Prompting Methods
|
| 69 |
+
|
| 70 |
+
SAM-Audio supports three types of prompts:
|
| 71 |
+
|
| 72 |
+
1. **Text Prompting**: Describe the sound you want to isolate using natural language
|
| 73 |
+
```python
|
| 74 |
+
processor(audios=[audio], descriptions=["A man speaking"])
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
2. **Visual Prompting**: Use video frames and masks to isolate sounds associated with visual objects
|
| 78 |
+
```python
|
| 79 |
+
processor(audios=[video], descriptions=[""], masked_videos=processor.mask_videos([frames], [mask]))
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
3. **Span Prompting**: Specify time ranges where the target sound occurs
|
| 83 |
+
```python
|
| 84 |
+
processor(audios=[audio], descriptions=["A horn honking"], anchors=[[["+", 6.3, 7.0]]])
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
See the [examples](examples) directory for more detailed examples
|
| 88 |
+
|
| 89 |
+
### Span Prediction (Optional for Text Prompting)
|
| 90 |
+
|
| 91 |
+
We also provide support for automatically predicting the spans based on the text description, which is especially helpful for separating non-ambience sound events. You can enable this by adding `predict_spans=True` in your call to `separate`
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
with torch.inference_mode()
|
| 95 |
+
outputs = model.separate(batch, predict_spans=True)
|
| 96 |
+
|
| 97 |
+
# To further improve performance (at the expense of latency), you can add candidate re-ranking
|
| 98 |
+
with torch.inference_mode():
|
| 99 |
+
outputs = model.separate(batch, predict_spans=True, reranking_candidates=8)
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Re-Ranking
|
| 103 |
+
|
| 104 |
+
We provide the following models to assess the quality of the separated audio:
|
| 105 |
+
|
| 106 |
+
- [CLAP](https://github.com/LAION-AI/CLAP): measures the similarity between the target audio and text description
|
| 107 |
+
- [Judge](https://huggingface.co/facebook/sam-audio-judge): measures the overall separation quality across 3 axes: precision, recall, and faithfulness (see the [model card](https://huggingface.co/facebook/sam-audio-judge#output-format) for more details)
|
| 108 |
+
- [ImageBind](https://github.com/facebookresearch/ImageBind): for visual prompting, we measure the imagebind embedding similarity between the separated audio and the masked input video
|
| 109 |
+
|
| 110 |
+
We provide support for generating multiple candidates (by setting `reranking_candidates=<k>` in your call to `separate`), which will generate `k` audios, and choose the best one based on the ranking models mentioned above
|
| 111 |
+
|
| 112 |
+
# Models
|
| 113 |
+
|
| 114 |
+
Below is a table of each of the models we released along with their overall subjective evaluation scores
|
| 115 |
+
|
| 116 |
+
| Model | General SFX | Speech | Speaker | Music | Instr(wild) | Instr(pro) |
|
| 117 |
+
|----------|-------------|--------|---------|-------|-------------|------------|
|
| 118 |
+
| [`sam-audio-small`](https://huggingface.co/facebook/sam-audio-small) | 3.62 | 3.99 | 3.12 | 4.11 | 3.56 | 4.24 |
|
| 119 |
+
| [`sam-audio-base`](https://huggingface.co/facebook/sam-audio-base) | 3.28 | 4.25 | 3.57 | 3.87 | 3.66 | 4.27 |
|
| 120 |
+
| [`sam-audio-large`](https://huggingface.co/facebook/sam-audio-large) | 3.50 | 4.03 | 3.60 | 4.22 | 3.66 | 4.49 |
|
| 121 |
+
|
| 122 |
+
We additional release another variant (in each size) that is better specifically on correctness of target sound as well as visual prompting:
|
| 123 |
+
- [`sam-audio-small-tv`](https://huggingface.co/facebook/sam-audio-small-tv)
|
| 124 |
+
- [`sam-audio-base-tv`](https://huggingface.co/facebook/sam-audio-base-tv)
|
| 125 |
+
- [`sam-audio-large-tv`](https://huggingface.co/facebook/sam-audio-large-tv)
|
| 126 |
+
|
| 127 |
+
## Evaluation
|
| 128 |
+
|
| 129 |
+
See the [eval](eval) directory for instructions and scripts to reproduce results from the paper
|
| 130 |
+
|
| 131 |
+
## Contributing
|
| 132 |
+
|
| 133 |
+
See [contributing](CONTRIBUTING.md) and [code of conduct](CODE_OF_CONDUCT.md) for more information.
|
| 134 |
+
|
| 135 |
+
## License
|
| 136 |
+
|
| 137 |
+
This project is licensed under the SAM License - see the [LICENSE](LICENSE) file for details.
|
assets/sam_audio_main_model.png
ADDED
|
Git LFS Details
|
eval/README.md
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation
|
| 2 |
+
|
| 3 |
+
This directory contains the evaluation code to reproduce the results from the SAM-Audio paper. The evaluation framework supports multiple datasets, prompting modes (text-only, span, visual), and metrics.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
Before running evaluation, ensure you have:
|
| 8 |
+
|
| 9 |
+
1. Installed the SAM-Audio package and its dependencies
|
| 10 |
+
2. Authenticated with Hugging Face to access the model checkpoints (see main [README](../README.md))
|
| 11 |
+
|
| 12 |
+
## Quick Start
|
| 13 |
+
|
| 14 |
+
Run evaluation on the default setting (instr-pro):
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
python main.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
You can also use multiple GPUs to speed up evaluation:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
torchrun --nproc_per_node=<ngpus> python main.py
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Evaluate on a specific setting:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
python main.py --setting sfx
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Evaluate on multiple settings:
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
python main.py --setting sfx speech music
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Available Evaluation Settings
|
| 39 |
+
|
| 40 |
+
Run `python main.py --help` to see all available settings
|
| 41 |
+
|
| 42 |
+
## Command Line Options
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python main.py [OPTIONS]
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Options:
|
| 49 |
+
|
| 50 |
+
- `-s, --setting` - Which setting(s) to evaluate (default: `instr-pro`)
|
| 51 |
+
- Choices: See available settings above
|
| 52 |
+
- Can specify multiple settings: `--setting sfx speech music`
|
| 53 |
+
|
| 54 |
+
- `--cache-path` - Where to cache downloaded datasets (default: `~/.cache/sam_audio`)
|
| 55 |
+
|
| 56 |
+
- `-p, --checkpoint-path` - Model checkpoint to evaluate (default: `facebook/sam-audio-1b`)
|
| 57 |
+
- Can use local path or Hugging Face model ID
|
| 58 |
+
|
| 59 |
+
- `-b, --batch-size` - Batch size for evaluation (default: `1`)
|
| 60 |
+
|
| 61 |
+
- `-w, --num-workers` - Number of data loading workers (default: `4`)
|
| 62 |
+
|
| 63 |
+
- `-c, --candidates` - Number of reranking candidates (default: `8`)
|
| 64 |
+
|
| 65 |
+
## Evaluation Metrics
|
| 66 |
+
|
| 67 |
+
The evaluation framework computes the following metrics:
|
| 68 |
+
|
| 69 |
+
- **Judge** - SAM Audio Judge quality assessment metric
|
| 70 |
+
- **Aesthetic** - Aesthetic quality metric
|
| 71 |
+
- **CLAP** - Audio-text alignment metric (CLAP similarity)
|
| 72 |
+
- **ImageBind** - Audio-video alignment metric (for visual settings only)
|
| 73 |
+
|
| 74 |
+
## Output
|
| 75 |
+
|
| 76 |
+
Results are saved to the `results/` directory as JSON files, one per setting:
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
results/
|
| 80 |
+
├── sfx.json
|
| 81 |
+
├── speech.json
|
| 82 |
+
└── music.json
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Each JSON file contains the averaged metric scores across all samples in that setting.
|
| 86 |
+
|
| 87 |
+
Example output:
|
| 88 |
+
```json
|
| 89 |
+
{
|
| 90 |
+
"JudgeOverall": "4.386",
|
| 91 |
+
"JudgeFaithfulness": "4.708",
|
| 92 |
+
"JudgeRecall": "4.934",
|
| 93 |
+
"JudgePrecision": "4.451",
|
| 94 |
+
"ContentEnjoyment": "5.296",
|
| 95 |
+
"ContentUsefulness": "6.903",
|
| 96 |
+
"ProductionComplexity": "4.301",
|
| 97 |
+
"ProductionQuality": "7.100",
|
| 98 |
+
"CLAPSimilarity": "0.271"
|
| 99 |
+
}
|
| 100 |
+
```
|
eval/dataset/__init__.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
from .musdb import MUSDB
|
| 6 |
+
from .sam_audio_bench import SAMAudioBench
|
| 7 |
+
|
| 8 |
+
SETTINGS = {
|
| 9 |
+
# Text-only settings
|
| 10 |
+
"sfx": (
|
| 11 |
+
SAMAudioBench,
|
| 12 |
+
{"span": False, "visual": False, "subset": "others-50:text-only"},
|
| 13 |
+
),
|
| 14 |
+
"speech": (
|
| 15 |
+
SAMAudioBench,
|
| 16 |
+
{"span": False, "visual": False, "subset": "speech-clean-50:text-only"},
|
| 17 |
+
),
|
| 18 |
+
"speaker": (
|
| 19 |
+
SAMAudioBench,
|
| 20 |
+
{"span": False, "visual": False, "subset": "spk-50:text-only"},
|
| 21 |
+
),
|
| 22 |
+
"music": (
|
| 23 |
+
SAMAudioBench,
|
| 24 |
+
{"span": False, "visual": False, "subset": "music-clean-50:text-only"},
|
| 25 |
+
),
|
| 26 |
+
"instr-wild": (
|
| 27 |
+
SAMAudioBench,
|
| 28 |
+
{"span": False, "visual": False, "subset": "instr-50:text-only"},
|
| 29 |
+
),
|
| 30 |
+
"instr-pro": (MUSDB, {}),
|
| 31 |
+
# Span settings
|
| 32 |
+
"sfx-span": (
|
| 33 |
+
SAMAudioBench,
|
| 34 |
+
{"span": True, "visual": False, "subset": "others-50:text+span"},
|
| 35 |
+
),
|
| 36 |
+
"speech-span": (
|
| 37 |
+
SAMAudioBench,
|
| 38 |
+
{"span": True, "visual": False, "subset": "speech-clean-50:text+span"},
|
| 39 |
+
),
|
| 40 |
+
"speaker-span": (
|
| 41 |
+
SAMAudioBench,
|
| 42 |
+
{"span": True, "visual": False, "subset": "spk-50:text+span"},
|
| 43 |
+
),
|
| 44 |
+
"music-span": (
|
| 45 |
+
SAMAudioBench,
|
| 46 |
+
{"span": True, "visual": False, "subset": "music-clean-50:text+span"},
|
| 47 |
+
),
|
| 48 |
+
"instr-wild-span": (
|
| 49 |
+
SAMAudioBench,
|
| 50 |
+
{"span": True, "visual": False, "subset": "instr-50:text+span"},
|
| 51 |
+
),
|
| 52 |
+
# Visual settings
|
| 53 |
+
"sfx-visual": (
|
| 54 |
+
SAMAudioBench,
|
| 55 |
+
{"span": False, "visual": True, "subset": "others-onscreen-50:visual-only"},
|
| 56 |
+
),
|
| 57 |
+
"speaker-visual": (
|
| 58 |
+
SAMAudioBench,
|
| 59 |
+
{"span": False, "visual": True, "subset": "spk-onscreen-50:visual-only"},
|
| 60 |
+
),
|
| 61 |
+
"instr-wild-visual": (
|
| 62 |
+
SAMAudioBench,
|
| 63 |
+
{"span": False, "visual": True, "subset": "instr-onscreen-50:visual-only"},
|
| 64 |
+
),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def make_dataset(setting: str, cache_path: str, collate_fn: Callable):
|
| 69 |
+
dataset, kwargs = SETTINGS[setting]
|
| 70 |
+
return dataset(cache_path=cache_path, collate_fn=collate_fn, **kwargs)
|
eval/dataset/musdb.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from subprocess import check_call
|
| 5 |
+
|
| 6 |
+
import torchaudio
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from torchcodec.decoders import AudioDecoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cache_file(url, outfile):
|
| 13 |
+
if not os.path.exists(outfile):
|
| 14 |
+
print("Downloading musdb18hq dataset...")
|
| 15 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 16 |
+
check_call(["curl", "--url", url, "--output", outfile + ".tmp"])
|
| 17 |
+
os.rename(outfile + ".tmp", outfile)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MUSDB(Dataset):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
collate_fn,
|
| 24 |
+
sample_rate: int = 48_000,
|
| 25 |
+
cache_path: str = os.path.expanduser("~/.cache/sam_audio"),
|
| 26 |
+
):
|
| 27 |
+
self.cache_path = os.path.join(cache_path, "musdb18hq")
|
| 28 |
+
self.ds = self.get_dataset(cache_path)
|
| 29 |
+
self.captions = ["bass", "drums", "vocals"]
|
| 30 |
+
self.collate_fn = collate_fn
|
| 31 |
+
self.sample_rate = sample_rate
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def visual(self):
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def get_dataset(self, cache_path):
|
| 38 |
+
zip_file = os.path.join(cache_path, "musdb18hq.zip")
|
| 39 |
+
url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1"
|
| 40 |
+
cache_file(url, zip_file)
|
| 41 |
+
extracted_dir = os.path.join(cache_path, "musdb18hq")
|
| 42 |
+
if not os.path.exists(extracted_dir):
|
| 43 |
+
check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"])
|
| 44 |
+
os.rename(extracted_dir + ".tmp", extracted_dir)
|
| 45 |
+
return load_dataset("facebook/sam-audio-musdb18hq-test")["test"]
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.ds)
|
| 49 |
+
|
| 50 |
+
def collate(self, items):
|
| 51 |
+
audios, descriptions = zip(*items, strict=False)
|
| 52 |
+
return self.collate_fn(
|
| 53 |
+
audios=audios,
|
| 54 |
+
descriptions=descriptions,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
item = self.ds[idx]
|
| 59 |
+
path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav")
|
| 60 |
+
assert os.path.exists(path), f"{path} does not exist!"
|
| 61 |
+
decoder = AudioDecoder(path)
|
| 62 |
+
data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"])
|
| 63 |
+
wav = data.data
|
| 64 |
+
if data.sample_rate != self.sample_rate:
|
| 65 |
+
wav = torchaudio.functional.resample(
|
| 66 |
+
wav, data.sample_rate, self.sample_rate
|
| 67 |
+
)
|
| 68 |
+
wav = wav.mean(0, keepdim=True)
|
| 69 |
+
return wav, item["description"]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
dataset = MUSDB(lambda **kwargs: None)
|
| 74 |
+
print(len(dataset))
|
| 75 |
+
print(dataset[0])
|
eval/dataset/sam_audio_bench.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchaudio
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from torchcodec.decoders import AudioDecoder, VideoDecoder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Item:
|
| 18 |
+
anchors: list[Tuple[str, float, float]]
|
| 19 |
+
masked_video_frames: torch.Tensor
|
| 20 |
+
audio_samples: torch.Tensor
|
| 21 |
+
description: str
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SAMAudioBench(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
cache_path,
|
| 28 |
+
collate_fn,
|
| 29 |
+
span: bool = True,
|
| 30 |
+
visual: bool = True,
|
| 31 |
+
subset: Optional[str] = None,
|
| 32 |
+
):
|
| 33 |
+
self.dataset = load_dataset("facebook/sam-audio-bench")["test"]
|
| 34 |
+
self.subset = subset
|
| 35 |
+
self._span = span
|
| 36 |
+
self._visual = visual
|
| 37 |
+
if subset is not None:
|
| 38 |
+
self.dataset = self.dataset.filter(lambda x: subset in x["paper_eval_sets"])
|
| 39 |
+
|
| 40 |
+
self.cache_path = os.path.join(cache_path, "sam_audio_bench")
|
| 41 |
+
self.collate_fn = collate_fn
|
| 42 |
+
DATA_MSG = (
|
| 43 |
+
f"`SAMAudioBench` requires the user to create a directory named {self.cache_path} "
|
| 44 |
+
"see the README.md file for how to prepare"
|
| 45 |
+
)
|
| 46 |
+
assert os.path.exists(self.cache_path), DATA_MSG
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def visual(self):
|
| 50 |
+
return self._visual
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self.dataset)
|
| 54 |
+
|
| 55 |
+
def _get_path(
|
| 56 |
+
self, video_id: str, source_dataset: str, start_offset: float, end_offset: float
|
| 57 |
+
) -> str:
|
| 58 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}.mp4"
|
| 59 |
+
select_frames = True
|
| 60 |
+
|
| 61 |
+
if not os.path.exists(path):
|
| 62 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset * 1000)}_{int(end_offset * 1000)}.mp4"
|
| 63 |
+
select_frames = False
|
| 64 |
+
|
| 65 |
+
if not os.path.exists(path):
|
| 66 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset)}_{int(end_offset)}.mp4"
|
| 67 |
+
|
| 68 |
+
if not os.path.exists(path):
|
| 69 |
+
path = f"{self.cache_path}/{source_dataset}/{video_id}.{int(start_offset * 1000):08d}_{int(end_offset * 1000):08d}.mp4"
|
| 70 |
+
|
| 71 |
+
return path, select_frames
|
| 72 |
+
|
| 73 |
+
def collate(self, items: list[Item]):
|
| 74 |
+
has_video = any(item.masked_video_frames is not None for item in items)
|
| 75 |
+
return self.collate_fn(
|
| 76 |
+
descriptions=[item.description for item in items],
|
| 77 |
+
audios=[item.audio_samples for item in items],
|
| 78 |
+
anchors=[item.anchors for item in items] if self._span else None,
|
| 79 |
+
masked_videos=[item.masked_video_frames for item in items]
|
| 80 |
+
if has_video and self._visual
|
| 81 |
+
else None,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def _get_masked_video(self, item, video_path, select_frames):
|
| 85 |
+
if item["mask_bytes"] is None:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
mask = torch.from_numpy(np.load(BytesIO(item["mask_bytes"]))["video_masklet"])
|
| 89 |
+
|
| 90 |
+
video_decoder = VideoDecoder(video_path)
|
| 91 |
+
if select_frames:
|
| 92 |
+
video_frames = video_decoder.get_frames_played_in_range(
|
| 93 |
+
item["start_offset"], item["end_offset"]
|
| 94 |
+
).data
|
| 95 |
+
else:
|
| 96 |
+
video_frames = video_decoder[:].data
|
| 97 |
+
|
| 98 |
+
if mask.size(0) != video_frames.size(0):
|
| 99 |
+
# It's possible that the mask and the video frames differ by a small amount
|
| 100 |
+
# we interpolate the mask frame to match
|
| 101 |
+
idxs = (
|
| 102 |
+
torch.linspace(0, mask.size(0) - 1, video_frames.size(0)).round().long()
|
| 103 |
+
)
|
| 104 |
+
mask = mask[idxs]
|
| 105 |
+
|
| 106 |
+
mask = mask.unsqueeze(1)
|
| 107 |
+
|
| 108 |
+
if mask.shape[-2:] != video_frames.shape[-2:]:
|
| 109 |
+
mask = F.interpolate(mask, size=video_frames.shape[-2:])
|
| 110 |
+
|
| 111 |
+
import torchvision
|
| 112 |
+
|
| 113 |
+
torchvision.io.write_video("test.mp4", video_frames.permute(0, 2, 3, 1), 30)
|
| 114 |
+
torchvision.io.write_video(
|
| 115 |
+
"test_mask.mp4", mask.unsqueeze(-1).expand(-1, -1, -1, 3) * 255, 30
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return video_frames * mask
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, idx) -> Item:
|
| 121 |
+
item = self.dataset[idx]
|
| 122 |
+
|
| 123 |
+
video_path, select_frames = self._get_path(
|
| 124 |
+
item["video_id"],
|
| 125 |
+
item["source_dataset"],
|
| 126 |
+
item["start_offset"],
|
| 127 |
+
item["end_offset"],
|
| 128 |
+
)
|
| 129 |
+
assert os.path.exists(video_path), f"{video_path} does not exist!"
|
| 130 |
+
|
| 131 |
+
audio_decoder = AudioDecoder(video_path)
|
| 132 |
+
audio_samples = audio_decoder.get_samples_played_in_range(
|
| 133 |
+
start_seconds=item["start_offset"] if select_frames else 0,
|
| 134 |
+
stop_seconds=item["end_offset"] if select_frames else None,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if audio_samples.sample_rate != self.collate_fn.audio_sampling_rate:
|
| 138 |
+
resampled_audio = torchaudio.functional.resample(
|
| 139 |
+
audio_samples.data,
|
| 140 |
+
audio_samples.sample_rate,
|
| 141 |
+
self.collate_fn.audio_sampling_rate,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
resampled_audio = audio_samples.data
|
| 145 |
+
|
| 146 |
+
masked_video_frames = self._get_masked_video(item, video_path, select_frames)
|
| 147 |
+
|
| 148 |
+
return Item(
|
| 149 |
+
description=item["description"],
|
| 150 |
+
anchors=[("+", start, end) for start, end in item["spans"]],
|
| 151 |
+
masked_video_frames=masked_video_frames,
|
| 152 |
+
audio_samples=resampled_audio.mean(0, keepdim=True),
|
| 153 |
+
)
|
eval/main.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from dataset import SETTINGS, make_dataset
|
| 11 |
+
from metrics import CLAP, Aesthetic, ImageBind, Judge
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from sam_audio import SAMAudio, SAMAudioProcessor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gather_and_average_results(results, world_size):
|
| 20 |
+
if world_size == 1:
|
| 21 |
+
return json.loads(results.mean().to_json())
|
| 22 |
+
|
| 23 |
+
# 1. Gather all dictionaries to all ranks
|
| 24 |
+
all_results = [None for _ in range(world_size)]
|
| 25 |
+
dist.all_gather_object(
|
| 26 |
+
all_results, {"sum": results.sum().to_json(), "count": len(results)}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
summed = {}
|
| 30 |
+
counts = 0
|
| 31 |
+
|
| 32 |
+
for res in all_results:
|
| 33 |
+
for k, v in json.loads(res["sum"]).items():
|
| 34 |
+
if k not in summed:
|
| 35 |
+
summed[k] = 0.0
|
| 36 |
+
summed[k] += v
|
| 37 |
+
counts += res["count"]
|
| 38 |
+
|
| 39 |
+
# 3. Compute average for keys that appeared at least once
|
| 40 |
+
averaged = {k: summed[k] / counts for k in summed}
|
| 41 |
+
|
| 42 |
+
return averaged
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def main(
|
| 46 |
+
settings: list[str],
|
| 47 |
+
cache_path: str,
|
| 48 |
+
batch_size: int,
|
| 49 |
+
checkpoint_path: str,
|
| 50 |
+
num_workers: int = 4,
|
| 51 |
+
reranking_candidates: int = 8,
|
| 52 |
+
):
|
| 53 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 54 |
+
rank = int(os.environ.get("RANK", 0))
|
| 55 |
+
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
if world_size > 1:
|
| 58 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 59 |
+
device = torch.device(f"cuda:{rank}")
|
| 60 |
+
torch.cuda.set_device(device)
|
| 61 |
+
|
| 62 |
+
model = SAMAudio.from_pretrained(checkpoint_path)
|
| 63 |
+
model = model.eval().to(device)
|
| 64 |
+
processor = SAMAudioProcessor.from_pretrained(checkpoint_path)
|
| 65 |
+
|
| 66 |
+
judge_metric = Judge(device=device)
|
| 67 |
+
aes_metric = Aesthetic(device=device)
|
| 68 |
+
clap_metric = CLAP(device=device)
|
| 69 |
+
imagebind_metric = ImageBind(device=device)
|
| 70 |
+
|
| 71 |
+
for setting in settings:
|
| 72 |
+
print(f"Evaluating: {setting}")
|
| 73 |
+
dset = make_dataset(setting, cache_path=cache_path, collate_fn=processor)
|
| 74 |
+
sampler = None
|
| 75 |
+
if world_size > 1:
|
| 76 |
+
sampler = DistributedSampler(dset)
|
| 77 |
+
|
| 78 |
+
dl = DataLoader(
|
| 79 |
+
dset,
|
| 80 |
+
batch_size=batch_size,
|
| 81 |
+
shuffle=False,
|
| 82 |
+
collate_fn=dset.collate,
|
| 83 |
+
num_workers=num_workers,
|
| 84 |
+
sampler=sampler,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
all_metrics = [
|
| 88 |
+
judge_metric,
|
| 89 |
+
aes_metric,
|
| 90 |
+
clap_metric,
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
if dset.visual:
|
| 94 |
+
all_metrics.append(imagebind_metric)
|
| 95 |
+
|
| 96 |
+
dfs = []
|
| 97 |
+
with torch.inference_mode():
|
| 98 |
+
for batch in tqdm(dl, disable=rank > 1):
|
| 99 |
+
batch = batch.to(device)
|
| 100 |
+
result = model.separate(
|
| 101 |
+
batch, reranking_candidates=reranking_candidates
|
| 102 |
+
)
|
| 103 |
+
mets = {}
|
| 104 |
+
for metric in all_metrics:
|
| 105 |
+
input_wavs = model.unbatch(batch.audios.squeeze(1), batch.wav_sizes)
|
| 106 |
+
|
| 107 |
+
mets.update(
|
| 108 |
+
metric(
|
| 109 |
+
target_wavs=result.target,
|
| 110 |
+
target_wavs_sample_rate=model.sample_rate,
|
| 111 |
+
descriptions=batch.descriptions,
|
| 112 |
+
input_wavs=input_wavs,
|
| 113 |
+
videos=batch.masked_video,
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
dfs.append(pd.DataFrame.from_dict(mets))
|
| 118 |
+
|
| 119 |
+
df = pd.concat(dfs)
|
| 120 |
+
averaged_results = gather_and_average_results(df, world_size)
|
| 121 |
+
if rank == 0:
|
| 122 |
+
results_dict = {k: f"{v:.3f}" for k, v in averaged_results.items()}
|
| 123 |
+
print(json.dumps(results_dict, indent=4))
|
| 124 |
+
os.makedirs("results", exist_ok=True)
|
| 125 |
+
outfile = f"results/{setting}.json"
|
| 126 |
+
with open(outfile, "w") as fout:
|
| 127 |
+
print(json.dumps(results_dict), file=fout)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser()
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--setting",
|
| 134 |
+
"-s",
|
| 135 |
+
choices=SETTINGS.keys(),
|
| 136 |
+
help=f"Which setting to evaluate. Choices: {SETTINGS.keys()}",
|
| 137 |
+
default=["instr-pro"],
|
| 138 |
+
nargs="+",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--cache-path",
|
| 142 |
+
type=str,
|
| 143 |
+
default=os.path.expanduser("~/.cache/sam_audio"),
|
| 144 |
+
help="Where to cache downloaded datasets",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--checkpoint-path", "-p", type=str, default="facebook/sam-audio-large"
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument("--batch-size", "-b", type=int, default=1, help="Batch size")
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--num-workers", "-w", type=int, default=4, help="Number of workers"
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument("--candidates", "-c", type=int, default=8)
|
| 154 |
+
opt = parser.parse_args()
|
| 155 |
+
main(
|
| 156 |
+
settings=opt.setting,
|
| 157 |
+
cache_path=opt.cache_path,
|
| 158 |
+
batch_size=opt.batch_size,
|
| 159 |
+
checkpoint_path=opt.checkpoint_path,
|
| 160 |
+
num_workers=opt.num_workers,
|
| 161 |
+
reranking_candidates=opt.candidates,
|
| 162 |
+
)
|
eval/metrics/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from metrics.aes import Aesthetic
|
| 4 |
+
from metrics.clap import CLAP
|
| 5 |
+
from metrics.imagebind import ImageBind
|
| 6 |
+
from metrics.judge import Judge
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"Aesthetic",
|
| 10 |
+
"CLAP",
|
| 11 |
+
"ImageBind",
|
| 12 |
+
"Judge",
|
| 13 |
+
]
|
eval/metrics/aes.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from audiobox_aesthetics.infer import AesPredictor
|
| 7 |
+
|
| 8 |
+
COLUMN_MAP = {
|
| 9 |
+
"CE": "ContentEnjoyment",
|
| 10 |
+
"CU": "ContentUsefulness",
|
| 11 |
+
"PC": "ProductionComplexity",
|
| 12 |
+
"PQ": "ProductionQuality",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Aesthetic(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
checkpoint: Optional[str] = None,
|
| 20 |
+
device: Optional[torch.device] = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.model = AesPredictor(
|
| 24 |
+
checkpoint_pth=checkpoint,
|
| 25 |
+
data_col="wav",
|
| 26 |
+
)
|
| 27 |
+
self.device = device or torch.device(
|
| 28 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
target_wavs: list[torch.Tensor],
|
| 34 |
+
target_wavs_sample_rate: int = 48_000,
|
| 35 |
+
**kwargs,
|
| 36 |
+
) -> dict[str, list[float]]:
|
| 37 |
+
result = self.model.forward(
|
| 38 |
+
[
|
| 39 |
+
{
|
| 40 |
+
"wav": wav[None] if wav.ndim == 1 else wav,
|
| 41 |
+
"sample_rate": target_wavs_sample_rate,
|
| 42 |
+
}
|
| 43 |
+
for wav in target_wavs
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
return {
|
| 47 |
+
long_name: [x[shortname] for x in result]
|
| 48 |
+
for shortname, long_name in COLUMN_MAP.items()
|
| 49 |
+
}
|
eval/metrics/clap.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from tempfile import TemporaryDirectory
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchcodec.encoders import AudioEncoder
|
| 8 |
+
|
| 9 |
+
from sam_audio.ranking.clap import get_model
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CLAP(torch.nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
checkpoint: Optional[str] = None,
|
| 16 |
+
device: Optional[torch.device] = None,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.model = get_model(device)
|
| 20 |
+
self.device = device or torch.device(
|
| 21 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def __call__(
|
| 25 |
+
self,
|
| 26 |
+
target_wavs: list[torch.Tensor],
|
| 27 |
+
descriptions: list[str],
|
| 28 |
+
target_wavs_sample_rate: int = 48_000,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> list[dict[str, float]]:
|
| 31 |
+
with TemporaryDirectory() as tdir, torch.inference_mode():
|
| 32 |
+
file_list = []
|
| 33 |
+
for i, wav in enumerate(target_wavs):
|
| 34 |
+
file_list.append(f"{tdir}/hyp_{i}.wav")
|
| 35 |
+
encoder = AudioEncoder(
|
| 36 |
+
samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(),
|
| 37 |
+
sample_rate=target_wavs_sample_rate,
|
| 38 |
+
)
|
| 39 |
+
encoder.to_file(file_list[-1])
|
| 40 |
+
audio_embs = self.model.get_audio_embedding_from_filelist(
|
| 41 |
+
file_list, use_tensor=True
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
text_embs = self.model.get_text_embedding(descriptions, use_tensor=True)
|
| 45 |
+
sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2)
|
| 46 |
+
return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()}
|
eval/metrics/imagebind.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
|
| 7 |
+
|
| 8 |
+
from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ImageBind(torch.nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
checkpoint: Optional[str] = None,
|
| 15 |
+
device: Optional[torch.device] = None,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.model = imagebind_huge(pretrained=checkpoint is None)
|
| 20 |
+
if checkpoint is not None:
|
| 21 |
+
self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
|
| 22 |
+
self.model = self.model.eval()
|
| 23 |
+
self.video_transform = VideoTransform()
|
| 24 |
+
self.device = device or torch.device(
|
| 25 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
)
|
| 27 |
+
self.model = self.model.to(self.device)
|
| 28 |
+
|
| 29 |
+
def __call__(
|
| 30 |
+
self,
|
| 31 |
+
target_wavs: list[torch.Tensor],
|
| 32 |
+
videos: list[torch.Tensor],
|
| 33 |
+
target_wavs_sample_rate: int = 48_000,
|
| 34 |
+
**kwargs,
|
| 35 |
+
) -> dict[str, list[float]]:
|
| 36 |
+
audio_data = load_and_transform_audio_data(
|
| 37 |
+
target_wavs, input_sample_rate=target_wavs_sample_rate
|
| 38 |
+
)
|
| 39 |
+
durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs]
|
| 40 |
+
video_data = self.video_transform(videos, durations, audio_data.device)
|
| 41 |
+
|
| 42 |
+
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
|
| 43 |
+
embs = self.model(inputs)
|
| 44 |
+
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
|
| 45 |
+
audio_embs, video_embs = (
|
| 46 |
+
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 47 |
+
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 48 |
+
)
|
| 49 |
+
bsz = len(target_wavs)
|
| 50 |
+
candidates = len(audio_embs) // bsz
|
| 51 |
+
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
|
| 52 |
+
return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()}
|
eval/metrics/judge.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from sam_audio import SAMAudioJudgeModel, SAMAudioJudgeProcessor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Judge(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
checkpoint: str = "facebook/sam-audio-judge",
|
| 14 |
+
device: Optional[torch.device] = None,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.model = SAMAudioJudgeModel.from_pretrained(checkpoint).to(device)
|
| 18 |
+
self.processor = SAMAudioJudgeProcessor.from_pretrained(checkpoint)
|
| 19 |
+
self.device = device or torch.device(
|
| 20 |
+
"cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(
|
| 24 |
+
self,
|
| 25 |
+
input_wavs: list[torch.Tensor],
|
| 26 |
+
target_wavs: list[torch.Tensor],
|
| 27 |
+
descriptions: list[str],
|
| 28 |
+
target_wavs_sample_rate: int = 48_000,
|
| 29 |
+
**kwargs,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
with torch.inference_mode():
|
| 32 |
+
processed = self.processor(
|
| 33 |
+
text=descriptions,
|
| 34 |
+
input_audio=[x.cpu() for x in input_wavs],
|
| 35 |
+
separated_audio=[x.cpu() for x in target_wavs],
|
| 36 |
+
sampling_rate=target_wavs_sample_rate,
|
| 37 |
+
).to(self.device)
|
| 38 |
+
result = self.model(**processed)
|
| 39 |
+
return {
|
| 40 |
+
"JudgeOverall": result.overall.squeeze(-1).cpu().tolist(),
|
| 41 |
+
"JudgeFaithfulness": result.faithfulness.squeeze(-1).cpu().tolist(),
|
| 42 |
+
"JudgeRecall": result.recall.squeeze(-1).cpu().tolist(),
|
| 43 |
+
"JudgePrecision": result.precision.squeeze(-1).cpu().tolist(),
|
| 44 |
+
}
|
examples/assets/office.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0f583ff34c5fd9d1a83d640e7c0131ad339755bd69e54f104723b707f213c21
|
| 3 |
+
size 4551702
|
examples/span_prompting.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/text_prompting.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/visual_prompting.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pyproject.toml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sam_audio"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Segment Anything Audio"
|
| 5 |
+
authors = [
|
| 6 |
+
{ name="Andros Tjandra", email="androstj@meta.com" },
|
| 7 |
+
{ name="Ann Lee", email="annl@meta.com" },
|
| 8 |
+
{ name="Bowen Shi", email="bshi@meta.com" },
|
| 9 |
+
{ name="Julius Richter", email="jrichter@meta.com" },
|
| 10 |
+
{ name="Matt Le", email="mattle@meta.com" },
|
| 11 |
+
{ name="Yi-Chiao Wu", email="yichiaowu@meta.com" },
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
readme = "README.md"
|
| 15 |
+
license = { file="LICENSE" }
|
| 16 |
+
requires-python = ">=3.10"
|
| 17 |
+
dependencies = [
|
| 18 |
+
"dacvae@git+https://github.com/facebookresearch/dacvae.git",
|
| 19 |
+
"audiobox_aesthetics",
|
| 20 |
+
"einops",
|
| 21 |
+
"imagebind@git+https://github.com/facebookresearch/ImageBind.git",
|
| 22 |
+
"laion-clap@git+https://github.com/lematt1991/CLAP.git",
|
| 23 |
+
"numpy",
|
| 24 |
+
"perception-models@git+https://github.com/facebookresearch/perception_models@unpin-deps",
|
| 25 |
+
"pydub",
|
| 26 |
+
"torch",
|
| 27 |
+
"torchaudio",
|
| 28 |
+
"torchcodec",
|
| 29 |
+
"torchdiffeq",
|
| 30 |
+
"torchvision",
|
| 31 |
+
"transformers>=4.54.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[tool.setuptools.packages.find]
|
| 35 |
+
include = ["sam_audio*"]
|
| 36 |
+
|
| 37 |
+
[tool.ruff]
|
| 38 |
+
|
| 39 |
+
target-version = "py310"
|
| 40 |
+
|
| 41 |
+
lint.select=[
|
| 42 |
+
"B",
|
| 43 |
+
"C",
|
| 44 |
+
"E",
|
| 45 |
+
"W",
|
| 46 |
+
"F",
|
| 47 |
+
"I",
|
| 48 |
+
]
|
| 49 |
+
lint.ignore = [
|
| 50 |
+
"E501",
|
| 51 |
+
"E731",
|
| 52 |
+
"C901",
|
| 53 |
+
"B006",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
[project.urls]
|
| 57 |
+
Homepage = "https://github.com/facebookresearch/sam-audio"
|
| 58 |
+
Repository = "https://github.com/facebookresearch/sam-audio"
|
| 59 |
+
|
| 60 |
+
[build-system]
|
| 61 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 62 |
+
build-backend = "setuptools.build_meta"
|
sam_audio/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from .model import * # noqa
|
| 4 |
+
from .processor import * # noqa
|
sam_audio/model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from .model import * # noqa
|
| 4 |
+
from .judge import * # noqa
|
sam_audio/model/align.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AlignModalities(torch.nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_channels: int,
|
| 12 |
+
out_channels: int,
|
| 13 |
+
normalize: bool = True,
|
| 14 |
+
with_gate: bool = True,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv = torch.nn.Conv1d(
|
| 18 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1
|
| 19 |
+
)
|
| 20 |
+
self.normalize = normalize
|
| 21 |
+
if self.normalize:
|
| 22 |
+
self.layer_norm = torch.nn.LayerNorm(out_channels)
|
| 23 |
+
|
| 24 |
+
self.gate = None
|
| 25 |
+
if with_gate:
|
| 26 |
+
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
|
| 27 |
+
|
| 28 |
+
self.out_channels = out_channels
|
| 29 |
+
|
| 30 |
+
def forward(self, anchor: torch.Tensor, tgt: Optional[torch.Tensor] = None):
|
| 31 |
+
"""
|
| 32 |
+
Align video features to the input audio features
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
anchor (torch.Tensor): Input anchor tensor of shape (B, T, C), where B is batch size, C is channel size, and T is sequence length.
|
| 36 |
+
tgt (Optional[torch.Tensor]): Optional features tensor to be aligned to anchor, expected shape (B, in_channels, T).
|
| 37 |
+
"""
|
| 38 |
+
if tgt is None:
|
| 39 |
+
return anchor
|
| 40 |
+
|
| 41 |
+
post_conv = self.conv(tgt)
|
| 42 |
+
post_conv = post_conv.permute(0, 2, 1) # BCT -> BTC
|
| 43 |
+
|
| 44 |
+
if self.normalize:
|
| 45 |
+
post_conv = self.layer_norm(post_conv)
|
| 46 |
+
|
| 47 |
+
if self.gate is None:
|
| 48 |
+
return post_conv
|
| 49 |
+
else:
|
| 50 |
+
return anchor + self.gate.tanh() * post_conv
|
sam_audio/model/base.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Callable, Dict, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from huggingface_hub import ModelHubMixin, snapshot_download
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseModel(torch.nn.Module, ModelHubMixin):
|
| 12 |
+
config_cls: Callable
|
| 13 |
+
|
| 14 |
+
def device(self):
|
| 15 |
+
return next(self.parameters()).device
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def _from_pretrained(
|
| 19 |
+
cls,
|
| 20 |
+
*,
|
| 21 |
+
model_id: str,
|
| 22 |
+
cache_dir: str,
|
| 23 |
+
force_download: bool,
|
| 24 |
+
proxies: Optional[Dict],
|
| 25 |
+
resume_download: bool,
|
| 26 |
+
local_files_only: bool,
|
| 27 |
+
token: Union[str, bool, None],
|
| 28 |
+
map_location: str = "cpu",
|
| 29 |
+
strict: bool = True,
|
| 30 |
+
revision: Optional[str] = None,
|
| 31 |
+
**model_kwargs,
|
| 32 |
+
):
|
| 33 |
+
if os.path.isdir(model_id):
|
| 34 |
+
cached_model_dir = model_id
|
| 35 |
+
else:
|
| 36 |
+
cached_model_dir = snapshot_download(
|
| 37 |
+
repo_id=model_id,
|
| 38 |
+
revision=cls.revision,
|
| 39 |
+
cache_dir=cache_dir,
|
| 40 |
+
force_download=force_download,
|
| 41 |
+
proxies=proxies,
|
| 42 |
+
resume_download=resume_download,
|
| 43 |
+
token=token,
|
| 44 |
+
local_files_only=local_files_only,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
with open(os.path.join(cached_model_dir, "config.json")) as fin:
|
| 48 |
+
config = json.load(fin)
|
| 49 |
+
|
| 50 |
+
for key, value in model_kwargs.items():
|
| 51 |
+
if key in config:
|
| 52 |
+
config[key] = value
|
| 53 |
+
|
| 54 |
+
config = cls.config_cls(**config)
|
| 55 |
+
model = cls(config)
|
| 56 |
+
state_dict = torch.load(
|
| 57 |
+
os.path.join(cached_model_dir, "checkpoint.pt"),
|
| 58 |
+
weights_only=True,
|
| 59 |
+
map_location=map_location,
|
| 60 |
+
)
|
| 61 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 62 |
+
return model
|
sam_audio/model/codec.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from abc import ABCMeta, abstractmethod
|
| 5 |
+
from typing import Union
|
| 6 |
+
|
| 7 |
+
import dacvae
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from sam_audio.model.config import DACVAEConfig
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Encoder(torch.nn.Module, metaclass=ABCMeta):
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Codec(Encoder):
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def wav_idx_to_feature_idx(
|
| 24 |
+
self, wav_idx: Union[torch.Tensor, int], sample_rate=None
|
| 25 |
+
) -> Union[torch.Tensor, int]: ...
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def feature_idx_to_wav_idx(
|
| 29 |
+
self, feature_idx: Union[torch.Tensor, int], sample_rate=None
|
| 30 |
+
) -> Union[torch.Tensor, int]: ...
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def cast_to_int(
|
| 34 |
+
x: Union[int, torch.Tensor],
|
| 35 |
+
) -> Union[int, torch.Tensor]:
|
| 36 |
+
if isinstance(x, torch.Tensor):
|
| 37 |
+
return x.int()
|
| 38 |
+
else:
|
| 39 |
+
return int(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DACVAEEncoder(Encoder):
|
| 43 |
+
def __init__(self, config: DACVAEConfig) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
model = dacvae.DACVAE(
|
| 46 |
+
encoder_dim=config.encoder_dim,
|
| 47 |
+
encoder_rates=config.encoder_rates,
|
| 48 |
+
latent_dim=config.latent_dim,
|
| 49 |
+
decoder_dim=config.decoder_dim,
|
| 50 |
+
decoder_rates=config.decoder_rates,
|
| 51 |
+
n_codebooks=config.n_codebooks,
|
| 52 |
+
codebook_size=config.codebook_size,
|
| 53 |
+
codebook_dim=config.codebook_dim,
|
| 54 |
+
quantizer_dropout=config.quantizer_dropout,
|
| 55 |
+
sample_rate=config.sample_rate,
|
| 56 |
+
).eval()
|
| 57 |
+
self._setup_model(model)
|
| 58 |
+
self.hop_length = config.hop_length
|
| 59 |
+
self.sample_rate = config.sample_rate
|
| 60 |
+
|
| 61 |
+
def _setup_model(self, model):
|
| 62 |
+
self.encoder = model.encoder
|
| 63 |
+
self.quantizer = model.quantizer
|
| 64 |
+
|
| 65 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
with torch.no_grad(), torch.backends.cudnn.flags(enabled=False):
|
| 67 |
+
z = self.encoder(self._pad(waveform))
|
| 68 |
+
mean, _ = self.quantizer.in_proj(z).chunk(2, dim=1)
|
| 69 |
+
encoded_frames = mean
|
| 70 |
+
return encoded_frames
|
| 71 |
+
|
| 72 |
+
def _pad(self, wavs):
|
| 73 |
+
length = wavs.size(-1)
|
| 74 |
+
if length % self.hop_length:
|
| 75 |
+
p1d = (0, self.hop_length - (length % self.hop_length))
|
| 76 |
+
return torch.nn.functional.pad(wavs, p1d, "reflect")
|
| 77 |
+
else:
|
| 78 |
+
return wavs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class DACVAE(DACVAEEncoder, Codec):
|
| 82 |
+
def _setup_model(self, model):
|
| 83 |
+
super()._setup_model(model)
|
| 84 |
+
self.decoder = model.decoder
|
| 85 |
+
|
| 86 |
+
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
with torch.backends.cudnn.flags(enabled=False):
|
| 88 |
+
emb = self.quantizer.out_proj(encoded_frames)
|
| 89 |
+
return self.decoder(emb)
|
| 90 |
+
|
| 91 |
+
def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
|
| 92 |
+
if sample_rate is None:
|
| 93 |
+
sample_rate = self.sample_rate
|
| 94 |
+
orig_freq = sample_rate
|
| 95 |
+
new_freq = self.sample_rate
|
| 96 |
+
wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
|
| 97 |
+
return self.cast_to_int(wav_chunklen)
|
| 98 |
+
|
| 99 |
+
def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
|
| 100 |
+
ceil = math.ceil
|
| 101 |
+
if torch.is_tensor(wav_idx):
|
| 102 |
+
ceil = torch.ceil
|
| 103 |
+
if sample_rate is None:
|
| 104 |
+
sample_rate = self.sample_rate
|
| 105 |
+
orig_freq = sample_rate
|
| 106 |
+
new_freq = self.sample_rate
|
| 107 |
+
target_length = ceil(new_freq * wav_idx / orig_freq)
|
| 108 |
+
res = ceil(target_length / self.hop_length)
|
| 109 |
+
return self.cast_to_int(res)
|
sam_audio/model/config.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
|
| 7 |
+
from transformers import ModernBertConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DACVAEConfig:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
encoder_dim: int = 64,
|
| 14 |
+
encoder_rates: list[int] = [2, 8, 10, 12],
|
| 15 |
+
latent_dim: int = 1024,
|
| 16 |
+
decoder_dim: int = 1536,
|
| 17 |
+
decoder_rates: list[int] = [12, 10, 8, 2],
|
| 18 |
+
n_codebooks: int = 16,
|
| 19 |
+
codebook_size: int = 1024,
|
| 20 |
+
codebook_dim: int = 128,
|
| 21 |
+
quantizer_dropout: bool = False,
|
| 22 |
+
sample_rate: int = 48_000,
|
| 23 |
+
mean: float = 0.0,
|
| 24 |
+
std: float = 1.0,
|
| 25 |
+
):
|
| 26 |
+
self.encoder_dim = encoder_dim
|
| 27 |
+
self.encoder_rates = encoder_rates
|
| 28 |
+
self.latent_dim = latent_dim
|
| 29 |
+
self.decoder_dim = decoder_dim
|
| 30 |
+
self.decoder_rates = decoder_rates
|
| 31 |
+
self.n_codebooks = n_codebooks
|
| 32 |
+
self.codebook_size = codebook_size
|
| 33 |
+
self.codebook_dim = codebook_dim
|
| 34 |
+
self.quantizer_dropout = quantizer_dropout
|
| 35 |
+
self.sample_rate = sample_rate
|
| 36 |
+
self.mean = mean
|
| 37 |
+
self.std = std
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def hop_length(self):
|
| 41 |
+
return int(np.prod(self.encoder_rates))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TextEncoderConfig:
|
| 45 |
+
def __init__(self, dim: int = 768):
|
| 46 |
+
self.dim = dim
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class T5EncoderConfig(TextEncoderConfig):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
name: str = "t5-base",
|
| 53 |
+
max_length: Optional[int] = 512,
|
| 54 |
+
pad_mode: str = "longest",
|
| 55 |
+
dim: int = 768,
|
| 56 |
+
):
|
| 57 |
+
super().__init__(dim=dim)
|
| 58 |
+
self.name = name
|
| 59 |
+
self.max_length = max_length
|
| 60 |
+
self.pad_mode = pad_mode
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class VisionEncoderConfig:
|
| 64 |
+
def __init__(self, dim: int = 1024, batch_size: int = 300):
|
| 65 |
+
self.dim = dim
|
| 66 |
+
self.batch_size = batch_size
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class PerceptionEncoderConfig(VisionEncoderConfig):
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
dim: int = 1024,
|
| 73 |
+
batch_size: int = 300,
|
| 74 |
+
name: str = "PE-Core-L14-336",
|
| 75 |
+
normalize_feature: bool = True,
|
| 76 |
+
interpolation_mode: str = "BICUBIC",
|
| 77 |
+
image_size: int = 336,
|
| 78 |
+
):
|
| 79 |
+
super().__init__(dim=dim, batch_size=batch_size)
|
| 80 |
+
self.name = name
|
| 81 |
+
self.normalize_feature = normalize_feature
|
| 82 |
+
self.interpolation_mode = interpolation_mode
|
| 83 |
+
self.image_size = image_size
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TransformerConfig:
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
dim: int = 2048,
|
| 90 |
+
n_heads: int = 16,
|
| 91 |
+
n_layers: int = 16,
|
| 92 |
+
dropout: float = 0.1,
|
| 93 |
+
norm_eps: float = 1.0e-05,
|
| 94 |
+
qk_norm: bool = True,
|
| 95 |
+
fc_bias: bool = False,
|
| 96 |
+
ffn_exp: int = 4,
|
| 97 |
+
ffn_dim_multiplier: int = 1,
|
| 98 |
+
multiple_of: int = 64,
|
| 99 |
+
non_linearity: str = "swiglu",
|
| 100 |
+
use_rope: bool = True,
|
| 101 |
+
max_positions: int = 10000,
|
| 102 |
+
frequency_embedding_dim: int = 256,
|
| 103 |
+
timestep_non_linearity: str = "swiglu",
|
| 104 |
+
t_block_non_linearity: str = "silu",
|
| 105 |
+
t_block_bias: bool = True,
|
| 106 |
+
context_dim: int = 2048,
|
| 107 |
+
context_non_linearity: str = "swiglu",
|
| 108 |
+
context_embedder_dropout: float = 0.0,
|
| 109 |
+
context_norm: bool = False,
|
| 110 |
+
out_channels: int = 256,
|
| 111 |
+
in_channels: Optional[int] = None,
|
| 112 |
+
):
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.n_heads = n_heads
|
| 115 |
+
self.n_layers = n_layers
|
| 116 |
+
self.dropout = dropout
|
| 117 |
+
self.norm_eps = norm_eps
|
| 118 |
+
self.qk_norm = qk_norm
|
| 119 |
+
self.fc_bias = fc_bias
|
| 120 |
+
self.ffn_exp = ffn_exp
|
| 121 |
+
self.ffn_dim_multiplier = ffn_dim_multiplier
|
| 122 |
+
self.multiple_of = multiple_of
|
| 123 |
+
self.non_linearity = non_linearity
|
| 124 |
+
self.use_rope = use_rope
|
| 125 |
+
self.max_positions = max_positions
|
| 126 |
+
self.frequency_embedding_dim = frequency_embedding_dim
|
| 127 |
+
self.timestep_non_linearity = timestep_non_linearity
|
| 128 |
+
self.t_block_non_linearity = t_block_non_linearity
|
| 129 |
+
self.t_block_bias = t_block_bias
|
| 130 |
+
self.context_dim = context_dim
|
| 131 |
+
self.context_non_linearity = context_non_linearity
|
| 132 |
+
self.context_embedder_dropout = context_embedder_dropout
|
| 133 |
+
self.context_norm = context_norm
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.in_channels = in_channels
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class RankerConfig:
|
| 139 |
+
kind: str
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ImageBindRankerConfig(RankerConfig):
|
| 143 |
+
kind: str = "imagebind"
|
| 144 |
+
|
| 145 |
+
def __init__(self, checkpoint: Optional[str] = None):
|
| 146 |
+
self.checkpoint = checkpoint
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class ClapRankerConfig(RankerConfig):
|
| 150 |
+
kind: str = "clap"
|
| 151 |
+
|
| 152 |
+
def __init__(self, checkpoint: Optional[str] = None):
|
| 153 |
+
self.checkpoint = checkpoint
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class JudgeRankerConfig(RankerConfig):
|
| 157 |
+
kind: str = "judge"
|
| 158 |
+
|
| 159 |
+
def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
|
| 160 |
+
self.checkpoint_or_model_id = checkpoint_or_model_id
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SoundActivityRankerConfig(RankerConfig):
|
| 164 |
+
kind: str = "sound_activity"
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
threshold_mode: str = "rel_to_max",
|
| 169 |
+
sil_threshold: float = -40,
|
| 170 |
+
metric: str = "iou",
|
| 171 |
+
):
|
| 172 |
+
self.threshold_mode = threshold_mode
|
| 173 |
+
self.sil_threshold = sil_threshold
|
| 174 |
+
self.metric = metric
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class EnsembleRankerConfig(RankerConfig):
|
| 178 |
+
kind: str = "ensemble"
|
| 179 |
+
|
| 180 |
+
def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
|
| 181 |
+
self.rankers = rankers
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def parse_ranker_config(config_dict: dict):
|
| 185 |
+
kind = config_dict.pop("kind")
|
| 186 |
+
match kind:
|
| 187 |
+
case ImageBindRankerConfig.kind:
|
| 188 |
+
return ImageBindRankerConfig(**config_dict)
|
| 189 |
+
case ClapRankerConfig.kind:
|
| 190 |
+
return ClapRankerConfig(**config_dict)
|
| 191 |
+
case JudgeRankerConfig.kind:
|
| 192 |
+
return JudgeRankerConfig(**config_dict)
|
| 193 |
+
case SoundActivityRankerConfig.kind:
|
| 194 |
+
return SoundActivityRankerConfig(**config_dict)
|
| 195 |
+
case EnsembleRankerConfig.kind:
|
| 196 |
+
return EnsembleRankerConfig(
|
| 197 |
+
{
|
| 198 |
+
k: (parse_ranker_config(v), w)
|
| 199 |
+
for k, (v, w) in config_dict["rankers"].items()
|
| 200 |
+
}
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class SAMAudioConfig:
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
in_channels: int = 768,
|
| 208 |
+
audio_codec=None,
|
| 209 |
+
text_encoder=None,
|
| 210 |
+
vision_encoder=None,
|
| 211 |
+
transformer=None,
|
| 212 |
+
num_anchors: int = 3,
|
| 213 |
+
anchor_embedding_dim: int = 128,
|
| 214 |
+
visual_ranker=None,
|
| 215 |
+
text_ranker=None,
|
| 216 |
+
span_predictor: Optional[str] = "pe-a-frame-large",
|
| 217 |
+
):
|
| 218 |
+
self.in_channels = in_channels
|
| 219 |
+
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| 220 |
+
self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
|
| 221 |
+
self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
|
| 222 |
+
self.transformer = TransformerConfig(**(transformer or {}))
|
| 223 |
+
self.num_anchors = num_anchors
|
| 224 |
+
self.anchor_embedding_dim = anchor_embedding_dim
|
| 225 |
+
self.visual_ranker = (
|
| 226 |
+
None if visual_ranker is None else parse_ranker_config(visual_ranker)
|
| 227 |
+
)
|
| 228 |
+
self.text_ranker = (
|
| 229 |
+
None if text_ranker is None else parse_ranker_config(text_ranker)
|
| 230 |
+
)
|
| 231 |
+
self.span_predictor = span_predictor
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class SAMAudioJudgeConfig:
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
audio_codec: DACVAEConfig = None,
|
| 238 |
+
transformer: PEAVTransformerConfig = None,
|
| 239 |
+
text_model: ModernBertConfig = None,
|
| 240 |
+
finetune_transformer: PEAVTransformerConfig = None,
|
| 241 |
+
nth_text_layer: int = 22,
|
| 242 |
+
bottleneck_dim: int = 256,
|
| 243 |
+
):
|
| 244 |
+
self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| 245 |
+
self.transformer = PEAVTransformerConfig(**(transformer or {}))
|
| 246 |
+
self.text_model = ModernBertConfig(**(text_model or {}))
|
| 247 |
+
self.finetune_transformer = PEAVTransformerConfig(
|
| 248 |
+
**(finetune_transformer or {})
|
| 249 |
+
)
|
| 250 |
+
self.nth_text_layer = nth_text_layer
|
| 251 |
+
self.bottleneck_dim = bottleneck_dim
|
sam_audio/model/judge.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from core.audio_visual_encoder.transformer import BaseModelOutputWithPooling
|
| 8 |
+
from core.audio_visual_encoder.transformer import Transformer as PEAVTransformer
|
| 9 |
+
from transformers import AutoModel
|
| 10 |
+
|
| 11 |
+
from .base import BaseModel
|
| 12 |
+
from .codec import DACVAEEncoder
|
| 13 |
+
from .config import SAMAudioJudgeConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class SAMAudioJudgeOutput:
|
| 18 |
+
r"""
|
| 19 |
+
overall (torch.Tensor, optional): Overall score tensor of shape (batch_size, 1).
|
| 20 |
+
recall (torch.Tensor, optional): Recall score tensor of shape (batch_size, 1).
|
| 21 |
+
precision (torch.Tensor, optional): Precision score tensor of shape (batch_size, 1).
|
| 22 |
+
faithfulness (torch.Tensor, optional): Faithfulness score tensor of shape (batch_size, 1).
|
| 23 |
+
text_model_output (BaseModelOutputWithPooling): Output from the text model.
|
| 24 |
+
audio_model_output (BaseModelOutputWithPooling): Output from the audio model.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
overall: Optional[torch.Tensor] = None
|
| 28 |
+
recall: Optional[torch.Tensor] = None
|
| 29 |
+
precision: Optional[torch.Tensor] = None
|
| 30 |
+
faithfulness: Optional[torch.Tensor] = None
|
| 31 |
+
text_model_output: BaseModelOutputWithPooling = None
|
| 32 |
+
audio_model_output: BaseModelOutputWithPooling = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SAMAudioJudgeModel(BaseModel):
|
| 36 |
+
config_cls = SAMAudioJudgeConfig
|
| 37 |
+
revision = "sam_audio"
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: SAMAudioJudgeConfig):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.config = config
|
| 42 |
+
self.data_proj = torch.nn.Linear(
|
| 43 |
+
config.audio_codec.codebook_dim, config.transformer.hidden_size
|
| 44 |
+
)
|
| 45 |
+
self.audio_codec = DACVAEEncoder(config.audio_codec)
|
| 46 |
+
self.transformer = PEAVTransformer(config.transformer)
|
| 47 |
+
self.finetune_transformer = PEAVTransformer(config.finetune_transformer)
|
| 48 |
+
self.text_model = AutoModel.from_config(config.text_model)
|
| 49 |
+
self.cat_audio_proj = torch.nn.Linear(
|
| 50 |
+
2 * config.transformer.hidden_size, config.bottleneck_dim
|
| 51 |
+
)
|
| 52 |
+
self.text_proj1 = torch.nn.Linear(
|
| 53 |
+
in_features=config.text_model.hidden_size,
|
| 54 |
+
out_features=config.transformer.hidden_size,
|
| 55 |
+
bias=False,
|
| 56 |
+
)
|
| 57 |
+
self.text_proj2 = torch.nn.Linear(
|
| 58 |
+
in_features=config.transformer.hidden_size,
|
| 59 |
+
out_features=config.bottleneck_dim,
|
| 60 |
+
)
|
| 61 |
+
self.layer_norm = torch.nn.LayerNorm(config.bottleneck_dim)
|
| 62 |
+
self.proj_audio_and_text = torch.nn.Linear(
|
| 63 |
+
2 * config.bottleneck_dim, config.bottleneck_dim
|
| 64 |
+
)
|
| 65 |
+
self.finetune_data_proj = torch.nn.Linear(
|
| 66 |
+
config.bottleneck_dim, config.finetune_transformer.hidden_size
|
| 67 |
+
)
|
| 68 |
+
self.head = torch.nn.Linear(
|
| 69 |
+
config.finetune_transformer.hidden_size, 4, bias=False
|
| 70 |
+
)
|
| 71 |
+
self.mean = torch.nn.Parameter(torch.zeros(4, requires_grad=False))
|
| 72 |
+
self.std = torch.nn.Parameter(torch.ones(4, requires_grad=False))
|
| 73 |
+
|
| 74 |
+
def _get_text_output(self, input_ids, attention_mask):
|
| 75 |
+
nth_layer = self.config.nth_text_layer
|
| 76 |
+
output = self.text_model(
|
| 77 |
+
input_ids=input_ids,
|
| 78 |
+
attention_mask=attention_mask,
|
| 79 |
+
output_hidden_states=nth_layer is not None,
|
| 80 |
+
)
|
| 81 |
+
if nth_layer is None:
|
| 82 |
+
text_model_output = output.last_hidden_state
|
| 83 |
+
else:
|
| 84 |
+
text_model_output = output.hidden_states[nth_layer]
|
| 85 |
+
|
| 86 |
+
return BaseModelOutputWithPooling(
|
| 87 |
+
last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
input_ids: torch.Tensor, # tokenized text
|
| 93 |
+
input_values: torch.Tensor, # input audio waveform
|
| 94 |
+
separated_values: torch.Tensor, # separated audio waveform
|
| 95 |
+
attention_mask: Optional[torch.Tensor] = None, # text attention mask
|
| 96 |
+
padding_mask: Optional[torch.Tensor] = None, # audio padding mask
|
| 97 |
+
) -> SAMAudioJudgeOutput:
|
| 98 |
+
text_features = self.text_proj1(
|
| 99 |
+
self._get_text_output(input_ids, attention_mask).pooler_output
|
| 100 |
+
)
|
| 101 |
+
stacked_audios = torch.cat([input_values, separated_values], dim=0)
|
| 102 |
+
stacked_codec_features = self.audio_codec(stacked_audios)
|
| 103 |
+
feature_padding_mask = None
|
| 104 |
+
if padding_mask is not None:
|
| 105 |
+
feature_padding_mask = padding_mask[
|
| 106 |
+
:, :: self.config.audio_codec.hop_length
|
| 107 |
+
]
|
| 108 |
+
stacked_features = self.transformer(
|
| 109 |
+
self.data_proj(stacked_codec_features.transpose(1, 2)),
|
| 110 |
+
padding_mask=feature_padding_mask,
|
| 111 |
+
)
|
| 112 |
+
input_features, hyp_features = stacked_features.last_hidden_state.chunk(2, 0)
|
| 113 |
+
audio_features = self.cat_audio_proj(
|
| 114 |
+
torch.cat([hyp_features, input_features], dim=2)
|
| 115 |
+
)
|
| 116 |
+
expanded_text = (
|
| 117 |
+
self.layer_norm(self.text_proj2(text_features))
|
| 118 |
+
.unsqueeze(1)
|
| 119 |
+
.expand_as(audio_features)
|
| 120 |
+
)
|
| 121 |
+
audio_and_text = self.proj_audio_and_text(
|
| 122 |
+
torch.cat([audio_features, expanded_text], dim=2)
|
| 123 |
+
)
|
| 124 |
+
finetune_transformer_output = self.finetune_transformer(
|
| 125 |
+
self.finetune_data_proj(audio_and_text), padding_mask=feature_padding_mask
|
| 126 |
+
)
|
| 127 |
+
result = self.head(finetune_transformer_output.last_hidden_state)
|
| 128 |
+
if feature_padding_mask is not None:
|
| 129 |
+
feature_padding_mask = feature_padding_mask.unsqueeze(-1)
|
| 130 |
+
pooled = torch.masked.mean(result, mask=feature_padding_mask, dim=1)
|
| 131 |
+
de_normalized = pooled * self.std + self.mean
|
| 132 |
+
return SAMAudioJudgeOutput(*de_normalized.chunk(4, dim=1))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
__all__ = ["SAMAudioJudgeModel", "SAMAudioJudgeOutput"]
|
sam_audio/model/model.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
|
| 10 |
+
from torchdiffeq import odeint
|
| 11 |
+
|
| 12 |
+
from sam_audio.model.align import AlignModalities
|
| 13 |
+
from sam_audio.model.base import BaseModel
|
| 14 |
+
from sam_audio.model.codec import DACVAE
|
| 15 |
+
from sam_audio.model.config import SAMAudioConfig
|
| 16 |
+
from sam_audio.model.text_encoder import T5TextEncoder
|
| 17 |
+
from sam_audio.model.transformer import DiT
|
| 18 |
+
from sam_audio.model.vision_encoder import PerceptionEncoder
|
| 19 |
+
from sam_audio.processor import Batch
|
| 20 |
+
from sam_audio.ranking import create_ranker
|
| 21 |
+
|
| 22 |
+
DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SinusoidalEmbedding(torch.nn.Module):
|
| 26 |
+
def __init__(self, dim, theta=10000):
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert (dim % 2) == 0
|
| 29 |
+
half_dim = dim // 2
|
| 30 |
+
inv_freq = torch.exp(
|
| 31 |
+
-math.log(theta) * torch.arange(half_dim).float() / half_dim
|
| 32 |
+
)
|
| 33 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, pos=None):
|
| 36 |
+
if pos is None:
|
| 37 |
+
seq_len, device = x.shape[1], x.device
|
| 38 |
+
pos = torch.arange(seq_len, device=device)
|
| 39 |
+
|
| 40 |
+
emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
|
| 41 |
+
emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
|
| 42 |
+
return emb
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class EmbedAnchors(torch.nn.Module):
|
| 46 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.embed = torch.nn.Embedding(
|
| 49 |
+
num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
|
| 50 |
+
)
|
| 51 |
+
self.gate = torch.nn.Parameter(torch.tensor([0.0]))
|
| 52 |
+
self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 58 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 59 |
+
):
|
| 60 |
+
if anchor_ids is None:
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
embs = self.embed(anchor_ids.gather(1, anchor_alignment))
|
| 64 |
+
proj = self.proj(embs)
|
| 65 |
+
return x + self.gate.tanh() * proj
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class SeparationResult:
|
| 70 |
+
target: torch.Tensor
|
| 71 |
+
residual: torch.Tensor
|
| 72 |
+
noise: torch.Tensor
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SAMAudio(BaseModel):
|
| 76 |
+
config_cls = SAMAudioConfig
|
| 77 |
+
revision = None
|
| 78 |
+
|
| 79 |
+
def __init__(self, cfg: SAMAudioConfig):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.audio_codec = DACVAE(cfg.audio_codec)
|
| 82 |
+
self.text_encoder = T5TextEncoder(cfg.text_encoder)
|
| 83 |
+
self.vision_encoder = PerceptionEncoder(cfg.vision_encoder)
|
| 84 |
+
self.transformer = DiT(cfg.transformer)
|
| 85 |
+
self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim)
|
| 86 |
+
self.align_masked_video = AlignModalities(
|
| 87 |
+
cfg.vision_encoder.dim, cfg.transformer.dim
|
| 88 |
+
)
|
| 89 |
+
self.embed_anchors = EmbedAnchors(
|
| 90 |
+
cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim
|
| 91 |
+
)
|
| 92 |
+
self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim)
|
| 93 |
+
self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
|
| 94 |
+
self.visual_ranker = create_ranker(cfg.visual_ranker)
|
| 95 |
+
self.text_ranker = create_ranker(cfg.text_ranker)
|
| 96 |
+
if cfg.span_predictor is not None:
|
| 97 |
+
self.span_predictor = PEAudioFrame.from_config(
|
| 98 |
+
cfg.span_predictor, pretrained=True
|
| 99 |
+
)
|
| 100 |
+
self.span_predictor_transform = PEAudioFrameTransform.from_config(
|
| 101 |
+
cfg.span_predictor
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def sample_rate(self):
|
| 106 |
+
return self.audio_codec.sample_rate
|
| 107 |
+
|
| 108 |
+
def align_inputs(
|
| 109 |
+
self,
|
| 110 |
+
noisy_audio,
|
| 111 |
+
audio_features: torch.Tensor,
|
| 112 |
+
masked_video_features: Optional[torch.Tensor] = None,
|
| 113 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 114 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 115 |
+
):
|
| 116 |
+
x = torch.cat(
|
| 117 |
+
[
|
| 118 |
+
noisy_audio,
|
| 119 |
+
torch.zeros_like(audio_features),
|
| 120 |
+
audio_features,
|
| 121 |
+
],
|
| 122 |
+
dim=2,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
projected = self.proj(x)
|
| 126 |
+
aligned = self.align_masked_video(projected, masked_video_features)
|
| 127 |
+
aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
|
| 128 |
+
return aligned
|
| 129 |
+
|
| 130 |
+
def forward(
|
| 131 |
+
self,
|
| 132 |
+
noisy_audio: torch.Tensor,
|
| 133 |
+
audio_features: torch.Tensor,
|
| 134 |
+
text_features: torch.Tensor,
|
| 135 |
+
time: torch.Tensor,
|
| 136 |
+
masked_video_features: Optional[torch.Tensor] = None,
|
| 137 |
+
text_mask: Optional[torch.Tensor] = None,
|
| 138 |
+
anchor_ids: Optional[torch.Tensor] = None,
|
| 139 |
+
anchor_alignment: Optional[torch.Tensor] = None,
|
| 140 |
+
audio_pad_mask: Optional[torch.Tensor] = None,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Forward pass for the model. Represents one function evaluation of the ODE.
|
| 144 |
+
In the below descriptions, B is batch size, T is sequence length, C is channel size.
|
| 145 |
+
Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features),
|
| 146 |
+
it is used only to designate a Channel or time/sequence-length dimension respectively.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised).
|
| 150 |
+
audio_features (torch.Tensor): Clean audio features [B x T x C].
|
| 151 |
+
text_features (torch.Tensor): Encoded text features tensor [B x T x C].
|
| 152 |
+
time (torch.Tensor): Timestep tensor for positional encoding [B].
|
| 153 |
+
masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T].
|
| 154 |
+
text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T].
|
| 155 |
+
anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T].
|
| 156 |
+
anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T.
|
| 157 |
+
audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T].
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
torch.Tensor
|
| 161 |
+
"""
|
| 162 |
+
aligned_inputs = self.align_inputs(
|
| 163 |
+
noisy_audio,
|
| 164 |
+
audio_features,
|
| 165 |
+
masked_video_features=masked_video_features,
|
| 166 |
+
anchor_ids=anchor_ids,
|
| 167 |
+
anchor_alignment=anchor_alignment,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1)
|
| 171 |
+
if text_features is not None:
|
| 172 |
+
memory = self.memory_proj(text_features) + timestep_emb
|
| 173 |
+
|
| 174 |
+
return self.transformer(
|
| 175 |
+
aligned_inputs,
|
| 176 |
+
time,
|
| 177 |
+
padding_mask=audio_pad_mask,
|
| 178 |
+
memory=memory,
|
| 179 |
+
memory_padding_mask=text_mask,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _get_audio_features(self, audios: torch.Tensor):
|
| 183 |
+
audio_features = self.audio_codec(audios).transpose(1, 2)
|
| 184 |
+
return torch.cat([audio_features, audio_features], dim=2)
|
| 185 |
+
|
| 186 |
+
def _get_video_features(self, video, audio_features):
|
| 187 |
+
B, T, _ = audio_features.shape
|
| 188 |
+
if video is None:
|
| 189 |
+
return audio_features.new_zeros(B, self.vision_encoder.dim, T)
|
| 190 |
+
else:
|
| 191 |
+
return self.vision_encoder(video).transpose(1, 2)
|
| 192 |
+
|
| 193 |
+
def _repeat_for_reranking(self, tensor, candidates):
|
| 194 |
+
if candidates > 1:
|
| 195 |
+
B = tensor.size(0)
|
| 196 |
+
rest = tensor.shape[1:]
|
| 197 |
+
return (
|
| 198 |
+
tensor.unsqueeze(1)
|
| 199 |
+
.expand(B, candidates, *rest)
|
| 200 |
+
.reshape(B * candidates, *rest)
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
return tensor
|
| 204 |
+
|
| 205 |
+
def _unrepeat_from_reranking(self, tensor, candidates):
|
| 206 |
+
return tensor[::candidates]
|
| 207 |
+
|
| 208 |
+
def _get_forward_args(self, batch: Batch, candidates: int = 1):
|
| 209 |
+
audio_features = self._get_audio_features(batch.audios)
|
| 210 |
+
text_features, text_mask = self.text_encoder(batch.descriptions)
|
| 211 |
+
masked_video_features = self._get_video_features(
|
| 212 |
+
batch.masked_video, audio_features
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return {
|
| 216 |
+
"audio_features": self._repeat_for_reranking(audio_features, candidates),
|
| 217 |
+
"text_features": self._repeat_for_reranking(text_features, candidates),
|
| 218 |
+
"text_mask": self._repeat_for_reranking(text_mask, candidates),
|
| 219 |
+
"masked_video_features": self._repeat_for_reranking(
|
| 220 |
+
masked_video_features, candidates
|
| 221 |
+
),
|
| 222 |
+
"anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates),
|
| 223 |
+
"anchor_alignment": self._repeat_for_reranking(
|
| 224 |
+
batch.anchor_alignment, candidates
|
| 225 |
+
),
|
| 226 |
+
"audio_pad_mask": self._repeat_for_reranking(
|
| 227 |
+
batch.audio_pad_mask, candidates
|
| 228 |
+
),
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
def predict_spans(
|
| 232 |
+
self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor
|
| 233 |
+
) -> Batch:
|
| 234 |
+
input = self.span_predictor_transform(text=batch.descriptions).to(
|
| 235 |
+
audio_features.device
|
| 236 |
+
)
|
| 237 |
+
output = self.span_predictor(
|
| 238 |
+
input_features=audio_features[:, :, :128],
|
| 239 |
+
padding_mask=audio_pad_mask,
|
| 240 |
+
return_spans=True,
|
| 241 |
+
**input,
|
| 242 |
+
)
|
| 243 |
+
anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans]
|
| 244 |
+
batch.process_anchors(anchors)
|
| 245 |
+
return batch
|
| 246 |
+
|
| 247 |
+
@torch.inference_mode()
|
| 248 |
+
def separate(
|
| 249 |
+
self,
|
| 250 |
+
batch: Batch,
|
| 251 |
+
noise: Optional[torch.Tensor] = None,
|
| 252 |
+
ode_opt: Dict[str, Any] = DFLT_ODE_OPT,
|
| 253 |
+
reranking_candidates: int = 1,
|
| 254 |
+
predict_spans: bool = False,
|
| 255 |
+
) -> SeparationResult:
|
| 256 |
+
# Encode audio
|
| 257 |
+
forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
|
| 258 |
+
|
| 259 |
+
if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
|
| 260 |
+
batch = self.predict_spans(
|
| 261 |
+
batch=batch,
|
| 262 |
+
audio_features=self._unrepeat_from_reranking(
|
| 263 |
+
forward_args["audio_features"], reranking_candidates
|
| 264 |
+
),
|
| 265 |
+
audio_pad_mask=self._unrepeat_from_reranking(
|
| 266 |
+
forward_args["audio_pad_mask"], reranking_candidates
|
| 267 |
+
),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
audio_features = forward_args["audio_features"]
|
| 271 |
+
B, T, C = audio_features.shape
|
| 272 |
+
C = C // 2 # we stack audio_features, so the actual channels is half
|
| 273 |
+
|
| 274 |
+
if noise is None:
|
| 275 |
+
noise = torch.randn_like(audio_features)
|
| 276 |
+
|
| 277 |
+
def vector_field(t, noisy_audio):
|
| 278 |
+
res = self.forward(
|
| 279 |
+
noisy_audio=noisy_audio,
|
| 280 |
+
time=t.expand(noisy_audio.size(0)),
|
| 281 |
+
**forward_args,
|
| 282 |
+
)
|
| 283 |
+
return res
|
| 284 |
+
|
| 285 |
+
states = odeint(
|
| 286 |
+
vector_field,
|
| 287 |
+
noise,
|
| 288 |
+
torch.tensor([0.0, 1.0], device=noise.device),
|
| 289 |
+
**ode_opt,
|
| 290 |
+
)
|
| 291 |
+
generated_features = states[-1].transpose(1, 2)
|
| 292 |
+
# generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension
|
| 293 |
+
wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view(
|
| 294 |
+
B, 2, -1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
bsz = wavs.size(0) // reranking_candidates
|
| 298 |
+
sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes)
|
| 299 |
+
target_wavs = self.unbatch(
|
| 300 |
+
wavs[:, 0].view(bsz, reranking_candidates, -1), sizes
|
| 301 |
+
)
|
| 302 |
+
residual_wavs = self.unbatch(
|
| 303 |
+
wavs[:, 1].view(bsz, reranking_candidates, -1), sizes
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if (
|
| 307 |
+
reranking_candidates > 1
|
| 308 |
+
and batch.masked_video is not None
|
| 309 |
+
and self.visual_ranker is not None
|
| 310 |
+
):
|
| 311 |
+
scores = self.visual_ranker(
|
| 312 |
+
extracted_audio=target_wavs,
|
| 313 |
+
videos=batch.masked_video,
|
| 314 |
+
sample_rate=self.audio_codec.sample_rate,
|
| 315 |
+
)
|
| 316 |
+
idxs = scores.argmax(dim=1)
|
| 317 |
+
elif reranking_candidates > 1 and self.text_ranker is not None:
|
| 318 |
+
input_audio = [
|
| 319 |
+
audio[:, :size].expand(reranking_candidates, -1)
|
| 320 |
+
for audio, size in zip(batch.audios, sizes, strict=False)
|
| 321 |
+
]
|
| 322 |
+
scores = self.text_ranker(
|
| 323 |
+
extracted_audio=target_wavs,
|
| 324 |
+
input_audio=input_audio,
|
| 325 |
+
descriptions=batch.descriptions,
|
| 326 |
+
sample_rate=self.audio_codec.sample_rate,
|
| 327 |
+
)
|
| 328 |
+
idxs = scores.argmax(dim=1)
|
| 329 |
+
else:
|
| 330 |
+
idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device)
|
| 331 |
+
|
| 332 |
+
return SeparationResult(
|
| 333 |
+
target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)],
|
| 334 |
+
residual=[
|
| 335 |
+
wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False)
|
| 336 |
+
],
|
| 337 |
+
noise=noise,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1):
|
| 341 |
+
result = []
|
| 342 |
+
for row, size in zip(wavs, sizes, strict=False):
|
| 343 |
+
result.append(row.narrow(dim=time_dim, start=0, length=size))
|
| 344 |
+
return result
|
| 345 |
+
|
| 346 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 347 |
+
if strict:
|
| 348 |
+
missing_keys, unexpected_keys = super().load_state_dict(
|
| 349 |
+
state_dict, strict=False
|
| 350 |
+
)
|
| 351 |
+
# We load this directly from HF, not in checkpoint
|
| 352 |
+
skip_regex = re.compile(
|
| 353 |
+
"(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)"
|
| 354 |
+
)
|
| 355 |
+
missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)]
|
| 356 |
+
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
|
| 357 |
+
raise RuntimeError(
|
| 358 |
+
f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}"
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
__all__ = ["SAMAudio"]
|
sam_audio/model/patcher.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def pad1d(
|
| 12 |
+
x: torch.Tensor,
|
| 13 |
+
paddings: Tuple[int, int],
|
| 14 |
+
mode: str = "constant",
|
| 15 |
+
value: float = 0.0,
|
| 16 |
+
):
|
| 17 |
+
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
|
| 18 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 19 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 20 |
+
"""
|
| 21 |
+
length = x.shape[-1]
|
| 22 |
+
padding_left, padding_right = paddings
|
| 23 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 24 |
+
if mode == "reflect":
|
| 25 |
+
max_pad = max(padding_left, padding_right)
|
| 26 |
+
extra_pad = 0
|
| 27 |
+
if length <= max_pad:
|
| 28 |
+
extra_pad = max_pad - length + 1
|
| 29 |
+
x = F.pad(x, (0, extra_pad))
|
| 30 |
+
padded = F.pad(x, paddings, mode, value)
|
| 31 |
+
end = padded.shape[-1] - extra_pad
|
| 32 |
+
return padded[..., :end]
|
| 33 |
+
else:
|
| 34 |
+
return F.pad(x, paddings, mode, value)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_extra_padding_for_conv1d(
|
| 38 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 39 |
+
) -> int:
|
| 40 |
+
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
|
| 41 |
+
"""See `pad_for_conv1d`."""
|
| 42 |
+
length = x.shape[-1]
|
| 43 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 44 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 45 |
+
return ideal_length - length
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Conv1d(torch.nn.Conv1d):
|
| 49 |
+
def __init__(self, *args, **kwargs):
|
| 50 |
+
super().__init__(*args, **kwargs)
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
kernel_size = self.kernel_size[0]
|
| 54 |
+
stride = self.stride[0]
|
| 55 |
+
dilation = self.dilation[0]
|
| 56 |
+
kernel_size = (
|
| 57 |
+
kernel_size - 1
|
| 58 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 59 |
+
padding_total = kernel_size - stride
|
| 60 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 61 |
+
x, kernel_size, stride, padding_total
|
| 62 |
+
)
|
| 63 |
+
# Asymmetric padding required for odd strides
|
| 64 |
+
padding_right = padding_total // 2
|
| 65 |
+
padding_left = padding_total - padding_right
|
| 66 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding))
|
| 67 |
+
return super().forward(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ConvBlock1d(torch.nn.Module):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
in_channels: int,
|
| 74 |
+
out_channels: int,
|
| 75 |
+
*,
|
| 76 |
+
kernel_size: int = 3,
|
| 77 |
+
stride: int = 1,
|
| 78 |
+
dilation: int = 1,
|
| 79 |
+
num_groups: int = 8,
|
| 80 |
+
) -> None:
|
| 81 |
+
super().__init__()
|
| 82 |
+
|
| 83 |
+
self.groupnorm = torch.nn.GroupNorm(
|
| 84 |
+
num_groups=num_groups, num_channels=in_channels
|
| 85 |
+
)
|
| 86 |
+
self.activation = torch.nn.SiLU()
|
| 87 |
+
self.project = Conv1d(
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=out_channels,
|
| 90 |
+
kernel_size=kernel_size,
|
| 91 |
+
stride=stride,
|
| 92 |
+
dilation=dilation,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
x: torch.Tensor,
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
x = self.groupnorm(x)
|
| 100 |
+
x = self.activation(x)
|
| 101 |
+
return self.project(x)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResnetBlock1d(torch.nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
in_channels: int,
|
| 108 |
+
out_channels: int,
|
| 109 |
+
*,
|
| 110 |
+
kernel_size: int = 3,
|
| 111 |
+
stride: int = 1,
|
| 112 |
+
dilation: int = 1,
|
| 113 |
+
num_groups: int = 8,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.block1 = ConvBlock1d(
|
| 118 |
+
in_channels=in_channels,
|
| 119 |
+
out_channels=out_channels,
|
| 120 |
+
kernel_size=kernel_size,
|
| 121 |
+
stride=stride,
|
| 122 |
+
dilation=dilation,
|
| 123 |
+
num_groups=num_groups,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.block2 = ConvBlock1d(
|
| 127 |
+
in_channels=out_channels,
|
| 128 |
+
out_channels=out_channels,
|
| 129 |
+
num_groups=num_groups,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.to_out = (
|
| 133 |
+
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 134 |
+
if in_channels != out_channels
|
| 135 |
+
else torch.nn.Identity()
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
h = self.block1(x)
|
| 140 |
+
h = self.block2(h)
|
| 141 |
+
return h + self.to_out(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Patcher(torch.nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channels: int,
|
| 148 |
+
out_channels: int,
|
| 149 |
+
patch_size: int,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
|
| 153 |
+
assert out_channels % patch_size == 0, assert_message
|
| 154 |
+
self.patch_size = patch_size
|
| 155 |
+
self.block = ResnetBlock1d(
|
| 156 |
+
in_channels=in_channels,
|
| 157 |
+
out_channels=out_channels // patch_size,
|
| 158 |
+
num_groups=1,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
x = self.block(x)
|
| 163 |
+
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
|
| 164 |
+
return x
|
sam_audio/model/rope.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
|
| 10 |
+
"""
|
| 11 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 12 |
+
|
| 13 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 14 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
| 18 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 19 |
+
seq_dim (int): Sequence dimension index.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 23 |
+
"""
|
| 24 |
+
ndim = x.ndim
|
| 25 |
+
assert 0 <= seq_dim < ndim
|
| 26 |
+
assert freqs_cis.shape == (
|
| 27 |
+
x.shape[seq_dim],
|
| 28 |
+
x.shape[-3],
|
| 29 |
+
2,
|
| 30 |
+
2,
|
| 31 |
+
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
|
| 32 |
+
shape = [
|
| 33 |
+
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
|
| 34 |
+
] + [2, 2]
|
| 35 |
+
return freqs_cis.view(*shape)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def apply_rotary_emb(
|
| 39 |
+
xq: torch.Tensor,
|
| 40 |
+
xk: torch.Tensor,
|
| 41 |
+
seq_dim: int,
|
| 42 |
+
freqs_cis: torch.Tensor,
|
| 43 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 45 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 46 |
+
freqs_cis = reshape_for_broadcast(
|
| 47 |
+
freqs_cis, xq_, seq_dim
|
| 48 |
+
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
|
| 49 |
+
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
|
| 50 |
+
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
|
| 51 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
RotaryEmbedding Module
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
theta: float,
|
| 62 |
+
head_dim: int,
|
| 63 |
+
max_seqlen: int = 1024,
|
| 64 |
+
scale_factor: int = 1,
|
| 65 |
+
low_freq_factor: int = 1,
|
| 66 |
+
high_freq_factor: int = 32,
|
| 67 |
+
old_context_len: int = 8192,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.theta = theta
|
| 72 |
+
self.head_dim = head_dim
|
| 73 |
+
self.max_seqlen = max_seqlen
|
| 74 |
+
self.scale_factor = scale_factor
|
| 75 |
+
self.low_freq_factor = low_freq_factor
|
| 76 |
+
self.high_freq_factor = high_freq_factor
|
| 77 |
+
self.old_context_len = old_context_len
|
| 78 |
+
if scale_factor != 1:
|
| 79 |
+
self.low_freq_wavelen = old_context_len / low_freq_factor
|
| 80 |
+
self.high_freq_wavelen = old_context_len / high_freq_factor
|
| 81 |
+
assert self.low_freq_wavelen >= self.high_freq_wavelen
|
| 82 |
+
|
| 83 |
+
def reset_parameters(self):
|
| 84 |
+
freqs_cis = self.precompute_freqs_cis(
|
| 85 |
+
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
| 86 |
+
)
|
| 87 |
+
S, D, _, _ = freqs_cis.shape
|
| 88 |
+
# S D 2 2 -> 1 S 1 D 2 2
|
| 89 |
+
freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
|
| 90 |
+
self.register_buffer(
|
| 91 |
+
"freqs_cis",
|
| 92 |
+
freqs_cis,
|
| 93 |
+
persistent=False,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def apply_scaling(self, freqs):
|
| 97 |
+
if self.scale_factor == 1:
|
| 98 |
+
return freqs
|
| 99 |
+
new_freqs = []
|
| 100 |
+
for freq in freqs:
|
| 101 |
+
wavelen = 2 * math.pi / freq
|
| 102 |
+
if wavelen < self.high_freq_wavelen:
|
| 103 |
+
new_freqs.append(freq)
|
| 104 |
+
elif wavelen > self.low_freq_wavelen:
|
| 105 |
+
new_freqs.append(freq / self.scale_factor)
|
| 106 |
+
else:
|
| 107 |
+
assert self.low_freq_wavelen != self.high_freq_wavelen
|
| 108 |
+
smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
|
| 109 |
+
self.high_freq_factor - self.low_freq_factor
|
| 110 |
+
)
|
| 111 |
+
new_freqs.append(
|
| 112 |
+
(1 - smooth) * freq / self.scale_factor + smooth * freq
|
| 113 |
+
)
|
| 114 |
+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
| 115 |
+
|
| 116 |
+
def precompute_freqs_cis(
|
| 117 |
+
self,
|
| 118 |
+
dim: int,
|
| 119 |
+
end: int,
|
| 120 |
+
theta: float = 10000.0,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 124 |
+
|
| 125 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 126 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 127 |
+
The returned tensor contains complex values in complex64 data type.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
dim (int): Dimension of the frequency tensor.
|
| 131 |
+
end (int): End index for precomputing frequencies.
|
| 132 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 136 |
+
"""
|
| 137 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 138 |
+
freqs = self.apply_scaling(freqs)
|
| 139 |
+
|
| 140 |
+
t = torch.arange(end, device=freqs.device)
|
| 141 |
+
freqs = torch.outer(t, freqs).float()
|
| 142 |
+
|
| 143 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 144 |
+
|
| 145 |
+
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
|
| 148 |
+
if bhle:
|
| 149 |
+
x = x.transpose(1, 2) # (B H L E) -> (B L H E)
|
| 150 |
+
seqlen = x.size(1)
|
| 151 |
+
x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
|
| 152 |
+
x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
|
| 153 |
+
if bhle:
|
| 154 |
+
x_out = x_out.transpose(1, 2)
|
| 155 |
+
return x_out.type_as(x)
|
sam_audio/model/text_encoder.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import transformers
|
| 7 |
+
|
| 8 |
+
from sam_audio.model.config import T5EncoderConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class T5TextEncoder(torch.nn.Module):
|
| 12 |
+
def __init__(self, cfg: T5EncoderConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.model = transformers.T5EncoderModel.from_pretrained(cfg.name)
|
| 15 |
+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.name)
|
| 16 |
+
self.pad_mode = cfg.pad_mode
|
| 17 |
+
self.max_length = cfg.max_length
|
| 18 |
+
|
| 19 |
+
def forward(self, texts: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 20 |
+
device = next(self.model.parameters()).device
|
| 21 |
+
encoded = self.tokenizer(
|
| 22 |
+
texts,
|
| 23 |
+
truncation=True,
|
| 24 |
+
max_length=self.max_length,
|
| 25 |
+
padding=self.pad_mode,
|
| 26 |
+
return_tensors="pt",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
input_ids = encoded["input_ids"].to(device)
|
| 30 |
+
attention_mask = encoded["attention_mask"].to(device)
|
| 31 |
+
res = self.model(
|
| 32 |
+
input_ids=input_ids,
|
| 33 |
+
attention_mask=attention_mask,
|
| 34 |
+
output_hidden_states=True,
|
| 35 |
+
)["last_hidden_state"]
|
| 36 |
+
|
| 37 |
+
return res, attention_mask.bool()
|
sam_audio/model/transformer.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
from .config import TransformerConfig
|
| 13 |
+
from .patcher import Patcher
|
| 14 |
+
from .rope import RotaryEmbedding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gate(x, gate):
|
| 18 |
+
return x * gate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def modulate(x, shift, scale):
|
| 22 |
+
return x * (1 + scale) + shift
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_nonlinearity(kind: str):
|
| 26 |
+
return {
|
| 27 |
+
"relu": F.relu,
|
| 28 |
+
"gelu": F.gelu,
|
| 29 |
+
"swiglu": None,
|
| 30 |
+
"approx_gelu": partial(F.gelu, approximate="tanh"),
|
| 31 |
+
"srelu": lambda x: F.relu(x) ** 2,
|
| 32 |
+
"silu": F.silu,
|
| 33 |
+
}[kind]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RMSNorm(torch.nn.Module):
|
| 37 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.eps = eps
|
| 40 |
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
| 41 |
+
|
| 42 |
+
def _norm(self, x):
|
| 43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
output = self._norm(x.float())
|
| 47 |
+
return (output * self.weight).type_as(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ProjectionLayer(torch.nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
in_dim: int,
|
| 54 |
+
out_dim: int,
|
| 55 |
+
non_linearity: str,
|
| 56 |
+
dropout: float,
|
| 57 |
+
fc_bias: bool = False,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.swiglu = non_linearity == "swiglu"
|
| 62 |
+
self.dropout = dropout
|
| 63 |
+
self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
|
| 64 |
+
|
| 65 |
+
self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias)
|
| 66 |
+
if self.swiglu:
|
| 67 |
+
self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
|
| 68 |
+
|
| 69 |
+
# non-linearity
|
| 70 |
+
self.non_linearity = get_nonlinearity(non_linearity)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
hidden1 = self.w1(x)
|
| 74 |
+
if self.swiglu:
|
| 75 |
+
hidden3 = self.w3(x)
|
| 76 |
+
hidden = F.silu(hidden1) * hidden3
|
| 77 |
+
else:
|
| 78 |
+
hidden = self.non_linearity(hidden1)
|
| 79 |
+
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
| 80 |
+
return self.w2(hidden)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Attention(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
dim: int,
|
| 87 |
+
head_dim: int,
|
| 88 |
+
n_heads: int,
|
| 89 |
+
n_kv_heads: int,
|
| 90 |
+
norm_eps: float = 1e-5,
|
| 91 |
+
use_qk_norm: bool = False,
|
| 92 |
+
fc_bias: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
assert n_heads % n_kv_heads == 0
|
| 96 |
+
|
| 97 |
+
self.head_dim = head_dim
|
| 98 |
+
self.n_heads = n_heads
|
| 99 |
+
self.n_kv_heads = n_kv_heads
|
| 100 |
+
self.use_qk_norm = use_qk_norm
|
| 101 |
+
|
| 102 |
+
self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias)
|
| 103 |
+
self.wk, self.wv = [
|
| 104 |
+
torch.nn.Linear(
|
| 105 |
+
dim,
|
| 106 |
+
n_kv_heads * head_dim,
|
| 107 |
+
bias=fc_bias,
|
| 108 |
+
)
|
| 109 |
+
for _ in range(2)
|
| 110 |
+
]
|
| 111 |
+
self.wo = torch.nn.Linear(
|
| 112 |
+
n_heads * head_dim,
|
| 113 |
+
dim,
|
| 114 |
+
bias=fc_bias,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if self.use_qk_norm is True:
|
| 118 |
+
self.q_norm = RMSNorm(head_dim, eps=norm_eps)
|
| 119 |
+
self.k_norm = RMSNorm(head_dim, eps=norm_eps)
|
| 120 |
+
|
| 121 |
+
def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor:
|
| 122 |
+
B, T, C = x.shape
|
| 123 |
+
# B x T x C -> B x T x C/H x H
|
| 124 |
+
x = x.reshape(B, T, C // heads, heads)
|
| 125 |
+
# B x T x C/H x H -> B x H x T x C/H
|
| 126 |
+
return x.permute(0, 3, 1, 2)
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
x: torch.Tensor,
|
| 131 |
+
cross_x: Optional[torch.Tensor] = None,
|
| 132 |
+
key_padding_mask: Optional[torch.Tensor] = None,
|
| 133 |
+
rope: Optional[RotaryEmbedding] = None,
|
| 134 |
+
):
|
| 135 |
+
# x: B, T, E
|
| 136 |
+
xq = self.wq(x)
|
| 137 |
+
if cross_x is not None:
|
| 138 |
+
xk, xv = self.wk(cross_x), self.wv(cross_x)
|
| 139 |
+
else:
|
| 140 |
+
xk, xv = self.wk(x), self.wv(x)
|
| 141 |
+
|
| 142 |
+
xk = self.reshape_heads(xk, self.n_kv_heads)
|
| 143 |
+
xv = self.reshape_heads(xv, self.n_kv_heads)
|
| 144 |
+
xq = self.reshape_heads(xq, self.n_heads)
|
| 145 |
+
if self.use_qk_norm:
|
| 146 |
+
xq = self.q_norm(xq)
|
| 147 |
+
xk = self.k_norm(xk)
|
| 148 |
+
|
| 149 |
+
if rope is not None:
|
| 150 |
+
xq = rope(xq, bhle=True)
|
| 151 |
+
xk = rope(xk, bhle=True)
|
| 152 |
+
|
| 153 |
+
attn_mask = None
|
| 154 |
+
|
| 155 |
+
if key_padding_mask is not None:
|
| 156 |
+
attn_mask = key_padding_mask[:, None, None, :]
|
| 157 |
+
|
| 158 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask)
|
| 159 |
+
|
| 160 |
+
output = rearrange(output, "b h n d -> b n (h d)")
|
| 161 |
+
return self.wo(output)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class FeedForward(torch.nn.Module):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
dim: int,
|
| 168 |
+
hidden_dim: int,
|
| 169 |
+
ffn_dim_multiplier: float,
|
| 170 |
+
multiple_of: int,
|
| 171 |
+
dropout: float,
|
| 172 |
+
non_linearity: str = "swiglu",
|
| 173 |
+
fc_bias: bool = False,
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.dropout = dropout
|
| 177 |
+
self.swiglu = non_linearity == "swiglu"
|
| 178 |
+
# swiglu hidden dim factor multiplier (same #params as relu / gelu)
|
| 179 |
+
if self.swiglu:
|
| 180 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 181 |
+
|
| 182 |
+
# custom dim factor multiplier
|
| 183 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 184 |
+
# round hidden dimension to `multiple_of`
|
| 185 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 186 |
+
# layers
|
| 187 |
+
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
|
| 188 |
+
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias)
|
| 189 |
+
if self.swiglu:
|
| 190 |
+
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
|
| 191 |
+
|
| 192 |
+
# non-linearity
|
| 193 |
+
self.non_linearity = get_nonlinearity(non_linearity)
|
| 194 |
+
|
| 195 |
+
def forward(
|
| 196 |
+
self,
|
| 197 |
+
x,
|
| 198 |
+
):
|
| 199 |
+
hidden1 = self.w1(x)
|
| 200 |
+
if self.swiglu:
|
| 201 |
+
hidden3 = self.w3(x)
|
| 202 |
+
hidden = F.silu(hidden1) * hidden3
|
| 203 |
+
else:
|
| 204 |
+
hidden = self.non_linearity(hidden1)
|
| 205 |
+
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
| 206 |
+
return self.w2(hidden)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class TimestepEmbedder(torch.nn.Module):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
dim: int,
|
| 213 |
+
frequency_embedding_dim: int,
|
| 214 |
+
non_linearity: str,
|
| 215 |
+
dropout: float,
|
| 216 |
+
fc_bias: bool,
|
| 217 |
+
max_period: int = 10000,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.frequency_embedding_size = frequency_embedding_dim
|
| 221 |
+
self.projection = ProjectionLayer(
|
| 222 |
+
in_dim=frequency_embedding_dim,
|
| 223 |
+
out_dim=dim,
|
| 224 |
+
non_linearity=non_linearity,
|
| 225 |
+
dropout=dropout,
|
| 226 |
+
fc_bias=fc_bias,
|
| 227 |
+
)
|
| 228 |
+
half = frequency_embedding_dim // 2
|
| 229 |
+
freqs = torch.exp(
|
| 230 |
+
-math.log(max_period)
|
| 231 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
| 232 |
+
/ half
|
| 233 |
+
)
|
| 234 |
+
self.register_buffer("freqs", freqs, persistent=False)
|
| 235 |
+
|
| 236 |
+
def timestep_embedding(self, t, dim):
|
| 237 |
+
"""
|
| 238 |
+
Create sinusoidal timestep embeddings.
|
| 239 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 240 |
+
These may be fractional.
|
| 241 |
+
:param dim: the dimension of the output.
|
| 242 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 243 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 244 |
+
"""
|
| 245 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 246 |
+
self.freqs = self.freqs.to(device=t.device)
|
| 247 |
+
args = t[:, None].float() * self.freqs[None]
|
| 248 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 249 |
+
if dim % 2:
|
| 250 |
+
embedding = torch.cat(
|
| 251 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
| 252 |
+
)
|
| 253 |
+
return embedding.to(t)
|
| 254 |
+
|
| 255 |
+
def forward(self, t):
|
| 256 |
+
x = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 257 |
+
return self.projection(x)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class ContextEmbedder(torch.nn.Module):
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
in_dim: int,
|
| 264 |
+
out_dim: int,
|
| 265 |
+
non_linearity: str,
|
| 266 |
+
dropout: float,
|
| 267 |
+
fc_bias: bool,
|
| 268 |
+
norm_eps: float = 1e-5,
|
| 269 |
+
context_norm: bool = False,
|
| 270 |
+
):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.context_norm = context_norm
|
| 273 |
+
if context_norm:
|
| 274 |
+
self.norm = RMSNorm(in_dim, norm_eps)
|
| 275 |
+
|
| 276 |
+
self.projection = ProjectionLayer(
|
| 277 |
+
in_dim=in_dim,
|
| 278 |
+
out_dim=out_dim,
|
| 279 |
+
non_linearity=non_linearity,
|
| 280 |
+
dropout=dropout,
|
| 281 |
+
fc_bias=fc_bias,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
if self.context_norm:
|
| 286 |
+
x = self.norm(x)
|
| 287 |
+
h = self.projection(x)
|
| 288 |
+
return h
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class DiTBlock(torch.nn.Module):
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
dim: int,
|
| 295 |
+
n_heads: int,
|
| 296 |
+
n_kv_heads: Optional[int] = None,
|
| 297 |
+
dropout: float = 0.0,
|
| 298 |
+
norm_eps: float = 1e-5,
|
| 299 |
+
qk_norm: bool = False,
|
| 300 |
+
fc_bias: bool = False,
|
| 301 |
+
ffn_exp: int = 1,
|
| 302 |
+
ffn_dim_multiplier: int = 4,
|
| 303 |
+
multiple_of: int = 64,
|
| 304 |
+
non_linearity: str = "silu",
|
| 305 |
+
no_cross_attention: bool = False,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
assert dim % n_heads == 0
|
| 309 |
+
self.n_heads = n_heads
|
| 310 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
| 311 |
+
self.dim = dim
|
| 312 |
+
self.dropout = dropout
|
| 313 |
+
self.head_dim = dim // n_heads
|
| 314 |
+
|
| 315 |
+
assert self.n_heads % self.n_kv_heads == 0
|
| 316 |
+
|
| 317 |
+
self.attention = Attention(
|
| 318 |
+
dim=dim,
|
| 319 |
+
head_dim=self.head_dim,
|
| 320 |
+
n_heads=self.n_heads,
|
| 321 |
+
n_kv_heads=self.n_kv_heads,
|
| 322 |
+
norm_eps=norm_eps,
|
| 323 |
+
use_qk_norm=qk_norm,
|
| 324 |
+
fc_bias=fc_bias,
|
| 325 |
+
)
|
| 326 |
+
self.feed_forward = FeedForward(
|
| 327 |
+
dim=dim,
|
| 328 |
+
hidden_dim=int(ffn_exp * dim),
|
| 329 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 330 |
+
multiple_of=multiple_of,
|
| 331 |
+
dropout=dropout,
|
| 332 |
+
non_linearity=non_linearity,
|
| 333 |
+
fc_bias=fc_bias,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)]
|
| 337 |
+
|
| 338 |
+
self.cross_attention = None
|
| 339 |
+
if not no_cross_attention:
|
| 340 |
+
self.cross_attention = Attention(
|
| 341 |
+
dim=dim,
|
| 342 |
+
head_dim=self.head_dim,
|
| 343 |
+
n_heads=self.n_heads,
|
| 344 |
+
n_kv_heads=self.n_heads,
|
| 345 |
+
norm_eps=norm_eps,
|
| 346 |
+
use_qk_norm=qk_norm,
|
| 347 |
+
fc_bias=fc_bias,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
self.scale_shift_table = nn.Parameter(
|
| 351 |
+
torch.randn(6, self.dim) / self.dim**0.5,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(
|
| 355 |
+
self,
|
| 356 |
+
x: torch.Tensor,
|
| 357 |
+
cross_x: Optional[torch.Tensor],
|
| 358 |
+
t: torch.Tensor,
|
| 359 |
+
padding_mask: Optional[torch.Tensor],
|
| 360 |
+
memory_padding_mask: Optional[torch.Tensor],
|
| 361 |
+
rope: Optional[RotaryEmbedding] = None,
|
| 362 |
+
):
|
| 363 |
+
biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1)
|
| 364 |
+
(
|
| 365 |
+
shift_msa,
|
| 366 |
+
scale_msa,
|
| 367 |
+
gate_msa,
|
| 368 |
+
shift_mlp,
|
| 369 |
+
scale_mlp,
|
| 370 |
+
gate_mlp,
|
| 371 |
+
) = biases.chunk(6, dim=1)
|
| 372 |
+
|
| 373 |
+
assert self.attention is not None and self.attention_norm is not None
|
| 374 |
+
h_attn = self.attention(
|
| 375 |
+
modulate(self.attention_norm(x), shift_msa, scale_msa),
|
| 376 |
+
key_padding_mask=padding_mask,
|
| 377 |
+
rope=rope,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
h = x + gate(h_attn, gate_msa)
|
| 381 |
+
|
| 382 |
+
if self.cross_attention is not None:
|
| 383 |
+
h_cross = self.cross_attention(
|
| 384 |
+
x=h,
|
| 385 |
+
cross_x=cross_x,
|
| 386 |
+
key_padding_mask=memory_padding_mask,
|
| 387 |
+
)
|
| 388 |
+
h = h + h_cross # residual
|
| 389 |
+
h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
|
| 390 |
+
out = h + gate(h_ff, gate_mlp)
|
| 391 |
+
return out
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class DiT(torch.nn.Module):
|
| 395 |
+
def __init__(self, config: TransformerConfig):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.dropout = config.dropout
|
| 398 |
+
if config.in_channels is not None:
|
| 399 |
+
self.data_proj = torch.nn.Linear(config.in_channels, config.dim)
|
| 400 |
+
|
| 401 |
+
# embeddings
|
| 402 |
+
self.rope_embeddings = None
|
| 403 |
+
# rotary embeddings
|
| 404 |
+
if config.use_rope:
|
| 405 |
+
self.rope_embeddings = RotaryEmbedding(
|
| 406 |
+
theta=max(10000, 2 * config.max_positions),
|
| 407 |
+
head_dim=config.dim // config.n_heads,
|
| 408 |
+
max_seqlen=config.max_positions,
|
| 409 |
+
)
|
| 410 |
+
self.rope_embeddings.reset_parameters()
|
| 411 |
+
|
| 412 |
+
# transformer blocks
|
| 413 |
+
self.layers = nn.ModuleList()
|
| 414 |
+
for _ in range(config.n_layers):
|
| 415 |
+
self.layers.append(
|
| 416 |
+
DiTBlock(
|
| 417 |
+
dim=config.dim,
|
| 418 |
+
n_heads=config.n_heads,
|
| 419 |
+
dropout=config.dropout,
|
| 420 |
+
norm_eps=config.norm_eps,
|
| 421 |
+
qk_norm=config.qk_norm,
|
| 422 |
+
fc_bias=config.fc_bias,
|
| 423 |
+
ffn_exp=config.ffn_exp,
|
| 424 |
+
ffn_dim_multiplier=config.ffn_dim_multiplier,
|
| 425 |
+
multiple_of=config.multiple_of,
|
| 426 |
+
non_linearity=config.non_linearity,
|
| 427 |
+
)
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
self.norm = RMSNorm(config.dim, config.norm_eps)
|
| 431 |
+
|
| 432 |
+
# output layer
|
| 433 |
+
self.output = torch.nn.Linear(
|
| 434 |
+
config.dim, config.out_channels, bias=config.fc_bias
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
self.x_embedder = Patcher(
|
| 438 |
+
in_channels=config.dim,
|
| 439 |
+
out_channels=config.dim,
|
| 440 |
+
patch_size=1,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
self.y_embedder = ContextEmbedder(
|
| 444 |
+
in_dim=config.context_dim,
|
| 445 |
+
out_dim=config.dim,
|
| 446 |
+
non_linearity=config.context_non_linearity,
|
| 447 |
+
dropout=config.context_embedder_dropout,
|
| 448 |
+
fc_bias=config.fc_bias,
|
| 449 |
+
norm_eps=config.norm_eps,
|
| 450 |
+
context_norm=config.context_norm,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
self.t_embedder = TimestepEmbedder(
|
| 454 |
+
config.dim,
|
| 455 |
+
config.frequency_embedding_dim,
|
| 456 |
+
non_linearity=config.timestep_non_linearity,
|
| 457 |
+
dropout=config.dropout,
|
| 458 |
+
fc_bias=config.fc_bias,
|
| 459 |
+
max_period=10000,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity)
|
| 463 |
+
self.t_block = torch.nn.Linear(
|
| 464 |
+
config.dim,
|
| 465 |
+
config.dim * 6,
|
| 466 |
+
bias=config.t_block_bias,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.final_layer_scale_shift_table = nn.Parameter(
|
| 470 |
+
torch.randn(2, config.dim) / config.dim**0.5,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def forward(
|
| 474 |
+
self,
|
| 475 |
+
x: torch.Tensor,
|
| 476 |
+
time: torch.Tensor,
|
| 477 |
+
*,
|
| 478 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 479 |
+
memory: Optional[torch.Tensor] = None,
|
| 480 |
+
memory_padding_mask: Optional[torch.Tensor] = None,
|
| 481 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 482 |
+
x = rearrange(x, "b l c-> b c l")
|
| 483 |
+
h = self.x_embedder(x)
|
| 484 |
+
h = rearrange(h, "b c l -> b l c")
|
| 485 |
+
original_N = h.shape[1]
|
| 486 |
+
N = h.shape[1]
|
| 487 |
+
|
| 488 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 489 |
+
|
| 490 |
+
t = self.t_embedder(time) # B -> B D
|
| 491 |
+
|
| 492 |
+
t0 = self.t_block_non_linearity(t)
|
| 493 |
+
t0 = self.t_block(t0) # B D -> B 6D
|
| 494 |
+
|
| 495 |
+
y = self.y_embedder(memory)
|
| 496 |
+
|
| 497 |
+
for layer in self.layers:
|
| 498 |
+
h = layer(
|
| 499 |
+
x=h,
|
| 500 |
+
cross_x=y,
|
| 501 |
+
t=t0,
|
| 502 |
+
padding_mask=padding_mask,
|
| 503 |
+
memory_padding_mask=memory_padding_mask,
|
| 504 |
+
rope=self.rope_embeddings,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk(
|
| 508 |
+
2, dim=1
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# output layer
|
| 512 |
+
if self.norm is not None:
|
| 513 |
+
h = self.norm(h)
|
| 514 |
+
|
| 515 |
+
h = modulate(h, shift, scale)
|
| 516 |
+
|
| 517 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
| 518 |
+
|
| 519 |
+
output = self.output(h)
|
| 520 |
+
|
| 521 |
+
N = output.shape[1]
|
| 522 |
+
if original_N != N:
|
| 523 |
+
output = output[:, -original_N:]
|
| 524 |
+
return output
|
sam_audio/model/vision_encoder.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision
|
| 7 |
+
from core.vision_encoder import pe
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
|
| 10 |
+
from sam_audio.model.config import (
|
| 11 |
+
PerceptionEncoderConfig,
|
| 12 |
+
VisionEncoderConfig,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RescaleTransform(object):
|
| 17 |
+
"""Rescale the image in a sample to a given size.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
output_size (tuple or int): Desired output size. If tuple, output is
|
| 21 |
+
matched to output_size. If int, smaller of image edges is matched
|
| 22 |
+
to output_size keeping aspect ratio the same.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, output_size, interpolation):
|
| 26 |
+
assert isinstance(output_size, (int, tuple))
|
| 27 |
+
self.output_size = output_size
|
| 28 |
+
if isinstance(output_size, int):
|
| 29 |
+
self.output_size = (output_size, output_size)
|
| 30 |
+
self.interpolation = interpolation
|
| 31 |
+
|
| 32 |
+
def __call__(self, sample):
|
| 33 |
+
# sample: [T, C, H, W]
|
| 34 |
+
sample = torch.nn.functional.interpolate(
|
| 35 |
+
sample.float(), size=self.output_size, mode=self.interpolation.value
|
| 36 |
+
)
|
| 37 |
+
return sample
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class VisionEncoder(torch.nn.Module, metaclass=ABCMeta):
|
| 41 |
+
def __init__(self, cfg: VisionEncoderConfig):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.batch_size = cfg.batch_size
|
| 44 |
+
self.dim = cfg.dim
|
| 45 |
+
self.transform = self.get_transform()
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def forward(self, videos: list[torch.Tensor]) -> torch.Tensor:
|
| 49 |
+
"""
|
| 50 |
+
Encodes a list of input videos. Each element of the list is a video represented
|
| 51 |
+
as a tensor [T, C, H, W]
|
| 52 |
+
Args:
|
| 53 |
+
videos (list[torch.Tensor]): List of input image tensors to be processed.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
torch.Tensor: Encoded feature representations of the input tensors.
|
| 57 |
+
The output is padded along the time dimension for variable length videos
|
| 58 |
+
"""
|
| 59 |
+
result = []
|
| 60 |
+
for video in videos:
|
| 61 |
+
video = self.transform(video)
|
| 62 |
+
if self.batch_size > 0 and video.size(0) > self.batch_size:
|
| 63 |
+
res = []
|
| 64 |
+
for i in range(0, video.size(0), self.batch_size):
|
| 65 |
+
res.append(self.encode(video[i : i + self.batch_size]))
|
| 66 |
+
result.append(torch.cat(res, dim=0))
|
| 67 |
+
else:
|
| 68 |
+
result.append(self.encode(video))
|
| 69 |
+
return pad_sequence(result, batch_first=True, padding_value=0.0)
|
| 70 |
+
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def get_transform(self):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class PerceptionEncoder(VisionEncoder):
|
| 81 |
+
def __init__(self, cfg: PerceptionEncoderConfig):
|
| 82 |
+
self.normalize_feature = cfg.normalize_feature
|
| 83 |
+
self.interpolation_mode = cfg.interpolation_mode
|
| 84 |
+
self.image_size = cfg.image_size
|
| 85 |
+
super().__init__(cfg)
|
| 86 |
+
self.model = pe.CLIP.from_config(cfg.name)
|
| 87 |
+
|
| 88 |
+
def encode(self, x):
|
| 89 |
+
image_features = self.model.encode_image(x, normalize=self.normalize_feature)
|
| 90 |
+
return image_features
|
| 91 |
+
|
| 92 |
+
def get_transform(self):
|
| 93 |
+
T = torchvision.transforms
|
| 94 |
+
try:
|
| 95 |
+
interp = getattr(T.InterpolationMode, self.interpolation_mode.upper())
|
| 96 |
+
except AttributeError as err:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"Unsupported interpolation_mode: {self.interpolation_mode}"
|
| 99 |
+
) from err
|
| 100 |
+
crop = [
|
| 101 |
+
T.Resize(
|
| 102 |
+
(self.image_size, self.image_size),
|
| 103 |
+
interpolation=interp,
|
| 104 |
+
)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
return T.Compose(
|
| 108 |
+
crop
|
| 109 |
+
+ [
|
| 110 |
+
T.Lambda(lambda x: x.float() / 255.0),
|
| 111 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 112 |
+
]
|
| 113 |
+
)
|
sam_audio/processor.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 13 |
+
from torchcodec.decoders import AudioDecoder, VideoDecoder
|
| 14 |
+
from transformers import AutoTokenizer, BatchFeature
|
| 15 |
+
|
| 16 |
+
from sam_audio.model.config import SAMAudioConfig, SAMAudioJudgeConfig
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
Anchor = Tuple[str, float, float]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def batch_audio(
|
| 24 |
+
audios: list[str | torch.Tensor], audio_sampling_rate: int = 48_000
|
| 25 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 26 |
+
wavs = []
|
| 27 |
+
for audio in audios:
|
| 28 |
+
if isinstance(audio, str):
|
| 29 |
+
wav, sr = torchaudio.load(audio)
|
| 30 |
+
if sr != audio_sampling_rate:
|
| 31 |
+
wav = torchaudio.functional.resample(wav, sr, audio_sampling_rate)
|
| 32 |
+
else:
|
| 33 |
+
wav = audio
|
| 34 |
+
wavs.append(wav.mean(0))
|
| 35 |
+
sizes = torch.tensor([wav.size(-1) for wav in wavs])
|
| 36 |
+
return pad_sequence(wavs, batch_first=True).unsqueeze(1), sizes
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Batch:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
audios: torch.Tensor,
|
| 43 |
+
sizes: torch.Tensor,
|
| 44 |
+
wav_sizes: torch.Tensor,
|
| 45 |
+
descriptions: list[str],
|
| 46 |
+
hop_length: int,
|
| 47 |
+
audio_sampling_rate: int,
|
| 48 |
+
anchors: Optional[list[list[Anchor]]] = None,
|
| 49 |
+
audio_pad_mask: Optional[torch.Tensor] = None,
|
| 50 |
+
masked_video: Optional[torch.Tensor] = None,
|
| 51 |
+
):
|
| 52 |
+
self.audios = audios
|
| 53 |
+
self.sizes = sizes
|
| 54 |
+
self.wav_sizes = wav_sizes
|
| 55 |
+
self.descriptions = descriptions
|
| 56 |
+
self.audio_pad_mask = audio_pad_mask
|
| 57 |
+
self.masked_video = masked_video
|
| 58 |
+
self.hop_length = hop_length
|
| 59 |
+
self.audio_sampling_rate = audio_sampling_rate
|
| 60 |
+
self.process_anchors(anchors)
|
| 61 |
+
assert self.audios.size(0) == len(self.descriptions)
|
| 62 |
+
|
| 63 |
+
def _wav_to_feature_idx(self, wav_idx: int):
|
| 64 |
+
return math.ceil(wav_idx / self.hop_length)
|
| 65 |
+
|
| 66 |
+
def to(self, device: torch.device):
|
| 67 |
+
self.audios = self.audios.to(device)
|
| 68 |
+
self.anchor_ids = self.anchor_ids.to(device)
|
| 69 |
+
self.anchor_alignment = self.anchor_alignment.to(device)
|
| 70 |
+
self.sizes = self.sizes.to(device)
|
| 71 |
+
self.wav_sizes = self.wav_sizes.to(device)
|
| 72 |
+
if self.audio_pad_mask is not None:
|
| 73 |
+
self.audio_pad_mask = self.audio_pad_mask.to(device)
|
| 74 |
+
if self.masked_video is not None:
|
| 75 |
+
self.masked_video = [v.to(device) for v in self.masked_video]
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
def process_anchors(self, anchors: Optional[list[list[Anchor]]]):
|
| 79 |
+
batch_size = len(self.audios)
|
| 80 |
+
anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3}
|
| 81 |
+
if anchors is None:
|
| 82 |
+
anchor_ids = torch.full(
|
| 83 |
+
(batch_size, 2), anchor_dict["<null>"], dtype=torch.long
|
| 84 |
+
)
|
| 85 |
+
anchor_ids[:, 1] = anchor_dict["<pad>"]
|
| 86 |
+
anchor_alignment = torch.full(
|
| 87 |
+
(
|
| 88 |
+
batch_size,
|
| 89 |
+
self.audio_pad_mask.size(-1),
|
| 90 |
+
),
|
| 91 |
+
0,
|
| 92 |
+
dtype=torch.long,
|
| 93 |
+
)
|
| 94 |
+
anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
|
| 95 |
+
else:
|
| 96 |
+
anchor_alignment = torch.full(
|
| 97 |
+
(
|
| 98 |
+
batch_size,
|
| 99 |
+
self.audio_pad_mask.size(-1),
|
| 100 |
+
),
|
| 101 |
+
0,
|
| 102 |
+
dtype=torch.long,
|
| 103 |
+
)
|
| 104 |
+
anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
|
| 105 |
+
ids = []
|
| 106 |
+
|
| 107 |
+
for i, anchor_list in enumerate(anchors):
|
| 108 |
+
current = [anchor_dict["<null>"], anchor_dict["<pad>"]]
|
| 109 |
+
for token, start_time, end_time in anchor_list:
|
| 110 |
+
start_idx = self._wav_to_feature_idx(
|
| 111 |
+
start_time * self.audio_sampling_rate
|
| 112 |
+
)
|
| 113 |
+
end_idx = self._wav_to_feature_idx(
|
| 114 |
+
end_time * self.audio_sampling_rate
|
| 115 |
+
)
|
| 116 |
+
anchor_alignment[i, start_idx:end_idx] = len(current)
|
| 117 |
+
current.append(anchor_dict[token])
|
| 118 |
+
ids.append(torch.tensor(current))
|
| 119 |
+
anchor_ids = pad_sequence(
|
| 120 |
+
ids, batch_first=True, padding_value=anchor_dict["<pad>"]
|
| 121 |
+
)
|
| 122 |
+
self.anchor_ids = anchor_ids
|
| 123 |
+
self.anchor_alignment = anchor_alignment
|
| 124 |
+
self.anchors = anchors
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def mask_from_sizes(sizes: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
return torch.arange(sizes.max()).expand(len(sizes), -1) < sizes.unsqueeze(1)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def load_video(
|
| 132 |
+
sizes: torch.Tensor,
|
| 133 |
+
videos: List[str],
|
| 134 |
+
feature_idx_to_wav_idx: Callable[[torch.Tensor], torch.Tensor],
|
| 135 |
+
audio_sampling_rate: int,
|
| 136 |
+
) -> list[torch.Tensor]:
|
| 137 |
+
all_frames = []
|
| 138 |
+
for size, video in zip(sizes, videos, strict=False):
|
| 139 |
+
audio_timestamps = (
|
| 140 |
+
feature_idx_to_wav_idx(torch.arange(size)) / audio_sampling_rate
|
| 141 |
+
)
|
| 142 |
+
if isinstance(video, str):
|
| 143 |
+
decoder = VideoDecoder(video, dimension_order="NCHW")
|
| 144 |
+
data = decoder.get_frames_in_range(0, len(decoder))
|
| 145 |
+
diffs = (audio_timestamps[None] - data.pts_seconds[:, None]).abs()
|
| 146 |
+
frame_idxs = diffs.argmin(dim=0)
|
| 147 |
+
frames = data.data[frame_idxs]
|
| 148 |
+
else:
|
| 149 |
+
assert video.size(1) == 3, (
|
| 150 |
+
f"Expected video tensor to be in NCHW format, but found {video.size(1)} channels"
|
| 151 |
+
)
|
| 152 |
+
idx = torch.linspace(0, video.size(0) - 1, int(size)).round().long()
|
| 153 |
+
frames = video[idx]
|
| 154 |
+
all_frames.append(frames)
|
| 155 |
+
return all_frames
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class Processor:
|
| 159 |
+
config_cls: Callable
|
| 160 |
+
|
| 161 |
+
def __init__(self, audio_hop_length: int, audio_sampling_rate: int):
|
| 162 |
+
self.audio_hop_length = audio_hop_length
|
| 163 |
+
self.audio_sampling_rate = audio_sampling_rate
|
| 164 |
+
|
| 165 |
+
@classmethod
|
| 166 |
+
def _get_config(cls, model_name_or_path: str):
|
| 167 |
+
if os.path.exists(model_name_or_path):
|
| 168 |
+
config_path = os.path.join(model_name_or_path, "config.json")
|
| 169 |
+
else:
|
| 170 |
+
config_path = hf_hub_download(
|
| 171 |
+
repo_id=model_name_or_path,
|
| 172 |
+
filename="config.json",
|
| 173 |
+
revision=cls.revision,
|
| 174 |
+
)
|
| 175 |
+
with open(config_path) as fin:
|
| 176 |
+
config = cls.config_cls(**json.load(fin))
|
| 177 |
+
return config
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def from_pretrained(cls, model_name_or_path: str) -> "Processor":
|
| 181 |
+
config = cls._get_config(model_name_or_path)
|
| 182 |
+
return cls(
|
| 183 |
+
audio_hop_length=config.audio_codec.hop_length,
|
| 184 |
+
audio_sampling_rate=config.audio_codec.sample_rate,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def feature_to_wav_idx(self, feature_idx):
|
| 188 |
+
return feature_idx * self.audio_hop_length
|
| 189 |
+
|
| 190 |
+
def wav_to_feature_idx(self, wav_idx):
|
| 191 |
+
if torch.is_tensor(wav_idx):
|
| 192 |
+
ceil = torch.ceil
|
| 193 |
+
else:
|
| 194 |
+
ceil = math.ceil
|
| 195 |
+
return ceil(wav_idx / self.audio_hop_length)
|
| 196 |
+
|
| 197 |
+
def mask_videos(
|
| 198 |
+
self,
|
| 199 |
+
videos: List[str | torch.Tensor],
|
| 200 |
+
masks: List[str | torch.Tensor],
|
| 201 |
+
) -> list[torch.Tensor]:
|
| 202 |
+
video = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in videos]
|
| 203 |
+
video_mask = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in masks]
|
| 204 |
+
return [v * m.eq(0) for v, m in zip(video, video_mask, strict=False)]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class SAMAudioProcessor(Processor):
|
| 208 |
+
config_cls = SAMAudioConfig
|
| 209 |
+
revision = None
|
| 210 |
+
|
| 211 |
+
def __call__(
|
| 212 |
+
self,
|
| 213 |
+
descriptions: list[str],
|
| 214 |
+
audios: list[str | torch.Tensor],
|
| 215 |
+
anchors: Optional[list[list[Anchor]]] = None,
|
| 216 |
+
masked_videos: Optional[list[str | torch.Tensor]] = None,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
Processes input data for the model.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
descriptions (list[str]): List of text descriptions corresponding to each audio sample.
|
| 223 |
+
audios (list[str]): List of audio file paths or tensors.
|
| 224 |
+
If a tensor:
|
| 225 |
+
- should have shape (channels, time) where channels=1 for mono and 2 for stereo.
|
| 226 |
+
- should be resampled to 48_000 hz
|
| 227 |
+
anchors (Optional[list[list[Anchor]]], optional): List of anchors for each sample,
|
| 228 |
+
where each anchor is a tuple (token, start_time, end_time).
|
| 229 |
+
masked_videos (Optional[list[str | torch.Tensor]], optional): List of masked video file paths or tensors.
|
| 230 |
+
If a tensor, should have shape (N, C, H, W)
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Batch: A Batch object containing processed audio, sizes, descriptions, anchor ids, anchor alignment, audio pad mask, and optionally masked video.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
assert len(descriptions) == len(audios)
|
| 237 |
+
assert anchors is None or len(descriptions) == len(anchors)
|
| 238 |
+
assert masked_videos is None or len(descriptions) == len(masked_videos)
|
| 239 |
+
|
| 240 |
+
audios, wav_sizes = batch_audio(audios, self.audio_sampling_rate)
|
| 241 |
+
|
| 242 |
+
sizes = self.wav_to_feature_idx(wav_sizes)
|
| 243 |
+
audio_pad_mask = mask_from_sizes(sizes)
|
| 244 |
+
masked_video = None
|
| 245 |
+
if masked_videos is not None:
|
| 246 |
+
masked_video = load_video(
|
| 247 |
+
sizes, masked_videos, self.feature_to_wav_idx, self.audio_sampling_rate
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return Batch(
|
| 251 |
+
audios=audios,
|
| 252 |
+
sizes=sizes,
|
| 253 |
+
descriptions=descriptions,
|
| 254 |
+
audio_pad_mask=audio_pad_mask,
|
| 255 |
+
anchors=anchors,
|
| 256 |
+
masked_video=masked_video,
|
| 257 |
+
hop_length=self.audio_hop_length,
|
| 258 |
+
audio_sampling_rate=self.audio_sampling_rate,
|
| 259 |
+
wav_sizes=wav_sizes,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class SAMAudioJudgeProcessor(Processor):
|
| 264 |
+
config_cls = SAMAudioJudgeConfig
|
| 265 |
+
revision = "sam_audio"
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
audio_hop_length: int,
|
| 270 |
+
audio_sampling_rate: int,
|
| 271 |
+
tokenizer: AutoTokenizer,
|
| 272 |
+
):
|
| 273 |
+
super().__init__(audio_hop_length, audio_sampling_rate)
|
| 274 |
+
self.tokenizer = tokenizer
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
def from_pretrained(cls, model_name_or_path: str) -> "SAMAudioJudgeProcessor":
|
| 278 |
+
config = cls._get_config(model_name_or_path)
|
| 279 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
| 280 |
+
return cls(
|
| 281 |
+
audio_hop_length=config.audio_codec.hop_length,
|
| 282 |
+
audio_sampling_rate=config.audio_codec.sample_rate,
|
| 283 |
+
tokenizer=tokenizer,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def _reflect_pad(self, wav):
|
| 287 |
+
if wav.ndim == 1:
|
| 288 |
+
wav = wav.unsqueeze(0)
|
| 289 |
+
if wav.size(-1) % self.audio_hop_length == 0:
|
| 290 |
+
return wav
|
| 291 |
+
p1d = (0, self.audio_hop_length - (wav.size(-1) % self.audio_hop_length))
|
| 292 |
+
return torch.nn.functional.pad(wav, p1d, mode="reflect")
|
| 293 |
+
|
| 294 |
+
def _load_audio(self, path: str):
|
| 295 |
+
ad = AudioDecoder(path, sample_rate=self.audio_sampling_rate, num_channels=1)
|
| 296 |
+
return ad.get_all_samples().data
|
| 297 |
+
|
| 298 |
+
def _process_audio(
|
| 299 |
+
self,
|
| 300 |
+
raw_audio,
|
| 301 |
+
sampling_rate: Optional[int] = None,
|
| 302 |
+
):
|
| 303 |
+
from_file = False
|
| 304 |
+
if isinstance(raw_audio, str):
|
| 305 |
+
raw_audio = [raw_audio]
|
| 306 |
+
|
| 307 |
+
if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str):
|
| 308 |
+
loaded = []
|
| 309 |
+
for audio_file in raw_audio:
|
| 310 |
+
loaded.append(self._load_audio(audio_file))
|
| 311 |
+
raw_audio = loaded
|
| 312 |
+
from_file = True
|
| 313 |
+
|
| 314 |
+
if sampling_rate is not None:
|
| 315 |
+
if sampling_rate != self.audio_sampling_rate:
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
| 318 |
+
f" {self.audio_sampling_rate}. Please make sure that the provided audio input was sampled with"
|
| 319 |
+
f" {self.audio_sampling_rate} and not {sampling_rate}."
|
| 320 |
+
)
|
| 321 |
+
elif not from_file:
|
| 322 |
+
logger.warning(
|
| 323 |
+
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
| 324 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if isinstance(raw_audio, list):
|
| 328 |
+
raw_audio = [self._reflect_pad(x).T for x in raw_audio]
|
| 329 |
+
else:
|
| 330 |
+
raw_audio = self._reflect_pad(raw_audio).T
|
| 331 |
+
|
| 332 |
+
# verify inputs are valid
|
| 333 |
+
for example in raw_audio:
|
| 334 |
+
if example.ndim > 2:
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"Expected input shape (channels, num_samples), but got shape ({example.shape})"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
lengths = torch.tensor([x.size(0) for x in raw_audio])
|
| 340 |
+
input_values = pad_sequence(raw_audio, batch_first=True).transpose(1, 2)
|
| 341 |
+
padding_mask = torch.arange(lengths.max())[None] < lengths[:, None]
|
| 342 |
+
|
| 343 |
+
return BatchFeature(
|
| 344 |
+
{"input_values": input_values, "padding_mask": padding_mask}
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def __call__(
|
| 348 |
+
self,
|
| 349 |
+
text: Optional[str] = None,
|
| 350 |
+
input_audio: Optional[
|
| 351 |
+
str | list[str] | torch.Tensor | list[torch.Tensor]
|
| 352 |
+
] = None,
|
| 353 |
+
separated_audio: Optional[
|
| 354 |
+
str | list[str] | torch.Tensor | list[torch.Tensor]
|
| 355 |
+
] = None,
|
| 356 |
+
sampling_rate: Optional[int] = None,
|
| 357 |
+
**kwargs,
|
| 358 |
+
):
|
| 359 |
+
batch = BatchFeature()
|
| 360 |
+
if text is not None:
|
| 361 |
+
batch.update(
|
| 362 |
+
self.tokenizer(
|
| 363 |
+
text,
|
| 364 |
+
return_tensors="pt",
|
| 365 |
+
padding="longest",
|
| 366 |
+
max_length=512,
|
| 367 |
+
truncation=True,
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
if input_audio is not None:
|
| 372 |
+
batch.update(self._process_audio(input_audio, sampling_rate))
|
| 373 |
+
|
| 374 |
+
if separated_audio is not None:
|
| 375 |
+
batch["separated_values"] = self._process_audio(
|
| 376 |
+
separated_audio, sampling_rate
|
| 377 |
+
)["input_values"]
|
| 378 |
+
|
| 379 |
+
return batch
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
__all__ = ["SAMAudioProcessor", "SAMAudioJudgeProcessor", "Batch"]
|
sam_audio/ranking/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from sam_audio.model.config import (
|
| 4 |
+
ClapRankerConfig,
|
| 5 |
+
EnsembleRankerConfig,
|
| 6 |
+
ImageBindRankerConfig,
|
| 7 |
+
JudgeRankerConfig,
|
| 8 |
+
)
|
| 9 |
+
from sam_audio.ranking.clap import ClapRanker
|
| 10 |
+
from sam_audio.ranking.imagebind import ImageBindRanker
|
| 11 |
+
from sam_audio.ranking.judge import JudgeRanker
|
| 12 |
+
from sam_audio.ranking.ranker import EnsembleRanker
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_ranker(config):
|
| 16 |
+
if isinstance(config, ImageBindRankerConfig):
|
| 17 |
+
return ImageBindRanker(config)
|
| 18 |
+
elif isinstance(config, ClapRankerConfig):
|
| 19 |
+
return ClapRanker(config)
|
| 20 |
+
elif isinstance(config, JudgeRankerConfig):
|
| 21 |
+
return JudgeRanker(config)
|
| 22 |
+
elif isinstance(config, EnsembleRankerConfig):
|
| 23 |
+
ranker_cfgs, weights = zip(*config.rankers.values(), strict=False)
|
| 24 |
+
return EnsembleRanker(
|
| 25 |
+
rankers=[create_ranker(cfg) for cfg in ranker_cfgs],
|
| 26 |
+
weights=weights,
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
assert config is None
|
| 30 |
+
return None
|
sam_audio/ranking/clap.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
from sam_audio.model.config import ClapRankerConfig
|
| 8 |
+
from sam_audio.ranking.ranker import Ranker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(checkpoint_file=None, device="cpu"):
|
| 12 |
+
import laion_clap
|
| 13 |
+
|
| 14 |
+
model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
|
| 15 |
+
|
| 16 |
+
if checkpoint_file is None:
|
| 17 |
+
checkpoint_file = hf_hub_download(
|
| 18 |
+
repo_id="lukewys/laion_clap", filename="630k-best.pt"
|
| 19 |
+
)
|
| 20 |
+
state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)[
|
| 21 |
+
"state_dict"
|
| 22 |
+
]
|
| 23 |
+
if next(iter(state_dict.items()))[0].startswith("module"):
|
| 24 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 25 |
+
|
| 26 |
+
if "text_branch.embeddings.position_ids" in state_dict:
|
| 27 |
+
del state_dict["text_branch.embeddings.position_ids"]
|
| 28 |
+
|
| 29 |
+
model.model.load_state_dict(state_dict)
|
| 30 |
+
return model.eval()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ClapRanker(Ranker):
|
| 34 |
+
def __init__(self, config: ClapRankerConfig):
|
| 35 |
+
from laion_clap.training import data
|
| 36 |
+
|
| 37 |
+
self.laion_data_module = data
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.config = config
|
| 40 |
+
self.model = get_model(checkpoint_file=config.checkpoint)
|
| 41 |
+
|
| 42 |
+
def _prepare_audio(self, audio, sample_rate):
|
| 43 |
+
audio_features = []
|
| 44 |
+
for candidates in audio:
|
| 45 |
+
if sample_rate != 48_000:
|
| 46 |
+
candidates = torchaudio.functional.resample(
|
| 47 |
+
candidates, sample_rate, 48000
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
quantized = self.laion_data_module.int16_to_float32_torch(
|
| 51 |
+
self.laion_data_module.float32_to_int16_torch(candidates)
|
| 52 |
+
).float()
|
| 53 |
+
for sample in quantized:
|
| 54 |
+
temp_dict = {}
|
| 55 |
+
temp_dict = self.laion_data_module.get_audio_features(
|
| 56 |
+
temp_dict,
|
| 57 |
+
sample,
|
| 58 |
+
480000,
|
| 59 |
+
data_truncating=(
|
| 60 |
+
"fusion" if self.model.enable_fusion else "rand_trunc"
|
| 61 |
+
),
|
| 62 |
+
data_filling="repeatpad",
|
| 63 |
+
audio_cfg=self.model.model_cfg["audio_cfg"],
|
| 64 |
+
require_grad=False,
|
| 65 |
+
)
|
| 66 |
+
audio_features.append(temp_dict)
|
| 67 |
+
return audio_features
|
| 68 |
+
|
| 69 |
+
@torch.inference_mode()
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
extracted_audio: list[torch.Tensor],
|
| 73 |
+
descriptions: list[str],
|
| 74 |
+
sample_rate: int = 48_000,
|
| 75 |
+
**kwargs,
|
| 76 |
+
):
|
| 77 |
+
audio_embed = self.model.model.get_audio_embedding(
|
| 78 |
+
self._prepare_audio(extracted_audio, sample_rate)
|
| 79 |
+
)
|
| 80 |
+
text_embed = self.model.get_text_embedding(descriptions, use_tensor=True)
|
| 81 |
+
bsz = len(extracted_audio)
|
| 82 |
+
candidates = len(audio_embed) // bsz
|
| 83 |
+
audio_embed = audio_embed.reshape(bsz, candidates, -1)
|
| 84 |
+
text_embed = text_embed.reshape(bsz, -1, 1)
|
| 85 |
+
scores = audio_embed @ text_embed
|
| 86 |
+
return scores.squeeze(-1)
|
sam_audio/ranking/imagebind.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio
|
| 8 |
+
|
| 9 |
+
from sam_audio.model.config import ImageBindRankerConfig
|
| 10 |
+
from sam_audio.ranking.ranker import Ranker
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from imagebind.data import (
|
| 14 |
+
ConstantClipsPerVideoSampler,
|
| 15 |
+
NormalizeVideo,
|
| 16 |
+
SpatialCrop,
|
| 17 |
+
get_clip_timepoints,
|
| 18 |
+
load_and_transform_video_data,
|
| 19 |
+
pv_transforms,
|
| 20 |
+
transforms,
|
| 21 |
+
waveform2melspec,
|
| 22 |
+
)
|
| 23 |
+
from imagebind.models.imagebind_model import ModalityType, imagebind_huge
|
| 24 |
+
|
| 25 |
+
__imagebind_exists__ = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
__imagebind_exists__ = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_and_transform_audio_data(
|
| 31 |
+
audios: List[Union[str, torch.Tensor]],
|
| 32 |
+
input_sample_rate=None,
|
| 33 |
+
num_mel_bins=128,
|
| 34 |
+
target_length=204,
|
| 35 |
+
sample_rate=16000,
|
| 36 |
+
clip_duration=2,
|
| 37 |
+
clips_per_video=3,
|
| 38 |
+
mean=-4.268,
|
| 39 |
+
std=9.138,
|
| 40 |
+
device=None,
|
| 41 |
+
):
|
| 42 |
+
if audios is None:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
audio_outputs = []
|
| 46 |
+
clip_sampler = ConstantClipsPerVideoSampler(
|
| 47 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
for audio in audios:
|
| 51 |
+
if isinstance(audio, str):
|
| 52 |
+
waveform, input_sample_rate = torchaudio.load(audio)
|
| 53 |
+
else:
|
| 54 |
+
assert torch.is_tensor(audio)
|
| 55 |
+
assert sample_rate is not None
|
| 56 |
+
# Preprocessing needs to be done in full precision
|
| 57 |
+
waveform = audio.float()
|
| 58 |
+
if waveform.ndim == 1:
|
| 59 |
+
waveform = waveform[None]
|
| 60 |
+
if sample_rate != input_sample_rate:
|
| 61 |
+
waveform = torchaudio.functional.resample(
|
| 62 |
+
waveform, orig_freq=input_sample_rate, new_freq=sample_rate
|
| 63 |
+
)
|
| 64 |
+
all_clips_timepoints = get_clip_timepoints(
|
| 65 |
+
clip_sampler, waveform.size(1) / sample_rate
|
| 66 |
+
)
|
| 67 |
+
all_clips = []
|
| 68 |
+
for clip_timepoints in all_clips_timepoints:
|
| 69 |
+
waveform_clip = waveform[
|
| 70 |
+
:,
|
| 71 |
+
int(clip_timepoints[0] * sample_rate) : int(
|
| 72 |
+
clip_timepoints[1] * sample_rate
|
| 73 |
+
),
|
| 74 |
+
]
|
| 75 |
+
waveform_melspec = waveform2melspec(
|
| 76 |
+
waveform_clip, sample_rate, num_mel_bins, target_length
|
| 77 |
+
)
|
| 78 |
+
all_clips.append(waveform_melspec)
|
| 79 |
+
|
| 80 |
+
normalize = transforms.Normalize(mean=mean, std=std)
|
| 81 |
+
all_clips = [normalize(ac).to(device) for ac in all_clips]
|
| 82 |
+
|
| 83 |
+
all_clips = torch.stack(all_clips, dim=0)
|
| 84 |
+
audio_outputs.append(all_clips)
|
| 85 |
+
|
| 86 |
+
return torch.stack(audio_outputs, dim=0)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class VideoTransform:
|
| 90 |
+
def __init__(self, clip_duration=2, clips_per_video=5):
|
| 91 |
+
self.clip_duration = clip_duration
|
| 92 |
+
self.clips_per_video = clips_per_video
|
| 93 |
+
self.clip_sampler = ConstantClipsPerVideoSampler(
|
| 94 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
| 95 |
+
)
|
| 96 |
+
self.video_transform = transforms.Compose(
|
| 97 |
+
[
|
| 98 |
+
pv_transforms.ShortSideScale(224),
|
| 99 |
+
NormalizeVideo(
|
| 100 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
| 101 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
| 102 |
+
),
|
| 103 |
+
]
|
| 104 |
+
)
|
| 105 |
+
self.spatial_crop = SpatialCrop(224, num_crops=3)
|
| 106 |
+
|
| 107 |
+
def load_video_fast(self, videos, durations, **kwargs):
|
| 108 |
+
result = []
|
| 109 |
+
for video, duration in zip(videos, durations, strict=False):
|
| 110 |
+
nframes = video.size(0)
|
| 111 |
+
fps = video.size(0) / duration
|
| 112 |
+
timepoints = get_clip_timepoints(
|
| 113 |
+
self.clip_sampler,
|
| 114 |
+
duration,
|
| 115 |
+
)
|
| 116 |
+
# Instead of loading 5 2s clips, and then sub-sampling frames, we figure
|
| 117 |
+
# Out the indices of the final clips we want and only decode those.
|
| 118 |
+
all_idxs = []
|
| 119 |
+
for start_time, end_time in timepoints:
|
| 120 |
+
idxs = torch.arange(
|
| 121 |
+
min(int(math.ceil(fps * start_time)), nframes - 1),
|
| 122 |
+
min(int(math.ceil(fps * end_time)), nframes),
|
| 123 |
+
)
|
| 124 |
+
ts = (
|
| 125 |
+
torch.linspace(0, idxs.size(0) - 1, self.clip_duration)
|
| 126 |
+
.clamp(max=idxs.size(0) - 1)
|
| 127 |
+
.long()
|
| 128 |
+
)
|
| 129 |
+
all_idxs.append(idxs[ts])
|
| 130 |
+
all_idxs = torch.cat(all_idxs)
|
| 131 |
+
fast_frames = video[all_idxs].transpose(0, 1)
|
| 132 |
+
result.append(fast_frames.chunk(self.clips_per_video, dim=1))
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def transform_video(self, batch, device=None):
|
| 136 |
+
device = device or torch.device("cpu")
|
| 137 |
+
video_outputs = []
|
| 138 |
+
for all_video in batch:
|
| 139 |
+
all_video = [
|
| 140 |
+
self.video_transform(clip.to(device) / 255.0) for clip in all_video
|
| 141 |
+
]
|
| 142 |
+
all_video = self.spatial_crop(all_video)
|
| 143 |
+
all_video = torch.stack(all_video, dim=0)
|
| 144 |
+
video_outputs.append(all_video)
|
| 145 |
+
return torch.stack(video_outputs, dim=0)
|
| 146 |
+
|
| 147 |
+
def __call__(self, videos, durations, device=None):
|
| 148 |
+
return self.transform_video(
|
| 149 |
+
self.load_video_fast(videos, durations), device=device
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ImageBindRanker(Ranker):
|
| 154 |
+
def __init__(self, cfg: ImageBindRankerConfig):
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert __imagebind_exists__, (
|
| 157 |
+
"Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.model = imagebind_huge(pretrained=cfg.checkpoint is None)
|
| 161 |
+
if cfg.checkpoint is not None:
|
| 162 |
+
self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu"))
|
| 163 |
+
self.model = self.model.eval()
|
| 164 |
+
self.video_transform = VideoTransform()
|
| 165 |
+
|
| 166 |
+
@torch.inference_mode()
|
| 167 |
+
def forward(
|
| 168 |
+
self,
|
| 169 |
+
extracted_audio: list[torch.Tensor],
|
| 170 |
+
videos: list[torch.Tensor | str],
|
| 171 |
+
sample_rate: int = 48_000,
|
| 172 |
+
**kwargs,
|
| 173 |
+
):
|
| 174 |
+
audio_data = torch.cat(
|
| 175 |
+
[
|
| 176 |
+
load_and_transform_audio_data(x, input_sample_rate=sample_rate)
|
| 177 |
+
for x in extracted_audio
|
| 178 |
+
],
|
| 179 |
+
dim=0,
|
| 180 |
+
)
|
| 181 |
+
if isinstance(videos[0], str):
|
| 182 |
+
video_data = load_and_transform_video_data(videos)
|
| 183 |
+
else:
|
| 184 |
+
durations = [x.size(-1) / sample_rate for x in extracted_audio]
|
| 185 |
+
video_data = self.video_transform(videos, durations, audio_data.device)
|
| 186 |
+
|
| 187 |
+
inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
|
| 188 |
+
embs = self.model(inputs)
|
| 189 |
+
audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
|
| 190 |
+
audio_embs, video_embs = (
|
| 191 |
+
audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 192 |
+
video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
|
| 193 |
+
)
|
| 194 |
+
bsz = len(extracted_audio)
|
| 195 |
+
candidates = len(audio_embs) // bsz
|
| 196 |
+
scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
|
| 197 |
+
return scores
|
sam_audio/ranking/judge.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ..model.config import JudgeRankerConfig
|
| 6 |
+
from ..model.judge import SAMAudioJudgeModel
|
| 7 |
+
from ..processor import SAMAudioJudgeProcessor
|
| 8 |
+
from .ranker import Ranker
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class JudgeRanker(Ranker):
|
| 12 |
+
def __init__(self, config: JudgeRankerConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.config = config
|
| 15 |
+
self.model = SAMAudioJudgeModel.from_pretrained(config.checkpoint_or_model_id)
|
| 16 |
+
self.processor = SAMAudioJudgeProcessor.from_pretrained(
|
| 17 |
+
config.checkpoint_or_model_id
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@torch.inference_mode()
|
| 21 |
+
def forward(
|
| 22 |
+
self,
|
| 23 |
+
input_audio: list[torch.Tensor],
|
| 24 |
+
extracted_audio: list[torch.Tensor],
|
| 25 |
+
descriptions: list[str],
|
| 26 |
+
sample_rate: int = 48_000,
|
| 27 |
+
**kwargs,
|
| 28 |
+
):
|
| 29 |
+
bsz, ncandidates = len(input_audio), len(input_audio[0])
|
| 30 |
+
input_seqs = [x[None] for candidates in input_audio for x in candidates]
|
| 31 |
+
extracted_seqs = [x[None] for candidates in extracted_audio for x in candidates]
|
| 32 |
+
repeated_descriptions = [x for x in descriptions for _ in range(ncandidates)]
|
| 33 |
+
processed = self.processor(
|
| 34 |
+
text=repeated_descriptions,
|
| 35 |
+
input_audio=input_seqs,
|
| 36 |
+
separated_audio=extracted_seqs,
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
padding=True,
|
| 39 |
+
sampling_rate=sample_rate,
|
| 40 |
+
)
|
| 41 |
+
res = self.model(**processed.to(input_audio[0].device))
|
| 42 |
+
return res.overall.view(bsz, ncandidates)
|
sam_audio/ranking/ranker.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from abc import ABCMeta, abstractmethod
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Ranker(torch.nn.Module, metaclass=ABCMeta):
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def forward(self, audio: list[torch.Tensor], **kwargs) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
audio: (list[torch.Tensor]) where each element in the list corresponds to
|
| 15 |
+
the candidates for the i'th generation (num_candidates, num_frames)
|
| 16 |
+
Returns:
|
| 17 |
+
(torch.Tensor) of shape (batch_size, num_candidates) correspoding to the ranking scores
|
| 18 |
+
"""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EnsembleRanker(Ranker):
|
| 23 |
+
def __init__(self, rankers: List[Ranker], weights: List[float]):
|
| 24 |
+
super().__init__()
|
| 25 |
+
assert len(rankers) == len(weights)
|
| 26 |
+
self.rankers = torch.nn.ModuleList(rankers)
|
| 27 |
+
self.weights = weights
|
| 28 |
+
|
| 29 |
+
def forward(self, **kwargs) -> torch.Tensor:
|
| 30 |
+
result = None
|
| 31 |
+
for weight, ranker in zip(self.weights, self.rankers, strict=False):
|
| 32 |
+
if result is None:
|
| 33 |
+
result = weight * ranker(**kwargs)
|
| 34 |
+
else:
|
| 35 |
+
result += weight * ranker(**kwargs)
|
| 36 |
+
return result
|
sam_audio/ranking/sound_activity.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
|
| 2 |
+
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torchcodec.encoders import AudioEncoder
|
| 8 |
+
|
| 9 |
+
from ..model.config import SoundActivityRankerConfig
|
| 10 |
+
from .ranker import Ranker
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import pydub
|
| 14 |
+
except ImportError:
|
| 15 |
+
pydub = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_peak_rms(audio, win_ms=250, hop_ms=100):
|
| 19 |
+
"""
|
| 20 |
+
win_length and hop_length are in ms
|
| 21 |
+
"""
|
| 22 |
+
last_slice_start = len(audio) - win_ms
|
| 23 |
+
slice_starts = range(0, last_slice_start + 1, hop_ms)
|
| 24 |
+
peak_rms = -1
|
| 25 |
+
for i in slice_starts:
|
| 26 |
+
audio_slice = audio[i : i + win_ms]
|
| 27 |
+
peak_rms = max(peak_rms, audio_slice.rms / audio.max_possible_amplitude)
|
| 28 |
+
# Ensure peak_rms is positive
|
| 29 |
+
peak_rms = max(peak_rms, 0)
|
| 30 |
+
return peak_rms
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def torch_tensor_to_pydub(wav: torch.Tensor, sample_rate: int):
|
| 34 |
+
bytesio = BytesIO()
|
| 35 |
+
encoder = AudioEncoder(wav, sample_rate=sample_rate)
|
| 36 |
+
encoder.to_file_like(bytesio, format="wav")
|
| 37 |
+
bytesio.seek(0)
|
| 38 |
+
audio = pydub.AudioSegment.from_file(bytesio, format="wav")
|
| 39 |
+
return audio
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def detect_nonsilent(
|
| 43 |
+
path: Union[str, Tuple[torch.Tensor, int]], # either a file path or pair wav & sr
|
| 44 |
+
min_sil_ms=250,
|
| 45 |
+
sil_threshold=-40,
|
| 46 |
+
threshold_mode="rel_to_max",
|
| 47 |
+
):
|
| 48 |
+
TH_MODES = {"abs", "rel_to_max"}
|
| 49 |
+
SAMPLE_RATE = 24_000
|
| 50 |
+
assert threshold_mode in TH_MODES, f"{threshold_mode=} not in {TH_MODES}"
|
| 51 |
+
if isinstance(path, str):
|
| 52 |
+
audio = pydub.AudioSegment.from_file(path)
|
| 53 |
+
else: # tuple of (tensor, sr)
|
| 54 |
+
audio = torch_tensor_to_pydub(path[0], path[1])
|
| 55 |
+
audio = audio.set_frame_rate(SAMPLE_RATE)
|
| 56 |
+
if threshold_mode == "rel_to_max":
|
| 57 |
+
peak_rms = get_peak_rms(audio)
|
| 58 |
+
sil_threshold = sil_threshold + pydub.utils.ratio_to_db(
|
| 59 |
+
peak_rms
|
| 60 |
+
) # convert to absolute db threshold
|
| 61 |
+
elif threshold_mode == "abs":
|
| 62 |
+
pass
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError(f"Unknown threshold_mode '{threshold_mode}'")
|
| 65 |
+
spans = pydub.silence.detect_nonsilent(
|
| 66 |
+
audio, min_silence_len=min_sil_ms, silence_thresh=sil_threshold, seek_step=10
|
| 67 |
+
)
|
| 68 |
+
spans = [(round(start / 1000, 3), round(end / 1000, 3)) for start, end in spans]
|
| 69 |
+
return spans
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_iou_recall_precision(hyp_spans, ref_spans):
|
| 73 |
+
def span_length(span):
|
| 74 |
+
return span[1] - span[0]
|
| 75 |
+
|
| 76 |
+
def intersection_length(span1, span2):
|
| 77 |
+
return max(0, min(span1[1], span2[1]) - max(span1[0], span2[0]))
|
| 78 |
+
|
| 79 |
+
total_hyp_length = sum(span_length(span) for span in hyp_spans)
|
| 80 |
+
total_ref_length = sum(span_length(span) for span in ref_spans)
|
| 81 |
+
total_intersection = 0
|
| 82 |
+
for hyp_span in hyp_spans:
|
| 83 |
+
for ref_span in ref_spans:
|
| 84 |
+
total_intersection += intersection_length(hyp_span, ref_span)
|
| 85 |
+
|
| 86 |
+
union_spans = hyp_spans + ref_spans # Combine both lists to compute union
|
| 87 |
+
union_length = sum(span_length(span) for span in union_spans) - total_intersection
|
| 88 |
+
|
| 89 |
+
iou = total_intersection / union_length if union_length > 0 else 0
|
| 90 |
+
recall = total_intersection / total_ref_length if total_ref_length > 0 else 0
|
| 91 |
+
precision = total_intersection / total_hyp_length if total_hyp_length > 0 else 0
|
| 92 |
+
|
| 93 |
+
return {"iou": iou, "recall": recall, "precision": precision}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SoundActivityRanker(Ranker):
|
| 97 |
+
def __init__(self, config: SoundActivityRankerConfig):
|
| 98 |
+
if pydub is None:
|
| 99 |
+
raise ImportError(
|
| 100 |
+
'Install reranking dependencies: `pip install "sam-audio[reranking]"`'
|
| 101 |
+
)
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.config = config
|
| 104 |
+
|
| 105 |
+
@torch.inference_mode()
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
extracted_audio: list[torch.Tensor],
|
| 109 |
+
spans: list[list[list[float]]],
|
| 110 |
+
sample_rate: int = 48_000,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
device = extracted_audio[0].device
|
| 114 |
+
scores = []
|
| 115 |
+
for wav, current_spans in zip(extracted_audio, spans, strict=True):
|
| 116 |
+
wav = wav.to(torch.float32).cpu()
|
| 117 |
+
# get non-silent spans
|
| 118 |
+
hyp_spans = detect_nonsilent(
|
| 119 |
+
(wav, sample_rate),
|
| 120 |
+
sil_threshold=self.config.sil_threshold,
|
| 121 |
+
threshold_mode=self.config.threshold_mode,
|
| 122 |
+
)
|
| 123 |
+
timestamps = [[span[1], span[2]] for span in current_spans]
|
| 124 |
+
result = compute_iou_recall_precision(hyp_spans, timestamps)
|
| 125 |
+
scores.append(result[self.config.metric])
|
| 126 |
+
|
| 127 |
+
# convert to tensor
|
| 128 |
+
scores = torch.tensor(scores, device=device)
|
| 129 |
+
return scores
|