ray-006 commited on
Commit
fc605f9
·
verified ·
1 Parent(s): e1c7597

Upload 43 files

Browse files
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Sam_audio/assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
37
  Sam_audio/examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Sam_audio/assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
37
  Sam_audio/examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/sam_audio_main_model.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/assets/office.mp4 filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+ on:
3
+ push:
4
+ pull_request:
5
+
6
+ env:
7
+ CACHE_NUMBER: 0
8
+
9
+ jobs:
10
+ build:
11
+ runs-on: 32-core-ubuntu
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+
15
+ - uses: mamba-org/setup-micromamba@v1.8.1
16
+ with:
17
+ environment-name: sam-audio
18
+ init-shell: bash
19
+ create-args: >-
20
+ python=3.11
21
+ pip=24.2
22
+ ruff==
23
+
24
+ - uses: actions/cache@v4
25
+ with:
26
+ path: /home/runner/micromamba/envs/sam-audio
27
+ key: ${{ hashFiles('pyproject.toml') }}-${{ env.CACHE_NUMBER }}
28
+ id: cache
29
+
30
+ - name: Update environment
31
+ shell: bash -l {0}
32
+ run: |
33
+ pip install .
34
+ if: steps.cache.outputs.cache-hit != 'true'
35
+
36
+ - name: Check formatting
37
+ shell: bash -l {0}
38
+ run: |
39
+ ruff format --check .
40
+
41
+ - name: Lint
42
+ shell: bash -l {0}
43
+ run: |
44
+ ruff check .
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.egg-info
3
+ *.pyc
4
+ *.so
5
+ build
6
+ dist
7
+ .checkpoints
8
+ .ipynb_checkpoints
.pre-commit-config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: check-yaml
7
+ args:
8
+ - --allow-multiple-documents
9
+ - id: end-of-file-fixer
10
+ - repo: https://github.com/astral-sh/ruff-pre-commit
11
+ rev: v0.12.0
12
+ hooks:
13
+ - id: ruff
14
+ args: [ --fix ]
15
+ - id: ruff-format
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to segment-anything-model-audio
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to segment-anything-model-audio, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAM License
2
+ Last Updated: November 19, 2025
3
+
4
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
5
+
6
+
7
+ “SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
8
+
9
+ “Documentation” means the specifications, manuals and documentation accompanying
10
+ SAM Materials distributed by Meta.
11
+
12
+
13
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
14
+
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+
18
+
19
+ “Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
20
+
21
+
22
+ “Trade Controls” means any of the following: Sanctions and applicable export and import controls.
23
+
24
+ By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
25
+
26
+
27
+ 1. License Rights and Redistribution.
28
+
29
+
30
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
31
+
32
+ b. Redistribution and Use.
33
+ i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
34
+
35
+
36
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
37
+
38
+
39
+ iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
40
+ iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
41
+ v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
42
+ 2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
43
+
44
+
45
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
46
+
47
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
48
+
49
+ 5. Intellectual Property.
50
+
51
+
52
+ a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
53
+
54
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
55
+
56
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
57
+
58
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
59
+
60
+
61
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
README.md CHANGED
@@ -1,14 +1,137 @@
1
- ---
2
- title: Sample Audio
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Sample-Audio
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # SAM-Audio
4
+
5
+ ![CI](https://github.com/facebookresearch/sam-audio/actions/workflows/ci.yaml/badge.svg)
6
+
7
+ ![model_image](assets/sam_audio_main_model.png)
8
+
9
+ </div>
10
+
11
+ Segment Anything Model for Audio [[**Blog**](https://ai.meta.com/blog/sam-audio/)] [[**Paper**](https://ai.meta.com/research/publications/sam-audio-segment-anything-in-audio/)] [[**Demo**](https://aidemos.meta.com/segment-anything/editor/segment-audio)]
12
+
13
+ SAM-Audio is a foundation model for isolating any sound in audio using text, visual, or temporal prompts. It can separate specific sounds from complex audio mixtures based on natural language descriptions, visual cues from video, or time spans.
14
+
15
+ SAM-Audio and the Judge model crucially rely on [Perception-Encoder Audio-Visual (PE-AV)](https://huggingface.co/facebook/pe-av-large), which you can read more about [here](https://ai.meta.com/research/publications/pushing-the-frontier-of-audiovisual-perception-with-large-scale-multimodal-correspondence-learning/)
16
+
17
+ ## Setup
18
+
19
+ **Requirements:**
20
+ - Python >= 3.10
21
+ - CUDA-compatible GPU (recommended)
22
+
23
+ Install dependencies:
24
+
25
+ ```bash
26
+ pip install .
27
+ ```
28
+
29
+ ## Usage
30
+
31
+ ⚠️ Before using SAM Audio, please request access to the checkpoints on the SAM Audio
32
+ Hugging Face [repo](https://huggingface.co/facebook/sam-audio-large). Once accepted, you
33
+ need to be authenticated to download the checkpoints. You can do this by running
34
+ the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
35
+ (e.g. `hf auth login` after generating an access token.)
36
+
37
+ ### Basic Text Prompting
38
+
39
+ ```python
40
+ from sam_audio import SAMAudio, SAMAudioProcessor
41
+ import torchaudio
42
+ import torch
43
+
44
+ model = SAMAudio.from_pretrained("facebook/sam-audio-large")
45
+ processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
46
+ model = model.eval().cuda()
47
+
48
+ file = "<audio file>" # audio file path or torch tensor
49
+ description = "<description>"
50
+
51
+ batch = processor(
52
+ audios=[file],
53
+ descriptions=[description],
54
+ ).to("cuda")
55
+
56
+ with torch.inference_mode():
57
+ # NOTE: `predict_spans` and `reranking_candidates` have a large impact on performance.
58
+ # Setting `predict_span=True` and `reranking_candidates=8` will give you better results at the cost of
59
+ # latency and memory. See the "Span Prediction" section below for more details
60
+ result = model.separate(batch, predict_spans=False, reranking_candidates=1)
61
+
62
+ # Save separated audio
63
+ sample_rate = processor.audio_sampling_rate
64
+ torchaudio.save("target.wav", result.target.cpu(), sample_rate) # The isolated sound
65
+ torchaudio.save("residual.wav", result.residual.cpu(), sample_rate) # Everything else
66
+ ```
67
+
68
+ ### Prompting Methods
69
+
70
+ SAM-Audio supports three types of prompts:
71
+
72
+ 1. **Text Prompting**: Describe the sound you want to isolate using natural language
73
+ ```python
74
+ processor(audios=[audio], descriptions=["A man speaking"])
75
+ ```
76
+
77
+ 2. **Visual Prompting**: Use video frames and masks to isolate sounds associated with visual objects
78
+ ```python
79
+ processor(audios=[video], descriptions=[""], masked_videos=processor.mask_videos([frames], [mask]))
80
+ ```
81
+
82
+ 3. **Span Prompting**: Specify time ranges where the target sound occurs
83
+ ```python
84
+ processor(audios=[audio], descriptions=["A horn honking"], anchors=[[["+", 6.3, 7.0]]])
85
+ ```
86
+
87
+ See the [examples](examples) directory for more detailed examples
88
+
89
+ ### Span Prediction (Optional for Text Prompting)
90
+
91
+ We also provide support for automatically predicting the spans based on the text description, which is especially helpful for separating non-ambience sound events. You can enable this by adding `predict_spans=True` in your call to `separate`
92
+
93
+ ```python
94
+ with torch.inference_mode()
95
+ outputs = model.separate(batch, predict_spans=True)
96
+
97
+ # To further improve performance (at the expense of latency), you can add candidate re-ranking
98
+ with torch.inference_mode():
99
+ outputs = model.separate(batch, predict_spans=True, reranking_candidates=8)
100
+ ```
101
+
102
+ ### Re-Ranking
103
+
104
+ We provide the following models to assess the quality of the separated audio:
105
+
106
+ - [CLAP](https://github.com/LAION-AI/CLAP): measures the similarity between the target audio and text description
107
+ - [Judge](https://huggingface.co/facebook/sam-audio-judge): measures the overall separation quality across 3 axes: precision, recall, and faithfulness (see the [model card](https://huggingface.co/facebook/sam-audio-judge#output-format) for more details)
108
+ - [ImageBind](https://github.com/facebookresearch/ImageBind): for visual prompting, we measure the imagebind embedding similarity between the separated audio and the masked input video
109
+
110
+ We provide support for generating multiple candidates (by setting `reranking_candidates=<k>` in your call to `separate`), which will generate `k` audios, and choose the best one based on the ranking models mentioned above
111
+
112
+ # Models
113
+
114
+ Below is a table of each of the models we released along with their overall subjective evaluation scores
115
+
116
+ | Model | General SFX | Speech | Speaker | Music | Instr(wild) | Instr(pro) |
117
+ |----------|-------------|--------|---------|-------|-------------|------------|
118
+ | [`sam-audio-small`](https://huggingface.co/facebook/sam-audio-small) | 3.62 | 3.99 | 3.12 | 4.11 | 3.56 | 4.24 |
119
+ | [`sam-audio-base`](https://huggingface.co/facebook/sam-audio-base) | 3.28 | 4.25 | 3.57 | 3.87 | 3.66 | 4.27 |
120
+ | [`sam-audio-large`](https://huggingface.co/facebook/sam-audio-large) | 3.50 | 4.03 | 3.60 | 4.22 | 3.66 | 4.49 |
121
+
122
+ We additional release another variant (in each size) that is better specifically on correctness of target sound as well as visual prompting:
123
+ - [`sam-audio-small-tv`](https://huggingface.co/facebook/sam-audio-small-tv)
124
+ - [`sam-audio-base-tv`](https://huggingface.co/facebook/sam-audio-base-tv)
125
+ - [`sam-audio-large-tv`](https://huggingface.co/facebook/sam-audio-large-tv)
126
+
127
+ ## Evaluation
128
+
129
+ See the [eval](eval) directory for instructions and scripts to reproduce results from the paper
130
+
131
+ ## Contributing
132
+
133
+ See [contributing](CONTRIBUTING.md) and [code of conduct](CODE_OF_CONDUCT.md) for more information.
134
+
135
+ ## License
136
+
137
+ This project is licensed under the SAM License - see the [LICENSE](LICENSE) file for details.
assets/sam_audio_main_model.png ADDED

Git LFS Details

  • SHA256: 8dc7bda3f7ad3a910cdcd5d137b3e7e721f2381736ac19e1155732422b818bbd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
eval/README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ This directory contains the evaluation code to reproduce the results from the SAM-Audio paper. The evaluation framework supports multiple datasets, prompting modes (text-only, span, visual), and metrics.
4
+
5
+ ## Setup
6
+
7
+ Before running evaluation, ensure you have:
8
+
9
+ 1. Installed the SAM-Audio package and its dependencies
10
+ 2. Authenticated with Hugging Face to access the model checkpoints (see main [README](../README.md))
11
+
12
+ ## Quick Start
13
+
14
+ Run evaluation on the default setting (instr-pro):
15
+
16
+ ```bash
17
+ python main.py
18
+ ```
19
+
20
+ You can also use multiple GPUs to speed up evaluation:
21
+
22
+ ```bash
23
+ torchrun --nproc_per_node=<ngpus> python main.py
24
+ ```
25
+
26
+ Evaluate on a specific setting:
27
+
28
+ ```bash
29
+ python main.py --setting sfx
30
+ ```
31
+
32
+ Evaluate on multiple settings:
33
+
34
+ ```bash
35
+ python main.py --setting sfx speech music
36
+ ```
37
+
38
+ ## Available Evaluation Settings
39
+
40
+ Run `python main.py --help` to see all available settings
41
+
42
+ ## Command Line Options
43
+
44
+ ```bash
45
+ python main.py [OPTIONS]
46
+ ```
47
+
48
+ ### Options:
49
+
50
+ - `-s, --setting` - Which setting(s) to evaluate (default: `instr-pro`)
51
+ - Choices: See available settings above
52
+ - Can specify multiple settings: `--setting sfx speech music`
53
+
54
+ - `--cache-path` - Where to cache downloaded datasets (default: `~/.cache/sam_audio`)
55
+
56
+ - `-p, --checkpoint-path` - Model checkpoint to evaluate (default: `facebook/sam-audio-1b`)
57
+ - Can use local path or Hugging Face model ID
58
+
59
+ - `-b, --batch-size` - Batch size for evaluation (default: `1`)
60
+
61
+ - `-w, --num-workers` - Number of data loading workers (default: `4`)
62
+
63
+ - `-c, --candidates` - Number of reranking candidates (default: `8`)
64
+
65
+ ## Evaluation Metrics
66
+
67
+ The evaluation framework computes the following metrics:
68
+
69
+ - **Judge** - SAM Audio Judge quality assessment metric
70
+ - **Aesthetic** - Aesthetic quality metric
71
+ - **CLAP** - Audio-text alignment metric (CLAP similarity)
72
+ - **ImageBind** - Audio-video alignment metric (for visual settings only)
73
+
74
+ ## Output
75
+
76
+ Results are saved to the `results/` directory as JSON files, one per setting:
77
+
78
+ ```
79
+ results/
80
+ ├── sfx.json
81
+ ├── speech.json
82
+ └── music.json
83
+ ```
84
+
85
+ Each JSON file contains the averaged metric scores across all samples in that setting.
86
+
87
+ Example output:
88
+ ```json
89
+ {
90
+ "JudgeOverall": "4.386",
91
+ "JudgeFaithfulness": "4.708",
92
+ "JudgeRecall": "4.934",
93
+ "JudgePrecision": "4.451",
94
+ "ContentEnjoyment": "5.296",
95
+ "ContentUsefulness": "6.903",
96
+ "ProductionComplexity": "4.301",
97
+ "ProductionQuality": "7.100",
98
+ "CLAPSimilarity": "0.271"
99
+ }
100
+ ```
eval/dataset/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Callable
4
+
5
+ from .musdb import MUSDB
6
+ from .sam_audio_bench import SAMAudioBench
7
+
8
+ SETTINGS = {
9
+ # Text-only settings
10
+ "sfx": (
11
+ SAMAudioBench,
12
+ {"span": False, "visual": False, "subset": "others-50:text-only"},
13
+ ),
14
+ "speech": (
15
+ SAMAudioBench,
16
+ {"span": False, "visual": False, "subset": "speech-clean-50:text-only"},
17
+ ),
18
+ "speaker": (
19
+ SAMAudioBench,
20
+ {"span": False, "visual": False, "subset": "spk-50:text-only"},
21
+ ),
22
+ "music": (
23
+ SAMAudioBench,
24
+ {"span": False, "visual": False, "subset": "music-clean-50:text-only"},
25
+ ),
26
+ "instr-wild": (
27
+ SAMAudioBench,
28
+ {"span": False, "visual": False, "subset": "instr-50:text-only"},
29
+ ),
30
+ "instr-pro": (MUSDB, {}),
31
+ # Span settings
32
+ "sfx-span": (
33
+ SAMAudioBench,
34
+ {"span": True, "visual": False, "subset": "others-50:text+span"},
35
+ ),
36
+ "speech-span": (
37
+ SAMAudioBench,
38
+ {"span": True, "visual": False, "subset": "speech-clean-50:text+span"},
39
+ ),
40
+ "speaker-span": (
41
+ SAMAudioBench,
42
+ {"span": True, "visual": False, "subset": "spk-50:text+span"},
43
+ ),
44
+ "music-span": (
45
+ SAMAudioBench,
46
+ {"span": True, "visual": False, "subset": "music-clean-50:text+span"},
47
+ ),
48
+ "instr-wild-span": (
49
+ SAMAudioBench,
50
+ {"span": True, "visual": False, "subset": "instr-50:text+span"},
51
+ ),
52
+ # Visual settings
53
+ "sfx-visual": (
54
+ SAMAudioBench,
55
+ {"span": False, "visual": True, "subset": "others-onscreen-50:visual-only"},
56
+ ),
57
+ "speaker-visual": (
58
+ SAMAudioBench,
59
+ {"span": False, "visual": True, "subset": "spk-onscreen-50:visual-only"},
60
+ ),
61
+ "instr-wild-visual": (
62
+ SAMAudioBench,
63
+ {"span": False, "visual": True, "subset": "instr-onscreen-50:visual-only"},
64
+ ),
65
+ }
66
+
67
+
68
+ def make_dataset(setting: str, cache_path: str, collate_fn: Callable):
69
+ dataset, kwargs = SETTINGS[setting]
70
+ return dataset(cache_path=cache_path, collate_fn=collate_fn, **kwargs)
eval/dataset/musdb.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import os
4
+ from subprocess import check_call
5
+
6
+ import torchaudio
7
+ from datasets import load_dataset
8
+ from torch.utils.data import Dataset
9
+ from torchcodec.decoders import AudioDecoder
10
+
11
+
12
+ def cache_file(url, outfile):
13
+ if not os.path.exists(outfile):
14
+ print("Downloading musdb18hq dataset...")
15
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
16
+ check_call(["curl", "--url", url, "--output", outfile + ".tmp"])
17
+ os.rename(outfile + ".tmp", outfile)
18
+
19
+
20
+ class MUSDB(Dataset):
21
+ def __init__(
22
+ self,
23
+ collate_fn,
24
+ sample_rate: int = 48_000,
25
+ cache_path: str = os.path.expanduser("~/.cache/sam_audio"),
26
+ ):
27
+ self.cache_path = os.path.join(cache_path, "musdb18hq")
28
+ self.ds = self.get_dataset(cache_path)
29
+ self.captions = ["bass", "drums", "vocals"]
30
+ self.collate_fn = collate_fn
31
+ self.sample_rate = sample_rate
32
+
33
+ @property
34
+ def visual(self):
35
+ return False
36
+
37
+ def get_dataset(self, cache_path):
38
+ zip_file = os.path.join(cache_path, "musdb18hq.zip")
39
+ url = "https://zenodo.org/records/3338373/files/musdb18hq.zip?download=1"
40
+ cache_file(url, zip_file)
41
+ extracted_dir = os.path.join(cache_path, "musdb18hq")
42
+ if not os.path.exists(extracted_dir):
43
+ check_call(["unzip", zip_file, "-d", extracted_dir + ".tmp"])
44
+ os.rename(extracted_dir + ".tmp", extracted_dir)
45
+ return load_dataset("facebook/sam-audio-musdb18hq-test")["test"]
46
+
47
+ def __len__(self):
48
+ return len(self.ds)
49
+
50
+ def collate(self, items):
51
+ audios, descriptions = zip(*items, strict=False)
52
+ return self.collate_fn(
53
+ audios=audios,
54
+ descriptions=descriptions,
55
+ )
56
+
57
+ def __getitem__(self, idx):
58
+ item = self.ds[idx]
59
+ path = os.path.join(self.cache_path, "test", item["id"], "mixture.wav")
60
+ assert os.path.exists(path), f"{path} does not exist!"
61
+ decoder = AudioDecoder(path)
62
+ data = decoder.get_samples_played_in_range(item["start_time"], item["end_time"])
63
+ wav = data.data
64
+ if data.sample_rate != self.sample_rate:
65
+ wav = torchaudio.functional.resample(
66
+ wav, data.sample_rate, self.sample_rate
67
+ )
68
+ wav = wav.mean(0, keepdim=True)
69
+ return wav, item["description"]
70
+
71
+
72
+ if __name__ == "__main__":
73
+ dataset = MUSDB(lambda **kwargs: None)
74
+ print(len(dataset))
75
+ print(dataset[0])
eval/dataset/sam_audio_bench.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from io import BytesIO
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchaudio
12
+ from datasets import load_dataset
13
+ from torchcodec.decoders import AudioDecoder, VideoDecoder
14
+
15
+
16
+ @dataclass
17
+ class Item:
18
+ anchors: list[Tuple[str, float, float]]
19
+ masked_video_frames: torch.Tensor
20
+ audio_samples: torch.Tensor
21
+ description: str
22
+
23
+
24
+ class SAMAudioBench(torch.utils.data.Dataset):
25
+ def __init__(
26
+ self,
27
+ cache_path,
28
+ collate_fn,
29
+ span: bool = True,
30
+ visual: bool = True,
31
+ subset: Optional[str] = None,
32
+ ):
33
+ self.dataset = load_dataset("facebook/sam-audio-bench")["test"]
34
+ self.subset = subset
35
+ self._span = span
36
+ self._visual = visual
37
+ if subset is not None:
38
+ self.dataset = self.dataset.filter(lambda x: subset in x["paper_eval_sets"])
39
+
40
+ self.cache_path = os.path.join(cache_path, "sam_audio_bench")
41
+ self.collate_fn = collate_fn
42
+ DATA_MSG = (
43
+ f"`SAMAudioBench` requires the user to create a directory named {self.cache_path} "
44
+ "see the README.md file for how to prepare"
45
+ )
46
+ assert os.path.exists(self.cache_path), DATA_MSG
47
+
48
+ @property
49
+ def visual(self):
50
+ return self._visual
51
+
52
+ def __len__(self):
53
+ return len(self.dataset)
54
+
55
+ def _get_path(
56
+ self, video_id: str, source_dataset: str, start_offset: float, end_offset: float
57
+ ) -> str:
58
+ path = f"{self.cache_path}/{source_dataset}/{video_id}.mp4"
59
+ select_frames = True
60
+
61
+ if not os.path.exists(path):
62
+ path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset * 1000)}_{int(end_offset * 1000)}.mp4"
63
+ select_frames = False
64
+
65
+ if not os.path.exists(path):
66
+ path = f"{self.cache_path}/{source_dataset}/{video_id}_{int(start_offset)}_{int(end_offset)}.mp4"
67
+
68
+ if not os.path.exists(path):
69
+ path = f"{self.cache_path}/{source_dataset}/{video_id}.{int(start_offset * 1000):08d}_{int(end_offset * 1000):08d}.mp4"
70
+
71
+ return path, select_frames
72
+
73
+ def collate(self, items: list[Item]):
74
+ has_video = any(item.masked_video_frames is not None for item in items)
75
+ return self.collate_fn(
76
+ descriptions=[item.description for item in items],
77
+ audios=[item.audio_samples for item in items],
78
+ anchors=[item.anchors for item in items] if self._span else None,
79
+ masked_videos=[item.masked_video_frames for item in items]
80
+ if has_video and self._visual
81
+ else None,
82
+ )
83
+
84
+ def _get_masked_video(self, item, video_path, select_frames):
85
+ if item["mask_bytes"] is None:
86
+ return None
87
+
88
+ mask = torch.from_numpy(np.load(BytesIO(item["mask_bytes"]))["video_masklet"])
89
+
90
+ video_decoder = VideoDecoder(video_path)
91
+ if select_frames:
92
+ video_frames = video_decoder.get_frames_played_in_range(
93
+ item["start_offset"], item["end_offset"]
94
+ ).data
95
+ else:
96
+ video_frames = video_decoder[:].data
97
+
98
+ if mask.size(0) != video_frames.size(0):
99
+ # It's possible that the mask and the video frames differ by a small amount
100
+ # we interpolate the mask frame to match
101
+ idxs = (
102
+ torch.linspace(0, mask.size(0) - 1, video_frames.size(0)).round().long()
103
+ )
104
+ mask = mask[idxs]
105
+
106
+ mask = mask.unsqueeze(1)
107
+
108
+ if mask.shape[-2:] != video_frames.shape[-2:]:
109
+ mask = F.interpolate(mask, size=video_frames.shape[-2:])
110
+
111
+ import torchvision
112
+
113
+ torchvision.io.write_video("test.mp4", video_frames.permute(0, 2, 3, 1), 30)
114
+ torchvision.io.write_video(
115
+ "test_mask.mp4", mask.unsqueeze(-1).expand(-1, -1, -1, 3) * 255, 30
116
+ )
117
+
118
+ return video_frames * mask
119
+
120
+ def __getitem__(self, idx) -> Item:
121
+ item = self.dataset[idx]
122
+
123
+ video_path, select_frames = self._get_path(
124
+ item["video_id"],
125
+ item["source_dataset"],
126
+ item["start_offset"],
127
+ item["end_offset"],
128
+ )
129
+ assert os.path.exists(video_path), f"{video_path} does not exist!"
130
+
131
+ audio_decoder = AudioDecoder(video_path)
132
+ audio_samples = audio_decoder.get_samples_played_in_range(
133
+ start_seconds=item["start_offset"] if select_frames else 0,
134
+ stop_seconds=item["end_offset"] if select_frames else None,
135
+ )
136
+
137
+ if audio_samples.sample_rate != self.collate_fn.audio_sampling_rate:
138
+ resampled_audio = torchaudio.functional.resample(
139
+ audio_samples.data,
140
+ audio_samples.sample_rate,
141
+ self.collate_fn.audio_sampling_rate,
142
+ )
143
+ else:
144
+ resampled_audio = audio_samples.data
145
+
146
+ masked_video_frames = self._get_masked_video(item, video_path, select_frames)
147
+
148
+ return Item(
149
+ description=item["description"],
150
+ anchors=[("+", start, end) for start, end in item["spans"]],
151
+ masked_video_frames=masked_video_frames,
152
+ audio_samples=resampled_audio.mean(0, keepdim=True),
153
+ )
eval/main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+
7
+ import pandas as pd
8
+ import torch
9
+ import torch.distributed as dist
10
+ from dataset import SETTINGS, make_dataset
11
+ from metrics import CLAP, Aesthetic, ImageBind, Judge
12
+ from torch.utils.data import DataLoader
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from tqdm import tqdm
15
+
16
+ from sam_audio import SAMAudio, SAMAudioProcessor
17
+
18
+
19
+ def gather_and_average_results(results, world_size):
20
+ if world_size == 1:
21
+ return json.loads(results.mean().to_json())
22
+
23
+ # 1. Gather all dictionaries to all ranks
24
+ all_results = [None for _ in range(world_size)]
25
+ dist.all_gather_object(
26
+ all_results, {"sum": results.sum().to_json(), "count": len(results)}
27
+ )
28
+
29
+ summed = {}
30
+ counts = 0
31
+
32
+ for res in all_results:
33
+ for k, v in json.loads(res["sum"]).items():
34
+ if k not in summed:
35
+ summed[k] = 0.0
36
+ summed[k] += v
37
+ counts += res["count"]
38
+
39
+ # 3. Compute average for keys that appeared at least once
40
+ averaged = {k: summed[k] / counts for k in summed}
41
+
42
+ return averaged
43
+
44
+
45
+ def main(
46
+ settings: list[str],
47
+ cache_path: str,
48
+ batch_size: int,
49
+ checkpoint_path: str,
50
+ num_workers: int = 4,
51
+ reranking_candidates: int = 8,
52
+ ):
53
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
54
+ rank = int(os.environ.get("RANK", 0))
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ if world_size > 1:
58
+ torch.distributed.init_process_group(backend="nccl")
59
+ device = torch.device(f"cuda:{rank}")
60
+ torch.cuda.set_device(device)
61
+
62
+ model = SAMAudio.from_pretrained(checkpoint_path)
63
+ model = model.eval().to(device)
64
+ processor = SAMAudioProcessor.from_pretrained(checkpoint_path)
65
+
66
+ judge_metric = Judge(device=device)
67
+ aes_metric = Aesthetic(device=device)
68
+ clap_metric = CLAP(device=device)
69
+ imagebind_metric = ImageBind(device=device)
70
+
71
+ for setting in settings:
72
+ print(f"Evaluating: {setting}")
73
+ dset = make_dataset(setting, cache_path=cache_path, collate_fn=processor)
74
+ sampler = None
75
+ if world_size > 1:
76
+ sampler = DistributedSampler(dset)
77
+
78
+ dl = DataLoader(
79
+ dset,
80
+ batch_size=batch_size,
81
+ shuffle=False,
82
+ collate_fn=dset.collate,
83
+ num_workers=num_workers,
84
+ sampler=sampler,
85
+ )
86
+
87
+ all_metrics = [
88
+ judge_metric,
89
+ aes_metric,
90
+ clap_metric,
91
+ ]
92
+
93
+ if dset.visual:
94
+ all_metrics.append(imagebind_metric)
95
+
96
+ dfs = []
97
+ with torch.inference_mode():
98
+ for batch in tqdm(dl, disable=rank > 1):
99
+ batch = batch.to(device)
100
+ result = model.separate(
101
+ batch, reranking_candidates=reranking_candidates
102
+ )
103
+ mets = {}
104
+ for metric in all_metrics:
105
+ input_wavs = model.unbatch(batch.audios.squeeze(1), batch.wav_sizes)
106
+
107
+ mets.update(
108
+ metric(
109
+ target_wavs=result.target,
110
+ target_wavs_sample_rate=model.sample_rate,
111
+ descriptions=batch.descriptions,
112
+ input_wavs=input_wavs,
113
+ videos=batch.masked_video,
114
+ )
115
+ )
116
+
117
+ dfs.append(pd.DataFrame.from_dict(mets))
118
+
119
+ df = pd.concat(dfs)
120
+ averaged_results = gather_and_average_results(df, world_size)
121
+ if rank == 0:
122
+ results_dict = {k: f"{v:.3f}" for k, v in averaged_results.items()}
123
+ print(json.dumps(results_dict, indent=4))
124
+ os.makedirs("results", exist_ok=True)
125
+ outfile = f"results/{setting}.json"
126
+ with open(outfile, "w") as fout:
127
+ print(json.dumps(results_dict), file=fout)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument(
133
+ "--setting",
134
+ "-s",
135
+ choices=SETTINGS.keys(),
136
+ help=f"Which setting to evaluate. Choices: {SETTINGS.keys()}",
137
+ default=["instr-pro"],
138
+ nargs="+",
139
+ )
140
+ parser.add_argument(
141
+ "--cache-path",
142
+ type=str,
143
+ default=os.path.expanduser("~/.cache/sam_audio"),
144
+ help="Where to cache downloaded datasets",
145
+ )
146
+ parser.add_argument(
147
+ "--checkpoint-path", "-p", type=str, default="facebook/sam-audio-large"
148
+ )
149
+ parser.add_argument("--batch-size", "-b", type=int, default=1, help="Batch size")
150
+ parser.add_argument(
151
+ "--num-workers", "-w", type=int, default=4, help="Number of workers"
152
+ )
153
+ parser.add_argument("--candidates", "-c", type=int, default=8)
154
+ opt = parser.parse_args()
155
+ main(
156
+ settings=opt.setting,
157
+ cache_path=opt.cache_path,
158
+ batch_size=opt.batch_size,
159
+ checkpoint_path=opt.checkpoint_path,
160
+ num_workers=opt.num_workers,
161
+ reranking_candidates=opt.candidates,
162
+ )
eval/metrics/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from metrics.aes import Aesthetic
4
+ from metrics.clap import CLAP
5
+ from metrics.imagebind import ImageBind
6
+ from metrics.judge import Judge
7
+
8
+ __all__ = [
9
+ "Aesthetic",
10
+ "CLAP",
11
+ "ImageBind",
12
+ "Judge",
13
+ ]
eval/metrics/aes.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from audiobox_aesthetics.infer import AesPredictor
7
+
8
+ COLUMN_MAP = {
9
+ "CE": "ContentEnjoyment",
10
+ "CU": "ContentUsefulness",
11
+ "PC": "ProductionComplexity",
12
+ "PQ": "ProductionQuality",
13
+ }
14
+
15
+
16
+ class Aesthetic(torch.nn.Module):
17
+ def __init__(
18
+ self,
19
+ checkpoint: Optional[str] = None,
20
+ device: Optional[torch.device] = None,
21
+ ):
22
+ super().__init__()
23
+ self.model = AesPredictor(
24
+ checkpoint_pth=checkpoint,
25
+ data_col="wav",
26
+ )
27
+ self.device = device or torch.device(
28
+ "cuda" if torch.cuda.is_available() else "cpu"
29
+ )
30
+
31
+ def __call__(
32
+ self,
33
+ target_wavs: list[torch.Tensor],
34
+ target_wavs_sample_rate: int = 48_000,
35
+ **kwargs,
36
+ ) -> dict[str, list[float]]:
37
+ result = self.model.forward(
38
+ [
39
+ {
40
+ "wav": wav[None] if wav.ndim == 1 else wav,
41
+ "sample_rate": target_wavs_sample_rate,
42
+ }
43
+ for wav in target_wavs
44
+ ]
45
+ )
46
+ return {
47
+ long_name: [x[shortname] for x in result]
48
+ for shortname, long_name in COLUMN_MAP.items()
49
+ }
eval/metrics/clap.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from tempfile import TemporaryDirectory
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torchcodec.encoders import AudioEncoder
8
+
9
+ from sam_audio.ranking.clap import get_model
10
+
11
+
12
+ class CLAP(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ checkpoint: Optional[str] = None,
16
+ device: Optional[torch.device] = None,
17
+ ):
18
+ super().__init__()
19
+ self.model = get_model(device)
20
+ self.device = device or torch.device(
21
+ "cuda" if torch.cuda.is_available() else "cpu"
22
+ )
23
+
24
+ def __call__(
25
+ self,
26
+ target_wavs: list[torch.Tensor],
27
+ descriptions: list[str],
28
+ target_wavs_sample_rate: int = 48_000,
29
+ **kwargs,
30
+ ) -> list[dict[str, float]]:
31
+ with TemporaryDirectory() as tdir, torch.inference_mode():
32
+ file_list = []
33
+ for i, wav in enumerate(target_wavs):
34
+ file_list.append(f"{tdir}/hyp_{i}.wav")
35
+ encoder = AudioEncoder(
36
+ samples=wav.cpu()[None] if wav.ndim == 1 else wav.cpu(),
37
+ sample_rate=target_wavs_sample_rate,
38
+ )
39
+ encoder.to_file(file_list[-1])
40
+ audio_embs = self.model.get_audio_embedding_from_filelist(
41
+ file_list, use_tensor=True
42
+ )
43
+
44
+ text_embs = self.model.get_text_embedding(descriptions, use_tensor=True)
45
+ sims = audio_embs.unsqueeze(1) @ text_embs.unsqueeze(2)
46
+ return {"CLAPSimilarity": sims.cpu()[:, 0, 0].tolist()}
eval/metrics/imagebind.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from imagebind.models.imagebind_model import ModalityType, imagebind_huge
7
+
8
+ from sam_audio.ranking.imagebind import VideoTransform, load_and_transform_audio_data
9
+
10
+
11
+ class ImageBind(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ checkpoint: Optional[str] = None,
15
+ device: Optional[torch.device] = None,
16
+ ):
17
+ super().__init__()
18
+
19
+ self.model = imagebind_huge(pretrained=checkpoint is None)
20
+ if checkpoint is not None:
21
+ self.model.load_state_dict(torch.load(checkpoint, map_location="cpu"))
22
+ self.model = self.model.eval()
23
+ self.video_transform = VideoTransform()
24
+ self.device = device or torch.device(
25
+ "cuda" if torch.cuda.is_available() else "cpu"
26
+ )
27
+ self.model = self.model.to(self.device)
28
+
29
+ def __call__(
30
+ self,
31
+ target_wavs: list[torch.Tensor],
32
+ videos: list[torch.Tensor],
33
+ target_wavs_sample_rate: int = 48_000,
34
+ **kwargs,
35
+ ) -> dict[str, list[float]]:
36
+ audio_data = load_and_transform_audio_data(
37
+ target_wavs, input_sample_rate=target_wavs_sample_rate
38
+ )
39
+ durations = [x.size(-1) / target_wavs_sample_rate for x in target_wavs]
40
+ video_data = self.video_transform(videos, durations, audio_data.device)
41
+
42
+ inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
43
+ embs = self.model(inputs)
44
+ audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
45
+ audio_embs, video_embs = (
46
+ audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
47
+ video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
48
+ )
49
+ bsz = len(target_wavs)
50
+ candidates = len(audio_embs) // bsz
51
+ scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
52
+ return {"ImageBind": scores.squeeze(1, 2).cpu().tolist()}
eval/metrics/judge.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from sam_audio import SAMAudioJudgeModel, SAMAudioJudgeProcessor
8
+
9
+
10
+ class Judge(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ checkpoint: str = "facebook/sam-audio-judge",
14
+ device: Optional[torch.device] = None,
15
+ ):
16
+ super().__init__()
17
+ self.model = SAMAudioJudgeModel.from_pretrained(checkpoint).to(device)
18
+ self.processor = SAMAudioJudgeProcessor.from_pretrained(checkpoint)
19
+ self.device = device or torch.device(
20
+ "cuda" if torch.cuda.is_available() else "cpu"
21
+ )
22
+
23
+ def forward(
24
+ self,
25
+ input_wavs: list[torch.Tensor],
26
+ target_wavs: list[torch.Tensor],
27
+ descriptions: list[str],
28
+ target_wavs_sample_rate: int = 48_000,
29
+ **kwargs,
30
+ ) -> torch.Tensor:
31
+ with torch.inference_mode():
32
+ processed = self.processor(
33
+ text=descriptions,
34
+ input_audio=[x.cpu() for x in input_wavs],
35
+ separated_audio=[x.cpu() for x in target_wavs],
36
+ sampling_rate=target_wavs_sample_rate,
37
+ ).to(self.device)
38
+ result = self.model(**processed)
39
+ return {
40
+ "JudgeOverall": result.overall.squeeze(-1).cpu().tolist(),
41
+ "JudgeFaithfulness": result.faithfulness.squeeze(-1).cpu().tolist(),
42
+ "JudgeRecall": result.recall.squeeze(-1).cpu().tolist(),
43
+ "JudgePrecision": result.precision.squeeze(-1).cpu().tolist(),
44
+ }
examples/assets/office.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0f583ff34c5fd9d1a83d640e7c0131ad339755bd69e54f104723b707f213c21
3
+ size 4551702
examples/span_prompting.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/text_prompting.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
examples/visual_prompting.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sam_audio"
3
+ version = "0.1.0"
4
+ description = "Segment Anything Audio"
5
+ authors = [
6
+ { name="Andros Tjandra", email="androstj@meta.com" },
7
+ { name="Ann Lee", email="annl@meta.com" },
8
+ { name="Bowen Shi", email="bshi@meta.com" },
9
+ { name="Julius Richter", email="jrichter@meta.com" },
10
+ { name="Matt Le", email="mattle@meta.com" },
11
+ { name="Yi-Chiao Wu", email="yichiaowu@meta.com" },
12
+ ]
13
+
14
+ readme = "README.md"
15
+ license = { file="LICENSE" }
16
+ requires-python = ">=3.10"
17
+ dependencies = [
18
+ "dacvae@git+https://github.com/facebookresearch/dacvae.git",
19
+ "audiobox_aesthetics",
20
+ "einops",
21
+ "imagebind@git+https://github.com/facebookresearch/ImageBind.git",
22
+ "laion-clap@git+https://github.com/lematt1991/CLAP.git",
23
+ "numpy",
24
+ "perception-models@git+https://github.com/facebookresearch/perception_models@unpin-deps",
25
+ "pydub",
26
+ "torch",
27
+ "torchaudio",
28
+ "torchcodec",
29
+ "torchdiffeq",
30
+ "torchvision",
31
+ "transformers>=4.54.0",
32
+ ]
33
+
34
+ [tool.setuptools.packages.find]
35
+ include = ["sam_audio*"]
36
+
37
+ [tool.ruff]
38
+
39
+ target-version = "py310"
40
+
41
+ lint.select=[
42
+ "B",
43
+ "C",
44
+ "E",
45
+ "W",
46
+ "F",
47
+ "I",
48
+ ]
49
+ lint.ignore = [
50
+ "E501",
51
+ "E731",
52
+ "C901",
53
+ "B006",
54
+ ]
55
+
56
+ [project.urls]
57
+ Homepage = "https://github.com/facebookresearch/sam-audio"
58
+ Repository = "https://github.com/facebookresearch/sam-audio"
59
+
60
+ [build-system]
61
+ requires = ["setuptools>=61.0", "wheel"]
62
+ build-backend = "setuptools.build_meta"
sam_audio/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from .model import * # noqa
4
+ from .processor import * # noqa
sam_audio/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from .model import * # noqa
4
+ from .judge import * # noqa
sam_audio/model/align.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class AlignModalities(torch.nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_channels: int,
12
+ out_channels: int,
13
+ normalize: bool = True,
14
+ with_gate: bool = True,
15
+ ):
16
+ super().__init__()
17
+ self.conv = torch.nn.Conv1d(
18
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1
19
+ )
20
+ self.normalize = normalize
21
+ if self.normalize:
22
+ self.layer_norm = torch.nn.LayerNorm(out_channels)
23
+
24
+ self.gate = None
25
+ if with_gate:
26
+ self.gate = torch.nn.Parameter(torch.tensor([0.0]))
27
+
28
+ self.out_channels = out_channels
29
+
30
+ def forward(self, anchor: torch.Tensor, tgt: Optional[torch.Tensor] = None):
31
+ """
32
+ Align video features to the input audio features
33
+
34
+ Args:
35
+ anchor (torch.Tensor): Input anchor tensor of shape (B, T, C), where B is batch size, C is channel size, and T is sequence length.
36
+ tgt (Optional[torch.Tensor]): Optional features tensor to be aligned to anchor, expected shape (B, in_channels, T).
37
+ """
38
+ if tgt is None:
39
+ return anchor
40
+
41
+ post_conv = self.conv(tgt)
42
+ post_conv = post_conv.permute(0, 2, 1) # BCT -> BTC
43
+
44
+ if self.normalize:
45
+ post_conv = self.layer_norm(post_conv)
46
+
47
+ if self.gate is None:
48
+ return post_conv
49
+ else:
50
+ return anchor + self.gate.tanh() * post_conv
sam_audio/model/base.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import json
4
+ import os
5
+ from typing import Callable, Dict, Optional, Union
6
+
7
+ import torch
8
+ from huggingface_hub import ModelHubMixin, snapshot_download
9
+
10
+
11
+ class BaseModel(torch.nn.Module, ModelHubMixin):
12
+ config_cls: Callable
13
+
14
+ def device(self):
15
+ return next(self.parameters()).device
16
+
17
+ @classmethod
18
+ def _from_pretrained(
19
+ cls,
20
+ *,
21
+ model_id: str,
22
+ cache_dir: str,
23
+ force_download: bool,
24
+ proxies: Optional[Dict],
25
+ resume_download: bool,
26
+ local_files_only: bool,
27
+ token: Union[str, bool, None],
28
+ map_location: str = "cpu",
29
+ strict: bool = True,
30
+ revision: Optional[str] = None,
31
+ **model_kwargs,
32
+ ):
33
+ if os.path.isdir(model_id):
34
+ cached_model_dir = model_id
35
+ else:
36
+ cached_model_dir = snapshot_download(
37
+ repo_id=model_id,
38
+ revision=cls.revision,
39
+ cache_dir=cache_dir,
40
+ force_download=force_download,
41
+ proxies=proxies,
42
+ resume_download=resume_download,
43
+ token=token,
44
+ local_files_only=local_files_only,
45
+ )
46
+
47
+ with open(os.path.join(cached_model_dir, "config.json")) as fin:
48
+ config = json.load(fin)
49
+
50
+ for key, value in model_kwargs.items():
51
+ if key in config:
52
+ config[key] = value
53
+
54
+ config = cls.config_cls(**config)
55
+ model = cls(config)
56
+ state_dict = torch.load(
57
+ os.path.join(cached_model_dir, "checkpoint.pt"),
58
+ weights_only=True,
59
+ map_location=map_location,
60
+ )
61
+ model.load_state_dict(state_dict, strict=strict)
62
+ return model
sam_audio/model/codec.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from abc import ABCMeta, abstractmethod
5
+ from typing import Union
6
+
7
+ import dacvae
8
+ import torch
9
+
10
+ from sam_audio.model.config import DACVAEConfig
11
+
12
+
13
+ class Encoder(torch.nn.Module, metaclass=ABCMeta):
14
+ @abstractmethod
15
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
16
+
17
+
18
+ class Codec(Encoder):
19
+ @abstractmethod
20
+ def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
21
+
22
+ @abstractmethod
23
+ def wav_idx_to_feature_idx(
24
+ self, wav_idx: Union[torch.Tensor, int], sample_rate=None
25
+ ) -> Union[torch.Tensor, int]: ...
26
+
27
+ @abstractmethod
28
+ def feature_idx_to_wav_idx(
29
+ self, feature_idx: Union[torch.Tensor, int], sample_rate=None
30
+ ) -> Union[torch.Tensor, int]: ...
31
+
32
+ @staticmethod
33
+ def cast_to_int(
34
+ x: Union[int, torch.Tensor],
35
+ ) -> Union[int, torch.Tensor]:
36
+ if isinstance(x, torch.Tensor):
37
+ return x.int()
38
+ else:
39
+ return int(x)
40
+
41
+
42
+ class DACVAEEncoder(Encoder):
43
+ def __init__(self, config: DACVAEConfig) -> None:
44
+ super().__init__()
45
+ model = dacvae.DACVAE(
46
+ encoder_dim=config.encoder_dim,
47
+ encoder_rates=config.encoder_rates,
48
+ latent_dim=config.latent_dim,
49
+ decoder_dim=config.decoder_dim,
50
+ decoder_rates=config.decoder_rates,
51
+ n_codebooks=config.n_codebooks,
52
+ codebook_size=config.codebook_size,
53
+ codebook_dim=config.codebook_dim,
54
+ quantizer_dropout=config.quantizer_dropout,
55
+ sample_rate=config.sample_rate,
56
+ ).eval()
57
+ self._setup_model(model)
58
+ self.hop_length = config.hop_length
59
+ self.sample_rate = config.sample_rate
60
+
61
+ def _setup_model(self, model):
62
+ self.encoder = model.encoder
63
+ self.quantizer = model.quantizer
64
+
65
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
66
+ with torch.no_grad(), torch.backends.cudnn.flags(enabled=False):
67
+ z = self.encoder(self._pad(waveform))
68
+ mean, _ = self.quantizer.in_proj(z).chunk(2, dim=1)
69
+ encoded_frames = mean
70
+ return encoded_frames
71
+
72
+ def _pad(self, wavs):
73
+ length = wavs.size(-1)
74
+ if length % self.hop_length:
75
+ p1d = (0, self.hop_length - (length % self.hop_length))
76
+ return torch.nn.functional.pad(wavs, p1d, "reflect")
77
+ else:
78
+ return wavs
79
+
80
+
81
+ class DACVAE(DACVAEEncoder, Codec):
82
+ def _setup_model(self, model):
83
+ super()._setup_model(model)
84
+ self.decoder = model.decoder
85
+
86
+ def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
87
+ with torch.backends.cudnn.flags(enabled=False):
88
+ emb = self.quantizer.out_proj(encoded_frames)
89
+ return self.decoder(emb)
90
+
91
+ def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
92
+ if sample_rate is None:
93
+ sample_rate = self.sample_rate
94
+ orig_freq = sample_rate
95
+ new_freq = self.sample_rate
96
+ wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
97
+ return self.cast_to_int(wav_chunklen)
98
+
99
+ def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
100
+ ceil = math.ceil
101
+ if torch.is_tensor(wav_idx):
102
+ ceil = torch.ceil
103
+ if sample_rate is None:
104
+ sample_rate = self.sample_rate
105
+ orig_freq = sample_rate
106
+ new_freq = self.sample_rate
107
+ target_length = ceil(new_freq * wav_idx / orig_freq)
108
+ res = ceil(target_length / self.hop_length)
109
+ return self.cast_to_int(res)
sam_audio/model/config.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
7
+ from transformers import ModernBertConfig
8
+
9
+
10
+ class DACVAEConfig:
11
+ def __init__(
12
+ self,
13
+ encoder_dim: int = 64,
14
+ encoder_rates: list[int] = [2, 8, 10, 12],
15
+ latent_dim: int = 1024,
16
+ decoder_dim: int = 1536,
17
+ decoder_rates: list[int] = [12, 10, 8, 2],
18
+ n_codebooks: int = 16,
19
+ codebook_size: int = 1024,
20
+ codebook_dim: int = 128,
21
+ quantizer_dropout: bool = False,
22
+ sample_rate: int = 48_000,
23
+ mean: float = 0.0,
24
+ std: float = 1.0,
25
+ ):
26
+ self.encoder_dim = encoder_dim
27
+ self.encoder_rates = encoder_rates
28
+ self.latent_dim = latent_dim
29
+ self.decoder_dim = decoder_dim
30
+ self.decoder_rates = decoder_rates
31
+ self.n_codebooks = n_codebooks
32
+ self.codebook_size = codebook_size
33
+ self.codebook_dim = codebook_dim
34
+ self.quantizer_dropout = quantizer_dropout
35
+ self.sample_rate = sample_rate
36
+ self.mean = mean
37
+ self.std = std
38
+
39
+ @property
40
+ def hop_length(self):
41
+ return int(np.prod(self.encoder_rates))
42
+
43
+
44
+ class TextEncoderConfig:
45
+ def __init__(self, dim: int = 768):
46
+ self.dim = dim
47
+
48
+
49
+ class T5EncoderConfig(TextEncoderConfig):
50
+ def __init__(
51
+ self,
52
+ name: str = "t5-base",
53
+ max_length: Optional[int] = 512,
54
+ pad_mode: str = "longest",
55
+ dim: int = 768,
56
+ ):
57
+ super().__init__(dim=dim)
58
+ self.name = name
59
+ self.max_length = max_length
60
+ self.pad_mode = pad_mode
61
+
62
+
63
+ class VisionEncoderConfig:
64
+ def __init__(self, dim: int = 1024, batch_size: int = 300):
65
+ self.dim = dim
66
+ self.batch_size = batch_size
67
+
68
+
69
+ class PerceptionEncoderConfig(VisionEncoderConfig):
70
+ def __init__(
71
+ self,
72
+ dim: int = 1024,
73
+ batch_size: int = 300,
74
+ name: str = "PE-Core-L14-336",
75
+ normalize_feature: bool = True,
76
+ interpolation_mode: str = "BICUBIC",
77
+ image_size: int = 336,
78
+ ):
79
+ super().__init__(dim=dim, batch_size=batch_size)
80
+ self.name = name
81
+ self.normalize_feature = normalize_feature
82
+ self.interpolation_mode = interpolation_mode
83
+ self.image_size = image_size
84
+
85
+
86
+ class TransformerConfig:
87
+ def __init__(
88
+ self,
89
+ dim: int = 2048,
90
+ n_heads: int = 16,
91
+ n_layers: int = 16,
92
+ dropout: float = 0.1,
93
+ norm_eps: float = 1.0e-05,
94
+ qk_norm: bool = True,
95
+ fc_bias: bool = False,
96
+ ffn_exp: int = 4,
97
+ ffn_dim_multiplier: int = 1,
98
+ multiple_of: int = 64,
99
+ non_linearity: str = "swiglu",
100
+ use_rope: bool = True,
101
+ max_positions: int = 10000,
102
+ frequency_embedding_dim: int = 256,
103
+ timestep_non_linearity: str = "swiglu",
104
+ t_block_non_linearity: str = "silu",
105
+ t_block_bias: bool = True,
106
+ context_dim: int = 2048,
107
+ context_non_linearity: str = "swiglu",
108
+ context_embedder_dropout: float = 0.0,
109
+ context_norm: bool = False,
110
+ out_channels: int = 256,
111
+ in_channels: Optional[int] = None,
112
+ ):
113
+ self.dim = dim
114
+ self.n_heads = n_heads
115
+ self.n_layers = n_layers
116
+ self.dropout = dropout
117
+ self.norm_eps = norm_eps
118
+ self.qk_norm = qk_norm
119
+ self.fc_bias = fc_bias
120
+ self.ffn_exp = ffn_exp
121
+ self.ffn_dim_multiplier = ffn_dim_multiplier
122
+ self.multiple_of = multiple_of
123
+ self.non_linearity = non_linearity
124
+ self.use_rope = use_rope
125
+ self.max_positions = max_positions
126
+ self.frequency_embedding_dim = frequency_embedding_dim
127
+ self.timestep_non_linearity = timestep_non_linearity
128
+ self.t_block_non_linearity = t_block_non_linearity
129
+ self.t_block_bias = t_block_bias
130
+ self.context_dim = context_dim
131
+ self.context_non_linearity = context_non_linearity
132
+ self.context_embedder_dropout = context_embedder_dropout
133
+ self.context_norm = context_norm
134
+ self.out_channels = out_channels
135
+ self.in_channels = in_channels
136
+
137
+
138
+ class RankerConfig:
139
+ kind: str
140
+
141
+
142
+ class ImageBindRankerConfig(RankerConfig):
143
+ kind: str = "imagebind"
144
+
145
+ def __init__(self, checkpoint: Optional[str] = None):
146
+ self.checkpoint = checkpoint
147
+
148
+
149
+ class ClapRankerConfig(RankerConfig):
150
+ kind: str = "clap"
151
+
152
+ def __init__(self, checkpoint: Optional[str] = None):
153
+ self.checkpoint = checkpoint
154
+
155
+
156
+ class JudgeRankerConfig(RankerConfig):
157
+ kind: str = "judge"
158
+
159
+ def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
160
+ self.checkpoint_or_model_id = checkpoint_or_model_id
161
+
162
+
163
+ class SoundActivityRankerConfig(RankerConfig):
164
+ kind: str = "sound_activity"
165
+
166
+ def __init__(
167
+ self,
168
+ threshold_mode: str = "rel_to_max",
169
+ sil_threshold: float = -40,
170
+ metric: str = "iou",
171
+ ):
172
+ self.threshold_mode = threshold_mode
173
+ self.sil_threshold = sil_threshold
174
+ self.metric = metric
175
+
176
+
177
+ class EnsembleRankerConfig(RankerConfig):
178
+ kind: str = "ensemble"
179
+
180
+ def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
181
+ self.rankers = rankers
182
+
183
+
184
+ def parse_ranker_config(config_dict: dict):
185
+ kind = config_dict.pop("kind")
186
+ match kind:
187
+ case ImageBindRankerConfig.kind:
188
+ return ImageBindRankerConfig(**config_dict)
189
+ case ClapRankerConfig.kind:
190
+ return ClapRankerConfig(**config_dict)
191
+ case JudgeRankerConfig.kind:
192
+ return JudgeRankerConfig(**config_dict)
193
+ case SoundActivityRankerConfig.kind:
194
+ return SoundActivityRankerConfig(**config_dict)
195
+ case EnsembleRankerConfig.kind:
196
+ return EnsembleRankerConfig(
197
+ {
198
+ k: (parse_ranker_config(v), w)
199
+ for k, (v, w) in config_dict["rankers"].items()
200
+ }
201
+ )
202
+
203
+
204
+ class SAMAudioConfig:
205
+ def __init__(
206
+ self,
207
+ in_channels: int = 768,
208
+ audio_codec=None,
209
+ text_encoder=None,
210
+ vision_encoder=None,
211
+ transformer=None,
212
+ num_anchors: int = 3,
213
+ anchor_embedding_dim: int = 128,
214
+ visual_ranker=None,
215
+ text_ranker=None,
216
+ span_predictor: Optional[str] = "pe-a-frame-large",
217
+ ):
218
+ self.in_channels = in_channels
219
+ self.audio_codec = DACVAEConfig(**(audio_codec or {}))
220
+ self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
221
+ self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
222
+ self.transformer = TransformerConfig(**(transformer or {}))
223
+ self.num_anchors = num_anchors
224
+ self.anchor_embedding_dim = anchor_embedding_dim
225
+ self.visual_ranker = (
226
+ None if visual_ranker is None else parse_ranker_config(visual_ranker)
227
+ )
228
+ self.text_ranker = (
229
+ None if text_ranker is None else parse_ranker_config(text_ranker)
230
+ )
231
+ self.span_predictor = span_predictor
232
+
233
+
234
+ class SAMAudioJudgeConfig:
235
+ def __init__(
236
+ self,
237
+ audio_codec: DACVAEConfig = None,
238
+ transformer: PEAVTransformerConfig = None,
239
+ text_model: ModernBertConfig = None,
240
+ finetune_transformer: PEAVTransformerConfig = None,
241
+ nth_text_layer: int = 22,
242
+ bottleneck_dim: int = 256,
243
+ ):
244
+ self.audio_codec = DACVAEConfig(**(audio_codec or {}))
245
+ self.transformer = PEAVTransformerConfig(**(transformer or {}))
246
+ self.text_model = ModernBertConfig(**(text_model or {}))
247
+ self.finetune_transformer = PEAVTransformerConfig(
248
+ **(finetune_transformer or {})
249
+ )
250
+ self.nth_text_layer = nth_text_layer
251
+ self.bottleneck_dim = bottleneck_dim
sam_audio/model/judge.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from core.audio_visual_encoder.transformer import BaseModelOutputWithPooling
8
+ from core.audio_visual_encoder.transformer import Transformer as PEAVTransformer
9
+ from transformers import AutoModel
10
+
11
+ from .base import BaseModel
12
+ from .codec import DACVAEEncoder
13
+ from .config import SAMAudioJudgeConfig
14
+
15
+
16
+ @dataclass
17
+ class SAMAudioJudgeOutput:
18
+ r"""
19
+ overall (torch.Tensor, optional): Overall score tensor of shape (batch_size, 1).
20
+ recall (torch.Tensor, optional): Recall score tensor of shape (batch_size, 1).
21
+ precision (torch.Tensor, optional): Precision score tensor of shape (batch_size, 1).
22
+ faithfulness (torch.Tensor, optional): Faithfulness score tensor of shape (batch_size, 1).
23
+ text_model_output (BaseModelOutputWithPooling): Output from the text model.
24
+ audio_model_output (BaseModelOutputWithPooling): Output from the audio model.
25
+ """
26
+
27
+ overall: Optional[torch.Tensor] = None
28
+ recall: Optional[torch.Tensor] = None
29
+ precision: Optional[torch.Tensor] = None
30
+ faithfulness: Optional[torch.Tensor] = None
31
+ text_model_output: BaseModelOutputWithPooling = None
32
+ audio_model_output: BaseModelOutputWithPooling = None
33
+
34
+
35
+ class SAMAudioJudgeModel(BaseModel):
36
+ config_cls = SAMAudioJudgeConfig
37
+ revision = "sam_audio"
38
+
39
+ def __init__(self, config: SAMAudioJudgeConfig):
40
+ super().__init__()
41
+ self.config = config
42
+ self.data_proj = torch.nn.Linear(
43
+ config.audio_codec.codebook_dim, config.transformer.hidden_size
44
+ )
45
+ self.audio_codec = DACVAEEncoder(config.audio_codec)
46
+ self.transformer = PEAVTransformer(config.transformer)
47
+ self.finetune_transformer = PEAVTransformer(config.finetune_transformer)
48
+ self.text_model = AutoModel.from_config(config.text_model)
49
+ self.cat_audio_proj = torch.nn.Linear(
50
+ 2 * config.transformer.hidden_size, config.bottleneck_dim
51
+ )
52
+ self.text_proj1 = torch.nn.Linear(
53
+ in_features=config.text_model.hidden_size,
54
+ out_features=config.transformer.hidden_size,
55
+ bias=False,
56
+ )
57
+ self.text_proj2 = torch.nn.Linear(
58
+ in_features=config.transformer.hidden_size,
59
+ out_features=config.bottleneck_dim,
60
+ )
61
+ self.layer_norm = torch.nn.LayerNorm(config.bottleneck_dim)
62
+ self.proj_audio_and_text = torch.nn.Linear(
63
+ 2 * config.bottleneck_dim, config.bottleneck_dim
64
+ )
65
+ self.finetune_data_proj = torch.nn.Linear(
66
+ config.bottleneck_dim, config.finetune_transformer.hidden_size
67
+ )
68
+ self.head = torch.nn.Linear(
69
+ config.finetune_transformer.hidden_size, 4, bias=False
70
+ )
71
+ self.mean = torch.nn.Parameter(torch.zeros(4, requires_grad=False))
72
+ self.std = torch.nn.Parameter(torch.ones(4, requires_grad=False))
73
+
74
+ def _get_text_output(self, input_ids, attention_mask):
75
+ nth_layer = self.config.nth_text_layer
76
+ output = self.text_model(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ output_hidden_states=nth_layer is not None,
80
+ )
81
+ if nth_layer is None:
82
+ text_model_output = output.last_hidden_state
83
+ else:
84
+ text_model_output = output.hidden_states[nth_layer]
85
+
86
+ return BaseModelOutputWithPooling(
87
+ last_hidden_state=text_model_output, pooler_output=text_model_output[:, 0]
88
+ )
89
+
90
+ def forward(
91
+ self,
92
+ input_ids: torch.Tensor, # tokenized text
93
+ input_values: torch.Tensor, # input audio waveform
94
+ separated_values: torch.Tensor, # separated audio waveform
95
+ attention_mask: Optional[torch.Tensor] = None, # text attention mask
96
+ padding_mask: Optional[torch.Tensor] = None, # audio padding mask
97
+ ) -> SAMAudioJudgeOutput:
98
+ text_features = self.text_proj1(
99
+ self._get_text_output(input_ids, attention_mask).pooler_output
100
+ )
101
+ stacked_audios = torch.cat([input_values, separated_values], dim=0)
102
+ stacked_codec_features = self.audio_codec(stacked_audios)
103
+ feature_padding_mask = None
104
+ if padding_mask is not None:
105
+ feature_padding_mask = padding_mask[
106
+ :, :: self.config.audio_codec.hop_length
107
+ ]
108
+ stacked_features = self.transformer(
109
+ self.data_proj(stacked_codec_features.transpose(1, 2)),
110
+ padding_mask=feature_padding_mask,
111
+ )
112
+ input_features, hyp_features = stacked_features.last_hidden_state.chunk(2, 0)
113
+ audio_features = self.cat_audio_proj(
114
+ torch.cat([hyp_features, input_features], dim=2)
115
+ )
116
+ expanded_text = (
117
+ self.layer_norm(self.text_proj2(text_features))
118
+ .unsqueeze(1)
119
+ .expand_as(audio_features)
120
+ )
121
+ audio_and_text = self.proj_audio_and_text(
122
+ torch.cat([audio_features, expanded_text], dim=2)
123
+ )
124
+ finetune_transformer_output = self.finetune_transformer(
125
+ self.finetune_data_proj(audio_and_text), padding_mask=feature_padding_mask
126
+ )
127
+ result = self.head(finetune_transformer_output.last_hidden_state)
128
+ if feature_padding_mask is not None:
129
+ feature_padding_mask = feature_padding_mask.unsqueeze(-1)
130
+ pooled = torch.masked.mean(result, mask=feature_padding_mask, dim=1)
131
+ de_normalized = pooled * self.std + self.mean
132
+ return SAMAudioJudgeOutput(*de_normalized.chunk(4, dim=1))
133
+
134
+
135
+ __all__ = ["SAMAudioJudgeModel", "SAMAudioJudgeOutput"]
sam_audio/model/model.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, Optional
7
+
8
+ import torch
9
+ from core.audio_visual_encoder import PEAudioFrame, PEAudioFrameTransform
10
+ from torchdiffeq import odeint
11
+
12
+ from sam_audio.model.align import AlignModalities
13
+ from sam_audio.model.base import BaseModel
14
+ from sam_audio.model.codec import DACVAE
15
+ from sam_audio.model.config import SAMAudioConfig
16
+ from sam_audio.model.text_encoder import T5TextEncoder
17
+ from sam_audio.model.transformer import DiT
18
+ from sam_audio.model.vision_encoder import PerceptionEncoder
19
+ from sam_audio.processor import Batch
20
+ from sam_audio.ranking import create_ranker
21
+
22
+ DFLT_ODE_OPT = {"method": "midpoint", "options": {"step_size": 2 / 32}}
23
+
24
+
25
+ class SinusoidalEmbedding(torch.nn.Module):
26
+ def __init__(self, dim, theta=10000):
27
+ super().__init__()
28
+ assert (dim % 2) == 0
29
+ half_dim = dim // 2
30
+ inv_freq = torch.exp(
31
+ -math.log(theta) * torch.arange(half_dim).float() / half_dim
32
+ )
33
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
34
+
35
+ def forward(self, x, pos=None):
36
+ if pos is None:
37
+ seq_len, device = x.shape[1], x.device
38
+ pos = torch.arange(seq_len, device=device)
39
+
40
+ emb = torch.einsum("i, j -> i j", pos, self.inv_freq)
41
+ emb = torch.cat((emb.cos(), emb.sin()), dim=-1)
42
+ return emb
43
+
44
+
45
+ class EmbedAnchors(torch.nn.Module):
46
+ def __init__(self, num_embeddings: int, embedding_dim: int, out_dim: int):
47
+ super().__init__()
48
+ self.embed = torch.nn.Embedding(
49
+ num_embeddings + 1, embedding_dim, padding_idx=num_embeddings
50
+ )
51
+ self.gate = torch.nn.Parameter(torch.tensor([0.0]))
52
+ self.proj = torch.nn.Linear(embedding_dim, out_dim, bias=False)
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ anchor_ids: Optional[torch.Tensor] = None,
58
+ anchor_alignment: Optional[torch.Tensor] = None,
59
+ ):
60
+ if anchor_ids is None:
61
+ return x
62
+
63
+ embs = self.embed(anchor_ids.gather(1, anchor_alignment))
64
+ proj = self.proj(embs)
65
+ return x + self.gate.tanh() * proj
66
+
67
+
68
+ @dataclass
69
+ class SeparationResult:
70
+ target: torch.Tensor
71
+ residual: torch.Tensor
72
+ noise: torch.Tensor
73
+
74
+
75
+ class SAMAudio(BaseModel):
76
+ config_cls = SAMAudioConfig
77
+ revision = None
78
+
79
+ def __init__(self, cfg: SAMAudioConfig):
80
+ super().__init__()
81
+ self.audio_codec = DACVAE(cfg.audio_codec)
82
+ self.text_encoder = T5TextEncoder(cfg.text_encoder)
83
+ self.vision_encoder = PerceptionEncoder(cfg.vision_encoder)
84
+ self.transformer = DiT(cfg.transformer)
85
+ self.proj = torch.nn.Linear(cfg.in_channels, cfg.transformer.dim)
86
+ self.align_masked_video = AlignModalities(
87
+ cfg.vision_encoder.dim, cfg.transformer.dim
88
+ )
89
+ self.embed_anchors = EmbedAnchors(
90
+ cfg.num_anchors, cfg.anchor_embedding_dim, cfg.transformer.dim
91
+ )
92
+ self.memory_proj = torch.nn.Linear(cfg.text_encoder.dim, cfg.transformer.dim)
93
+ self.timestep_emb = SinusoidalEmbedding(cfg.transformer.dim)
94
+ self.visual_ranker = create_ranker(cfg.visual_ranker)
95
+ self.text_ranker = create_ranker(cfg.text_ranker)
96
+ if cfg.span_predictor is not None:
97
+ self.span_predictor = PEAudioFrame.from_config(
98
+ cfg.span_predictor, pretrained=True
99
+ )
100
+ self.span_predictor_transform = PEAudioFrameTransform.from_config(
101
+ cfg.span_predictor
102
+ )
103
+
104
+ @property
105
+ def sample_rate(self):
106
+ return self.audio_codec.sample_rate
107
+
108
+ def align_inputs(
109
+ self,
110
+ noisy_audio,
111
+ audio_features: torch.Tensor,
112
+ masked_video_features: Optional[torch.Tensor] = None,
113
+ anchor_ids: Optional[torch.Tensor] = None,
114
+ anchor_alignment: Optional[torch.Tensor] = None,
115
+ ):
116
+ x = torch.cat(
117
+ [
118
+ noisy_audio,
119
+ torch.zeros_like(audio_features),
120
+ audio_features,
121
+ ],
122
+ dim=2,
123
+ )
124
+
125
+ projected = self.proj(x)
126
+ aligned = self.align_masked_video(projected, masked_video_features)
127
+ aligned = self.embed_anchors(aligned, anchor_ids, anchor_alignment)
128
+ return aligned
129
+
130
+ def forward(
131
+ self,
132
+ noisy_audio: torch.Tensor,
133
+ audio_features: torch.Tensor,
134
+ text_features: torch.Tensor,
135
+ time: torch.Tensor,
136
+ masked_video_features: Optional[torch.Tensor] = None,
137
+ text_mask: Optional[torch.Tensor] = None,
138
+ anchor_ids: Optional[torch.Tensor] = None,
139
+ anchor_alignment: Optional[torch.Tensor] = None,
140
+ audio_pad_mask: Optional[torch.Tensor] = None,
141
+ ):
142
+ """
143
+ Forward pass for the model. Represents one function evaluation of the ODE.
144
+ In the below descriptions, B is batch size, T is sequence length, C is channel size.
145
+ Note that the size of C and T may vary across arguments (ex. text_features vs. audio_features),
146
+ it is used only to designate a Channel or time/sequence-length dimension respectively.
147
+
148
+ Args:
149
+ noisy_audio (torch.Tensor): Noisy audio input tensor (being denoised).
150
+ audio_features (torch.Tensor): Clean audio features [B x T x C].
151
+ text_features (torch.Tensor): Encoded text features tensor [B x T x C].
152
+ time (torch.Tensor): Timestep tensor for positional encoding [B].
153
+ masked_video_features (Optional[torch.Tensor], optional): Masked video features tensor. [B x C x T].
154
+ text_mask (Optional[torch.Tensor], optional): Padding mask for text features. [B x T].
155
+ anchor_ids (Optional[torch.Tensor], optional): Anchor IDs tensor. Defaults to None [B x T].
156
+ anchor_alignment (Optional[torch.Tensor], optional): Anchor alignment tensor. B x T.
157
+ audio_pad_mask (Optional[torch.Tensor], optional): Padding mask for audio input. [B x T].
158
+
159
+ Returns:
160
+ torch.Tensor
161
+ """
162
+ aligned_inputs = self.align_inputs(
163
+ noisy_audio,
164
+ audio_features,
165
+ masked_video_features=masked_video_features,
166
+ anchor_ids=anchor_ids,
167
+ anchor_alignment=anchor_alignment,
168
+ )
169
+
170
+ memory = timestep_emb = self.timestep_emb(time, pos=time).unsqueeze(1)
171
+ if text_features is not None:
172
+ memory = self.memory_proj(text_features) + timestep_emb
173
+
174
+ return self.transformer(
175
+ aligned_inputs,
176
+ time,
177
+ padding_mask=audio_pad_mask,
178
+ memory=memory,
179
+ memory_padding_mask=text_mask,
180
+ )
181
+
182
+ def _get_audio_features(self, audios: torch.Tensor):
183
+ audio_features = self.audio_codec(audios).transpose(1, 2)
184
+ return torch.cat([audio_features, audio_features], dim=2)
185
+
186
+ def _get_video_features(self, video, audio_features):
187
+ B, T, _ = audio_features.shape
188
+ if video is None:
189
+ return audio_features.new_zeros(B, self.vision_encoder.dim, T)
190
+ else:
191
+ return self.vision_encoder(video).transpose(1, 2)
192
+
193
+ def _repeat_for_reranking(self, tensor, candidates):
194
+ if candidates > 1:
195
+ B = tensor.size(0)
196
+ rest = tensor.shape[1:]
197
+ return (
198
+ tensor.unsqueeze(1)
199
+ .expand(B, candidates, *rest)
200
+ .reshape(B * candidates, *rest)
201
+ )
202
+ else:
203
+ return tensor
204
+
205
+ def _unrepeat_from_reranking(self, tensor, candidates):
206
+ return tensor[::candidates]
207
+
208
+ def _get_forward_args(self, batch: Batch, candidates: int = 1):
209
+ audio_features = self._get_audio_features(batch.audios)
210
+ text_features, text_mask = self.text_encoder(batch.descriptions)
211
+ masked_video_features = self._get_video_features(
212
+ batch.masked_video, audio_features
213
+ )
214
+
215
+ return {
216
+ "audio_features": self._repeat_for_reranking(audio_features, candidates),
217
+ "text_features": self._repeat_for_reranking(text_features, candidates),
218
+ "text_mask": self._repeat_for_reranking(text_mask, candidates),
219
+ "masked_video_features": self._repeat_for_reranking(
220
+ masked_video_features, candidates
221
+ ),
222
+ "anchor_ids": self._repeat_for_reranking(batch.anchor_ids, candidates),
223
+ "anchor_alignment": self._repeat_for_reranking(
224
+ batch.anchor_alignment, candidates
225
+ ),
226
+ "audio_pad_mask": self._repeat_for_reranking(
227
+ batch.audio_pad_mask, candidates
228
+ ),
229
+ }
230
+
231
+ def predict_spans(
232
+ self, batch: Batch, audio_features: torch.Tensor, audio_pad_mask: torch.Tensor
233
+ ) -> Batch:
234
+ input = self.span_predictor_transform(text=batch.descriptions).to(
235
+ audio_features.device
236
+ )
237
+ output = self.span_predictor(
238
+ input_features=audio_features[:, :, :128],
239
+ padding_mask=audio_pad_mask,
240
+ return_spans=True,
241
+ **input,
242
+ )
243
+ anchors = [[["+"] + anchor for anchor in anchors] for anchors in output.spans]
244
+ batch.process_anchors(anchors)
245
+ return batch
246
+
247
+ @torch.inference_mode()
248
+ def separate(
249
+ self,
250
+ batch: Batch,
251
+ noise: Optional[torch.Tensor] = None,
252
+ ode_opt: Dict[str, Any] = DFLT_ODE_OPT,
253
+ reranking_candidates: int = 1,
254
+ predict_spans: bool = False,
255
+ ) -> SeparationResult:
256
+ # Encode audio
257
+ forward_args = self._get_forward_args(batch, candidates=reranking_candidates)
258
+
259
+ if predict_spans and hasattr(self, "span_predictor") and batch.anchors is None:
260
+ batch = self.predict_spans(
261
+ batch=batch,
262
+ audio_features=self._unrepeat_from_reranking(
263
+ forward_args["audio_features"], reranking_candidates
264
+ ),
265
+ audio_pad_mask=self._unrepeat_from_reranking(
266
+ forward_args["audio_pad_mask"], reranking_candidates
267
+ ),
268
+ )
269
+
270
+ audio_features = forward_args["audio_features"]
271
+ B, T, C = audio_features.shape
272
+ C = C // 2 # we stack audio_features, so the actual channels is half
273
+
274
+ if noise is None:
275
+ noise = torch.randn_like(audio_features)
276
+
277
+ def vector_field(t, noisy_audio):
278
+ res = self.forward(
279
+ noisy_audio=noisy_audio,
280
+ time=t.expand(noisy_audio.size(0)),
281
+ **forward_args,
282
+ )
283
+ return res
284
+
285
+ states = odeint(
286
+ vector_field,
287
+ noise,
288
+ torch.tensor([0.0, 1.0], device=noise.device),
289
+ **ode_opt,
290
+ )
291
+ generated_features = states[-1].transpose(1, 2)
292
+ # generated_features has shape [B, 2C, T]. Reshape to stack along the batch dimension
293
+ wavs = self.audio_codec.decode(generated_features.reshape(2 * B, C, T)).view(
294
+ B, 2, -1
295
+ )
296
+
297
+ bsz = wavs.size(0) // reranking_candidates
298
+ sizes = self.audio_codec.feature_idx_to_wav_idx(batch.sizes)
299
+ target_wavs = self.unbatch(
300
+ wavs[:, 0].view(bsz, reranking_candidates, -1), sizes
301
+ )
302
+ residual_wavs = self.unbatch(
303
+ wavs[:, 1].view(bsz, reranking_candidates, -1), sizes
304
+ )
305
+
306
+ if (
307
+ reranking_candidates > 1
308
+ and batch.masked_video is not None
309
+ and self.visual_ranker is not None
310
+ ):
311
+ scores = self.visual_ranker(
312
+ extracted_audio=target_wavs,
313
+ videos=batch.masked_video,
314
+ sample_rate=self.audio_codec.sample_rate,
315
+ )
316
+ idxs = scores.argmax(dim=1)
317
+ elif reranking_candidates > 1 and self.text_ranker is not None:
318
+ input_audio = [
319
+ audio[:, :size].expand(reranking_candidates, -1)
320
+ for audio, size in zip(batch.audios, sizes, strict=False)
321
+ ]
322
+ scores = self.text_ranker(
323
+ extracted_audio=target_wavs,
324
+ input_audio=input_audio,
325
+ descriptions=batch.descriptions,
326
+ sample_rate=self.audio_codec.sample_rate,
327
+ )
328
+ idxs = scores.argmax(dim=1)
329
+ else:
330
+ idxs = torch.zeros(bsz, dtype=torch.long, device=noise.device)
331
+
332
+ return SeparationResult(
333
+ target=[wav[idx] for wav, idx in zip(target_wavs, idxs, strict=False)],
334
+ residual=[
335
+ wavs[idx] for wavs, idx in zip(residual_wavs, idxs, strict=False)
336
+ ],
337
+ noise=noise,
338
+ )
339
+
340
+ def unbatch(self, wavs: torch.Tensor, sizes: torch.Tensor, time_dim: int = -1):
341
+ result = []
342
+ for row, size in zip(wavs, sizes, strict=False):
343
+ result.append(row.narrow(dim=time_dim, start=0, length=size))
344
+ return result
345
+
346
+ def load_state_dict(self, state_dict, strict=True):
347
+ if strict:
348
+ missing_keys, unexpected_keys = super().load_state_dict(
349
+ state_dict, strict=False
350
+ )
351
+ # We load this directly from HF, not in checkpoint
352
+ skip_regex = re.compile(
353
+ "(^text_encoder|^visual_ranker|^text_ranker|^span_predictor)"
354
+ )
355
+ missing_keys = [x for x in missing_keys if not re.search(skip_regex, x)]
356
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
357
+ raise RuntimeError(
358
+ f"Missing keys: {missing_keys}, unexpected_keys: {unexpected_keys}"
359
+ )
360
+
361
+
362
+ __all__ = ["SAMAudio"]
sam_audio/model/patcher.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+
11
+ def pad1d(
12
+ x: torch.Tensor,
13
+ paddings: Tuple[int, int],
14
+ mode: str = "constant",
15
+ value: float = 0.0,
16
+ ):
17
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
18
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
19
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
20
+ """
21
+ length = x.shape[-1]
22
+ padding_left, padding_right = paddings
23
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
24
+ if mode == "reflect":
25
+ max_pad = max(padding_left, padding_right)
26
+ extra_pad = 0
27
+ if length <= max_pad:
28
+ extra_pad = max_pad - length + 1
29
+ x = F.pad(x, (0, extra_pad))
30
+ padded = F.pad(x, paddings, mode, value)
31
+ end = padded.shape[-1] - extra_pad
32
+ return padded[..., :end]
33
+ else:
34
+ return F.pad(x, paddings, mode, value)
35
+
36
+
37
+ def get_extra_padding_for_conv1d(
38
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
39
+ ) -> int:
40
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py
41
+ """See `pad_for_conv1d`."""
42
+ length = x.shape[-1]
43
+ n_frames = (length - kernel_size + padding_total) / stride + 1
44
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
45
+ return ideal_length - length
46
+
47
+
48
+ class Conv1d(torch.nn.Conv1d):
49
+ def __init__(self, *args, **kwargs):
50
+ super().__init__(*args, **kwargs)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ kernel_size = self.kernel_size[0]
54
+ stride = self.stride[0]
55
+ dilation = self.dilation[0]
56
+ kernel_size = (
57
+ kernel_size - 1
58
+ ) * dilation + 1 # effective kernel size with dilations
59
+ padding_total = kernel_size - stride
60
+ extra_padding = get_extra_padding_for_conv1d(
61
+ x, kernel_size, stride, padding_total
62
+ )
63
+ # Asymmetric padding required for odd strides
64
+ padding_right = padding_total // 2
65
+ padding_left = padding_total - padding_right
66
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
67
+ return super().forward(x)
68
+
69
+
70
+ class ConvBlock1d(torch.nn.Module):
71
+ def __init__(
72
+ self,
73
+ in_channels: int,
74
+ out_channels: int,
75
+ *,
76
+ kernel_size: int = 3,
77
+ stride: int = 1,
78
+ dilation: int = 1,
79
+ num_groups: int = 8,
80
+ ) -> None:
81
+ super().__init__()
82
+
83
+ self.groupnorm = torch.nn.GroupNorm(
84
+ num_groups=num_groups, num_channels=in_channels
85
+ )
86
+ self.activation = torch.nn.SiLU()
87
+ self.project = Conv1d(
88
+ in_channels=in_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=kernel_size,
91
+ stride=stride,
92
+ dilation=dilation,
93
+ )
94
+
95
+ def forward(
96
+ self,
97
+ x: torch.Tensor,
98
+ ) -> torch.Tensor:
99
+ x = self.groupnorm(x)
100
+ x = self.activation(x)
101
+ return self.project(x)
102
+
103
+
104
+ class ResnetBlock1d(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ in_channels: int,
108
+ out_channels: int,
109
+ *,
110
+ kernel_size: int = 3,
111
+ stride: int = 1,
112
+ dilation: int = 1,
113
+ num_groups: int = 8,
114
+ ) -> None:
115
+ super().__init__()
116
+
117
+ self.block1 = ConvBlock1d(
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ kernel_size=kernel_size,
121
+ stride=stride,
122
+ dilation=dilation,
123
+ num_groups=num_groups,
124
+ )
125
+
126
+ self.block2 = ConvBlock1d(
127
+ in_channels=out_channels,
128
+ out_channels=out_channels,
129
+ num_groups=num_groups,
130
+ )
131
+
132
+ self.to_out = (
133
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
134
+ if in_channels != out_channels
135
+ else torch.nn.Identity()
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ h = self.block1(x)
140
+ h = self.block2(h)
141
+ return h + self.to_out(x)
142
+
143
+
144
+ class Patcher(torch.nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ out_channels: int,
149
+ patch_size: int,
150
+ ):
151
+ super().__init__()
152
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
153
+ assert out_channels % patch_size == 0, assert_message
154
+ self.patch_size = patch_size
155
+ self.block = ResnetBlock1d(
156
+ in_channels=in_channels,
157
+ out_channels=out_channels // patch_size,
158
+ num_groups=1,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
162
+ x = self.block(x)
163
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
164
+ return x
sam_audio/model/rope.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+
9
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
10
+ """
11
+ Reshape frequency tensor for broadcasting it with another tensor.
12
+
13
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
14
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
15
+
16
+ Args:
17
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
18
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
19
+ seq_dim (int): Sequence dimension index.
20
+
21
+ Returns:
22
+ torch.Tensor: Reshaped frequency tensor.
23
+ """
24
+ ndim = x.ndim
25
+ assert 0 <= seq_dim < ndim
26
+ assert freqs_cis.shape == (
27
+ x.shape[seq_dim],
28
+ x.shape[-3],
29
+ 2,
30
+ 2,
31
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
32
+ shape = [
33
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
34
+ ] + [2, 2]
35
+ return freqs_cis.view(*shape)
36
+
37
+
38
+ def apply_rotary_emb(
39
+ xq: torch.Tensor,
40
+ xk: torch.Tensor,
41
+ seq_dim: int,
42
+ freqs_cis: torch.Tensor,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
44
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
45
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
46
+ freqs_cis = reshape_for_broadcast(
47
+ freqs_cis, xq_, seq_dim
48
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
49
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
50
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
51
+ return xq_out.type_as(xq), xk_out.type_as(xk)
52
+
53
+
54
+ class RotaryEmbedding(torch.nn.Module):
55
+ """
56
+ RotaryEmbedding Module
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ theta: float,
62
+ head_dim: int,
63
+ max_seqlen: int = 1024,
64
+ scale_factor: int = 1,
65
+ low_freq_factor: int = 1,
66
+ high_freq_factor: int = 32,
67
+ old_context_len: int = 8192,
68
+ ):
69
+ super().__init__()
70
+
71
+ self.theta = theta
72
+ self.head_dim = head_dim
73
+ self.max_seqlen = max_seqlen
74
+ self.scale_factor = scale_factor
75
+ self.low_freq_factor = low_freq_factor
76
+ self.high_freq_factor = high_freq_factor
77
+ self.old_context_len = old_context_len
78
+ if scale_factor != 1:
79
+ self.low_freq_wavelen = old_context_len / low_freq_factor
80
+ self.high_freq_wavelen = old_context_len / high_freq_factor
81
+ assert self.low_freq_wavelen >= self.high_freq_wavelen
82
+
83
+ def reset_parameters(self):
84
+ freqs_cis = self.precompute_freqs_cis(
85
+ dim=self.head_dim, end=self.max_seqlen, theta=self.theta
86
+ )
87
+ S, D, _, _ = freqs_cis.shape
88
+ # S D 2 2 -> 1 S 1 D 2 2
89
+ freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
90
+ self.register_buffer(
91
+ "freqs_cis",
92
+ freqs_cis,
93
+ persistent=False,
94
+ )
95
+
96
+ def apply_scaling(self, freqs):
97
+ if self.scale_factor == 1:
98
+ return freqs
99
+ new_freqs = []
100
+ for freq in freqs:
101
+ wavelen = 2 * math.pi / freq
102
+ if wavelen < self.high_freq_wavelen:
103
+ new_freqs.append(freq)
104
+ elif wavelen > self.low_freq_wavelen:
105
+ new_freqs.append(freq / self.scale_factor)
106
+ else:
107
+ assert self.low_freq_wavelen != self.high_freq_wavelen
108
+ smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
109
+ self.high_freq_factor - self.low_freq_factor
110
+ )
111
+ new_freqs.append(
112
+ (1 - smooth) * freq / self.scale_factor + smooth * freq
113
+ )
114
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
115
+
116
+ def precompute_freqs_cis(
117
+ self,
118
+ dim: int,
119
+ end: int,
120
+ theta: float = 10000.0,
121
+ ):
122
+ """
123
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
124
+
125
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
126
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
127
+ The returned tensor contains complex values in complex64 data type.
128
+
129
+ Args:
130
+ dim (int): Dimension of the frequency tensor.
131
+ end (int): End index for precomputing frequencies.
132
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
133
+
134
+ Returns:
135
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
136
+ """
137
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
138
+ freqs = self.apply_scaling(freqs)
139
+
140
+ t = torch.arange(end, device=freqs.device)
141
+ freqs = torch.outer(t, freqs).float()
142
+
143
+ cos, sin = freqs.cos(), freqs.sin()
144
+
145
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
146
+
147
+ def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
148
+ if bhle:
149
+ x = x.transpose(1, 2) # (B H L E) -> (B L H E)
150
+ seqlen = x.size(1)
151
+ x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
152
+ x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
153
+ if bhle:
154
+ x_out = x_out.transpose(1, 2)
155
+ return x_out.type_as(x)
sam_audio/model/text_encoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import transformers
7
+
8
+ from sam_audio.model.config import T5EncoderConfig
9
+
10
+
11
+ class T5TextEncoder(torch.nn.Module):
12
+ def __init__(self, cfg: T5EncoderConfig):
13
+ super().__init__()
14
+ self.model = transformers.T5EncoderModel.from_pretrained(cfg.name)
15
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.name)
16
+ self.pad_mode = cfg.pad_mode
17
+ self.max_length = cfg.max_length
18
+
19
+ def forward(self, texts: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ device = next(self.model.parameters()).device
21
+ encoded = self.tokenizer(
22
+ texts,
23
+ truncation=True,
24
+ max_length=self.max_length,
25
+ padding=self.pad_mode,
26
+ return_tensors="pt",
27
+ )
28
+
29
+ input_ids = encoded["input_ids"].to(device)
30
+ attention_mask = encoded["attention_mask"].to(device)
31
+ res = self.model(
32
+ input_ids=input_ids,
33
+ attention_mask=attention_mask,
34
+ output_hidden_states=True,
35
+ )["last_hidden_state"]
36
+
37
+ return res, attention_mask.bool()
sam_audio/model/transformer.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from functools import partial
5
+ from typing import List, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+ from .config import TransformerConfig
13
+ from .patcher import Patcher
14
+ from .rope import RotaryEmbedding
15
+
16
+
17
+ def gate(x, gate):
18
+ return x * gate
19
+
20
+
21
+ def modulate(x, shift, scale):
22
+ return x * (1 + scale) + shift
23
+
24
+
25
+ def get_nonlinearity(kind: str):
26
+ return {
27
+ "relu": F.relu,
28
+ "gelu": F.gelu,
29
+ "swiglu": None,
30
+ "approx_gelu": partial(F.gelu, approximate="tanh"),
31
+ "srelu": lambda x: F.relu(x) ** 2,
32
+ "silu": F.silu,
33
+ }[kind]
34
+
35
+
36
+ class RMSNorm(torch.nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-5):
38
+ super().__init__()
39
+ self.eps = eps
40
+ self.weight = torch.nn.Parameter(torch.ones(dim))
41
+
42
+ def _norm(self, x):
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ output = self._norm(x.float())
47
+ return (output * self.weight).type_as(x)
48
+
49
+
50
+ class ProjectionLayer(torch.nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_dim: int,
54
+ out_dim: int,
55
+ non_linearity: str,
56
+ dropout: float,
57
+ fc_bias: bool = False,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.swiglu = non_linearity == "swiglu"
62
+ self.dropout = dropout
63
+ self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
64
+
65
+ self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias)
66
+ if self.swiglu:
67
+ self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
68
+
69
+ # non-linearity
70
+ self.non_linearity = get_nonlinearity(non_linearity)
71
+
72
+ def forward(self, x):
73
+ hidden1 = self.w1(x)
74
+ if self.swiglu:
75
+ hidden3 = self.w3(x)
76
+ hidden = F.silu(hidden1) * hidden3
77
+ else:
78
+ hidden = self.non_linearity(hidden1)
79
+ hidden = F.dropout(hidden, p=self.dropout, training=self.training)
80
+ return self.w2(hidden)
81
+
82
+
83
+ class Attention(nn.Module):
84
+ def __init__(
85
+ self,
86
+ dim: int,
87
+ head_dim: int,
88
+ n_heads: int,
89
+ n_kv_heads: int,
90
+ norm_eps: float = 1e-5,
91
+ use_qk_norm: bool = False,
92
+ fc_bias: bool = False,
93
+ ):
94
+ super().__init__()
95
+ assert n_heads % n_kv_heads == 0
96
+
97
+ self.head_dim = head_dim
98
+ self.n_heads = n_heads
99
+ self.n_kv_heads = n_kv_heads
100
+ self.use_qk_norm = use_qk_norm
101
+
102
+ self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias)
103
+ self.wk, self.wv = [
104
+ torch.nn.Linear(
105
+ dim,
106
+ n_kv_heads * head_dim,
107
+ bias=fc_bias,
108
+ )
109
+ for _ in range(2)
110
+ ]
111
+ self.wo = torch.nn.Linear(
112
+ n_heads * head_dim,
113
+ dim,
114
+ bias=fc_bias,
115
+ )
116
+
117
+ if self.use_qk_norm is True:
118
+ self.q_norm = RMSNorm(head_dim, eps=norm_eps)
119
+ self.k_norm = RMSNorm(head_dim, eps=norm_eps)
120
+
121
+ def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor:
122
+ B, T, C = x.shape
123
+ # B x T x C -> B x T x C/H x H
124
+ x = x.reshape(B, T, C // heads, heads)
125
+ # B x T x C/H x H -> B x H x T x C/H
126
+ return x.permute(0, 3, 1, 2)
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ cross_x: Optional[torch.Tensor] = None,
132
+ key_padding_mask: Optional[torch.Tensor] = None,
133
+ rope: Optional[RotaryEmbedding] = None,
134
+ ):
135
+ # x: B, T, E
136
+ xq = self.wq(x)
137
+ if cross_x is not None:
138
+ xk, xv = self.wk(cross_x), self.wv(cross_x)
139
+ else:
140
+ xk, xv = self.wk(x), self.wv(x)
141
+
142
+ xk = self.reshape_heads(xk, self.n_kv_heads)
143
+ xv = self.reshape_heads(xv, self.n_kv_heads)
144
+ xq = self.reshape_heads(xq, self.n_heads)
145
+ if self.use_qk_norm:
146
+ xq = self.q_norm(xq)
147
+ xk = self.k_norm(xk)
148
+
149
+ if rope is not None:
150
+ xq = rope(xq, bhle=True)
151
+ xk = rope(xk, bhle=True)
152
+
153
+ attn_mask = None
154
+
155
+ if key_padding_mask is not None:
156
+ attn_mask = key_padding_mask[:, None, None, :]
157
+
158
+ output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask)
159
+
160
+ output = rearrange(output, "b h n d -> b n (h d)")
161
+ return self.wo(output)
162
+
163
+
164
+ class FeedForward(torch.nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim: int,
168
+ hidden_dim: int,
169
+ ffn_dim_multiplier: float,
170
+ multiple_of: int,
171
+ dropout: float,
172
+ non_linearity: str = "swiglu",
173
+ fc_bias: bool = False,
174
+ ):
175
+ super().__init__()
176
+ self.dropout = dropout
177
+ self.swiglu = non_linearity == "swiglu"
178
+ # swiglu hidden dim factor multiplier (same #params as relu / gelu)
179
+ if self.swiglu:
180
+ hidden_dim = int(2 * hidden_dim / 3)
181
+
182
+ # custom dim factor multiplier
183
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
184
+ # round hidden dimension to `multiple_of`
185
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
186
+ # layers
187
+ self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
188
+ self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias)
189
+ if self.swiglu:
190
+ self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
191
+
192
+ # non-linearity
193
+ self.non_linearity = get_nonlinearity(non_linearity)
194
+
195
+ def forward(
196
+ self,
197
+ x,
198
+ ):
199
+ hidden1 = self.w1(x)
200
+ if self.swiglu:
201
+ hidden3 = self.w3(x)
202
+ hidden = F.silu(hidden1) * hidden3
203
+ else:
204
+ hidden = self.non_linearity(hidden1)
205
+ hidden = F.dropout(hidden, p=self.dropout, training=self.training)
206
+ return self.w2(hidden)
207
+
208
+
209
+ class TimestepEmbedder(torch.nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ frequency_embedding_dim: int,
214
+ non_linearity: str,
215
+ dropout: float,
216
+ fc_bias: bool,
217
+ max_period: int = 10000,
218
+ ):
219
+ super().__init__()
220
+ self.frequency_embedding_size = frequency_embedding_dim
221
+ self.projection = ProjectionLayer(
222
+ in_dim=frequency_embedding_dim,
223
+ out_dim=dim,
224
+ non_linearity=non_linearity,
225
+ dropout=dropout,
226
+ fc_bias=fc_bias,
227
+ )
228
+ half = frequency_embedding_dim // 2
229
+ freqs = torch.exp(
230
+ -math.log(max_period)
231
+ * torch.arange(start=0, end=half, dtype=torch.float32)
232
+ / half
233
+ )
234
+ self.register_buffer("freqs", freqs, persistent=False)
235
+
236
+ def timestep_embedding(self, t, dim):
237
+ """
238
+ Create sinusoidal timestep embeddings.
239
+ :param t: a 1-D Tensor of N indices, one per batch element.
240
+ These may be fractional.
241
+ :param dim: the dimension of the output.
242
+ :param max_period: controls the minimum frequency of the embeddings.
243
+ :return: an (N, D) Tensor of positional embeddings.
244
+ """
245
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
246
+ self.freqs = self.freqs.to(device=t.device)
247
+ args = t[:, None].float() * self.freqs[None]
248
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
249
+ if dim % 2:
250
+ embedding = torch.cat(
251
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
252
+ )
253
+ return embedding.to(t)
254
+
255
+ def forward(self, t):
256
+ x = self.timestep_embedding(t, self.frequency_embedding_size)
257
+ return self.projection(x)
258
+
259
+
260
+ class ContextEmbedder(torch.nn.Module):
261
+ def __init__(
262
+ self,
263
+ in_dim: int,
264
+ out_dim: int,
265
+ non_linearity: str,
266
+ dropout: float,
267
+ fc_bias: bool,
268
+ norm_eps: float = 1e-5,
269
+ context_norm: bool = False,
270
+ ):
271
+ super().__init__()
272
+ self.context_norm = context_norm
273
+ if context_norm:
274
+ self.norm = RMSNorm(in_dim, norm_eps)
275
+
276
+ self.projection = ProjectionLayer(
277
+ in_dim=in_dim,
278
+ out_dim=out_dim,
279
+ non_linearity=non_linearity,
280
+ dropout=dropout,
281
+ fc_bias=fc_bias,
282
+ )
283
+
284
+ def forward(self, x):
285
+ if self.context_norm:
286
+ x = self.norm(x)
287
+ h = self.projection(x)
288
+ return h
289
+
290
+
291
+ class DiTBlock(torch.nn.Module):
292
+ def __init__(
293
+ self,
294
+ dim: int,
295
+ n_heads: int,
296
+ n_kv_heads: Optional[int] = None,
297
+ dropout: float = 0.0,
298
+ norm_eps: float = 1e-5,
299
+ qk_norm: bool = False,
300
+ fc_bias: bool = False,
301
+ ffn_exp: int = 1,
302
+ ffn_dim_multiplier: int = 4,
303
+ multiple_of: int = 64,
304
+ non_linearity: str = "silu",
305
+ no_cross_attention: bool = False,
306
+ ):
307
+ super().__init__()
308
+ assert dim % n_heads == 0
309
+ self.n_heads = n_heads
310
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
311
+ self.dim = dim
312
+ self.dropout = dropout
313
+ self.head_dim = dim // n_heads
314
+
315
+ assert self.n_heads % self.n_kv_heads == 0
316
+
317
+ self.attention = Attention(
318
+ dim=dim,
319
+ head_dim=self.head_dim,
320
+ n_heads=self.n_heads,
321
+ n_kv_heads=self.n_kv_heads,
322
+ norm_eps=norm_eps,
323
+ use_qk_norm=qk_norm,
324
+ fc_bias=fc_bias,
325
+ )
326
+ self.feed_forward = FeedForward(
327
+ dim=dim,
328
+ hidden_dim=int(ffn_exp * dim),
329
+ ffn_dim_multiplier=ffn_dim_multiplier,
330
+ multiple_of=multiple_of,
331
+ dropout=dropout,
332
+ non_linearity=non_linearity,
333
+ fc_bias=fc_bias,
334
+ )
335
+
336
+ self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)]
337
+
338
+ self.cross_attention = None
339
+ if not no_cross_attention:
340
+ self.cross_attention = Attention(
341
+ dim=dim,
342
+ head_dim=self.head_dim,
343
+ n_heads=self.n_heads,
344
+ n_kv_heads=self.n_heads,
345
+ norm_eps=norm_eps,
346
+ use_qk_norm=qk_norm,
347
+ fc_bias=fc_bias,
348
+ )
349
+
350
+ self.scale_shift_table = nn.Parameter(
351
+ torch.randn(6, self.dim) / self.dim**0.5,
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ x: torch.Tensor,
357
+ cross_x: Optional[torch.Tensor],
358
+ t: torch.Tensor,
359
+ padding_mask: Optional[torch.Tensor],
360
+ memory_padding_mask: Optional[torch.Tensor],
361
+ rope: Optional[RotaryEmbedding] = None,
362
+ ):
363
+ biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1)
364
+ (
365
+ shift_msa,
366
+ scale_msa,
367
+ gate_msa,
368
+ shift_mlp,
369
+ scale_mlp,
370
+ gate_mlp,
371
+ ) = biases.chunk(6, dim=1)
372
+
373
+ assert self.attention is not None and self.attention_norm is not None
374
+ h_attn = self.attention(
375
+ modulate(self.attention_norm(x), shift_msa, scale_msa),
376
+ key_padding_mask=padding_mask,
377
+ rope=rope,
378
+ )
379
+
380
+ h = x + gate(h_attn, gate_msa)
381
+
382
+ if self.cross_attention is not None:
383
+ h_cross = self.cross_attention(
384
+ x=h,
385
+ cross_x=cross_x,
386
+ key_padding_mask=memory_padding_mask,
387
+ )
388
+ h = h + h_cross # residual
389
+ h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
390
+ out = h + gate(h_ff, gate_mlp)
391
+ return out
392
+
393
+
394
+ class DiT(torch.nn.Module):
395
+ def __init__(self, config: TransformerConfig):
396
+ super().__init__()
397
+ self.dropout = config.dropout
398
+ if config.in_channels is not None:
399
+ self.data_proj = torch.nn.Linear(config.in_channels, config.dim)
400
+
401
+ # embeddings
402
+ self.rope_embeddings = None
403
+ # rotary embeddings
404
+ if config.use_rope:
405
+ self.rope_embeddings = RotaryEmbedding(
406
+ theta=max(10000, 2 * config.max_positions),
407
+ head_dim=config.dim // config.n_heads,
408
+ max_seqlen=config.max_positions,
409
+ )
410
+ self.rope_embeddings.reset_parameters()
411
+
412
+ # transformer blocks
413
+ self.layers = nn.ModuleList()
414
+ for _ in range(config.n_layers):
415
+ self.layers.append(
416
+ DiTBlock(
417
+ dim=config.dim,
418
+ n_heads=config.n_heads,
419
+ dropout=config.dropout,
420
+ norm_eps=config.norm_eps,
421
+ qk_norm=config.qk_norm,
422
+ fc_bias=config.fc_bias,
423
+ ffn_exp=config.ffn_exp,
424
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
425
+ multiple_of=config.multiple_of,
426
+ non_linearity=config.non_linearity,
427
+ )
428
+ )
429
+
430
+ self.norm = RMSNorm(config.dim, config.norm_eps)
431
+
432
+ # output layer
433
+ self.output = torch.nn.Linear(
434
+ config.dim, config.out_channels, bias=config.fc_bias
435
+ )
436
+
437
+ self.x_embedder = Patcher(
438
+ in_channels=config.dim,
439
+ out_channels=config.dim,
440
+ patch_size=1,
441
+ )
442
+
443
+ self.y_embedder = ContextEmbedder(
444
+ in_dim=config.context_dim,
445
+ out_dim=config.dim,
446
+ non_linearity=config.context_non_linearity,
447
+ dropout=config.context_embedder_dropout,
448
+ fc_bias=config.fc_bias,
449
+ norm_eps=config.norm_eps,
450
+ context_norm=config.context_norm,
451
+ )
452
+
453
+ self.t_embedder = TimestepEmbedder(
454
+ config.dim,
455
+ config.frequency_embedding_dim,
456
+ non_linearity=config.timestep_non_linearity,
457
+ dropout=config.dropout,
458
+ fc_bias=config.fc_bias,
459
+ max_period=10000,
460
+ )
461
+
462
+ self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity)
463
+ self.t_block = torch.nn.Linear(
464
+ config.dim,
465
+ config.dim * 6,
466
+ bias=config.t_block_bias,
467
+ )
468
+
469
+ self.final_layer_scale_shift_table = nn.Parameter(
470
+ torch.randn(2, config.dim) / config.dim**0.5,
471
+ )
472
+
473
+ def forward(
474
+ self,
475
+ x: torch.Tensor,
476
+ time: torch.Tensor,
477
+ *,
478
+ padding_mask: Optional[torch.Tensor] = None,
479
+ memory: Optional[torch.Tensor] = None,
480
+ memory_padding_mask: Optional[torch.Tensor] = None,
481
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
482
+ x = rearrange(x, "b l c-> b c l")
483
+ h = self.x_embedder(x)
484
+ h = rearrange(h, "b c l -> b l c")
485
+ original_N = h.shape[1]
486
+ N = h.shape[1]
487
+
488
+ h = F.dropout(h, p=self.dropout, training=self.training)
489
+
490
+ t = self.t_embedder(time) # B -> B D
491
+
492
+ t0 = self.t_block_non_linearity(t)
493
+ t0 = self.t_block(t0) # B D -> B 6D
494
+
495
+ y = self.y_embedder(memory)
496
+
497
+ for layer in self.layers:
498
+ h = layer(
499
+ x=h,
500
+ cross_x=y,
501
+ t=t0,
502
+ padding_mask=padding_mask,
503
+ memory_padding_mask=memory_padding_mask,
504
+ rope=self.rope_embeddings,
505
+ )
506
+
507
+ shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk(
508
+ 2, dim=1
509
+ )
510
+
511
+ # output layer
512
+ if self.norm is not None:
513
+ h = self.norm(h)
514
+
515
+ h = modulate(h, shift, scale)
516
+
517
+ h = F.dropout(h, p=self.dropout, training=self.training)
518
+
519
+ output = self.output(h)
520
+
521
+ N = output.shape[1]
522
+ if original_N != N:
523
+ output = output[:, -original_N:]
524
+ return output
sam_audio/model/vision_encoder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+
5
+ import torch
6
+ import torchvision
7
+ from core.vision_encoder import pe
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ from sam_audio.model.config import (
11
+ PerceptionEncoderConfig,
12
+ VisionEncoderConfig,
13
+ )
14
+
15
+
16
+ class RescaleTransform(object):
17
+ """Rescale the image in a sample to a given size.
18
+
19
+ Args:
20
+ output_size (tuple or int): Desired output size. If tuple, output is
21
+ matched to output_size. If int, smaller of image edges is matched
22
+ to output_size keeping aspect ratio the same.
23
+ """
24
+
25
+ def __init__(self, output_size, interpolation):
26
+ assert isinstance(output_size, (int, tuple))
27
+ self.output_size = output_size
28
+ if isinstance(output_size, int):
29
+ self.output_size = (output_size, output_size)
30
+ self.interpolation = interpolation
31
+
32
+ def __call__(self, sample):
33
+ # sample: [T, C, H, W]
34
+ sample = torch.nn.functional.interpolate(
35
+ sample.float(), size=self.output_size, mode=self.interpolation.value
36
+ )
37
+ return sample
38
+
39
+
40
+ class VisionEncoder(torch.nn.Module, metaclass=ABCMeta):
41
+ def __init__(self, cfg: VisionEncoderConfig):
42
+ super().__init__()
43
+ self.batch_size = cfg.batch_size
44
+ self.dim = cfg.dim
45
+ self.transform = self.get_transform()
46
+
47
+ @torch.no_grad()
48
+ def forward(self, videos: list[torch.Tensor]) -> torch.Tensor:
49
+ """
50
+ Encodes a list of input videos. Each element of the list is a video represented
51
+ as a tensor [T, C, H, W]
52
+ Args:
53
+ videos (list[torch.Tensor]): List of input image tensors to be processed.
54
+
55
+ Returns:
56
+ torch.Tensor: Encoded feature representations of the input tensors.
57
+ The output is padded along the time dimension for variable length videos
58
+ """
59
+ result = []
60
+ for video in videos:
61
+ video = self.transform(video)
62
+ if self.batch_size > 0 and video.size(0) > self.batch_size:
63
+ res = []
64
+ for i in range(0, video.size(0), self.batch_size):
65
+ res.append(self.encode(video[i : i + self.batch_size]))
66
+ result.append(torch.cat(res, dim=0))
67
+ else:
68
+ result.append(self.encode(video))
69
+ return pad_sequence(result, batch_first=True, padding_value=0.0)
70
+
71
+ @abstractmethod
72
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
73
+ pass
74
+
75
+ @abstractmethod
76
+ def get_transform(self):
77
+ pass
78
+
79
+
80
+ class PerceptionEncoder(VisionEncoder):
81
+ def __init__(self, cfg: PerceptionEncoderConfig):
82
+ self.normalize_feature = cfg.normalize_feature
83
+ self.interpolation_mode = cfg.interpolation_mode
84
+ self.image_size = cfg.image_size
85
+ super().__init__(cfg)
86
+ self.model = pe.CLIP.from_config(cfg.name)
87
+
88
+ def encode(self, x):
89
+ image_features = self.model.encode_image(x, normalize=self.normalize_feature)
90
+ return image_features
91
+
92
+ def get_transform(self):
93
+ T = torchvision.transforms
94
+ try:
95
+ interp = getattr(T.InterpolationMode, self.interpolation_mode.upper())
96
+ except AttributeError as err:
97
+ raise ValueError(
98
+ f"Unsupported interpolation_mode: {self.interpolation_mode}"
99
+ ) from err
100
+ crop = [
101
+ T.Resize(
102
+ (self.image_size, self.image_size),
103
+ interpolation=interp,
104
+ )
105
+ ]
106
+
107
+ return T.Compose(
108
+ crop
109
+ + [
110
+ T.Lambda(lambda x: x.float() / 255.0),
111
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
112
+ ]
113
+ )
sam_audio/processor.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ from typing import Callable, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torchaudio
11
+ from huggingface_hub import hf_hub_download
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from torchcodec.decoders import AudioDecoder, VideoDecoder
14
+ from transformers import AutoTokenizer, BatchFeature
15
+
16
+ from sam_audio.model.config import SAMAudioConfig, SAMAudioJudgeConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ Anchor = Tuple[str, float, float]
21
+
22
+
23
+ def batch_audio(
24
+ audios: list[str | torch.Tensor], audio_sampling_rate: int = 48_000
25
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
26
+ wavs = []
27
+ for audio in audios:
28
+ if isinstance(audio, str):
29
+ wav, sr = torchaudio.load(audio)
30
+ if sr != audio_sampling_rate:
31
+ wav = torchaudio.functional.resample(wav, sr, audio_sampling_rate)
32
+ else:
33
+ wav = audio
34
+ wavs.append(wav.mean(0))
35
+ sizes = torch.tensor([wav.size(-1) for wav in wavs])
36
+ return pad_sequence(wavs, batch_first=True).unsqueeze(1), sizes
37
+
38
+
39
+ class Batch:
40
+ def __init__(
41
+ self,
42
+ audios: torch.Tensor,
43
+ sizes: torch.Tensor,
44
+ wav_sizes: torch.Tensor,
45
+ descriptions: list[str],
46
+ hop_length: int,
47
+ audio_sampling_rate: int,
48
+ anchors: Optional[list[list[Anchor]]] = None,
49
+ audio_pad_mask: Optional[torch.Tensor] = None,
50
+ masked_video: Optional[torch.Tensor] = None,
51
+ ):
52
+ self.audios = audios
53
+ self.sizes = sizes
54
+ self.wav_sizes = wav_sizes
55
+ self.descriptions = descriptions
56
+ self.audio_pad_mask = audio_pad_mask
57
+ self.masked_video = masked_video
58
+ self.hop_length = hop_length
59
+ self.audio_sampling_rate = audio_sampling_rate
60
+ self.process_anchors(anchors)
61
+ assert self.audios.size(0) == len(self.descriptions)
62
+
63
+ def _wav_to_feature_idx(self, wav_idx: int):
64
+ return math.ceil(wav_idx / self.hop_length)
65
+
66
+ def to(self, device: torch.device):
67
+ self.audios = self.audios.to(device)
68
+ self.anchor_ids = self.anchor_ids.to(device)
69
+ self.anchor_alignment = self.anchor_alignment.to(device)
70
+ self.sizes = self.sizes.to(device)
71
+ self.wav_sizes = self.wav_sizes.to(device)
72
+ if self.audio_pad_mask is not None:
73
+ self.audio_pad_mask = self.audio_pad_mask.to(device)
74
+ if self.masked_video is not None:
75
+ self.masked_video = [v.to(device) for v in self.masked_video]
76
+ return self
77
+
78
+ def process_anchors(self, anchors: Optional[list[list[Anchor]]]):
79
+ batch_size = len(self.audios)
80
+ anchor_dict = {"<null>": 0, "+": 1, "-": 2, "<pad>": 3}
81
+ if anchors is None:
82
+ anchor_ids = torch.full(
83
+ (batch_size, 2), anchor_dict["<null>"], dtype=torch.long
84
+ )
85
+ anchor_ids[:, 1] = anchor_dict["<pad>"]
86
+ anchor_alignment = torch.full(
87
+ (
88
+ batch_size,
89
+ self.audio_pad_mask.size(-1),
90
+ ),
91
+ 0,
92
+ dtype=torch.long,
93
+ )
94
+ anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
95
+ else:
96
+ anchor_alignment = torch.full(
97
+ (
98
+ batch_size,
99
+ self.audio_pad_mask.size(-1),
100
+ ),
101
+ 0,
102
+ dtype=torch.long,
103
+ )
104
+ anchor_alignment[~self.audio_pad_mask] = 1 # point to pad token
105
+ ids = []
106
+
107
+ for i, anchor_list in enumerate(anchors):
108
+ current = [anchor_dict["<null>"], anchor_dict["<pad>"]]
109
+ for token, start_time, end_time in anchor_list:
110
+ start_idx = self._wav_to_feature_idx(
111
+ start_time * self.audio_sampling_rate
112
+ )
113
+ end_idx = self._wav_to_feature_idx(
114
+ end_time * self.audio_sampling_rate
115
+ )
116
+ anchor_alignment[i, start_idx:end_idx] = len(current)
117
+ current.append(anchor_dict[token])
118
+ ids.append(torch.tensor(current))
119
+ anchor_ids = pad_sequence(
120
+ ids, batch_first=True, padding_value=anchor_dict["<pad>"]
121
+ )
122
+ self.anchor_ids = anchor_ids
123
+ self.anchor_alignment = anchor_alignment
124
+ self.anchors = anchors
125
+
126
+
127
+ def mask_from_sizes(sizes: torch.Tensor) -> torch.Tensor:
128
+ return torch.arange(sizes.max()).expand(len(sizes), -1) < sizes.unsqueeze(1)
129
+
130
+
131
+ def load_video(
132
+ sizes: torch.Tensor,
133
+ videos: List[str],
134
+ feature_idx_to_wav_idx: Callable[[torch.Tensor], torch.Tensor],
135
+ audio_sampling_rate: int,
136
+ ) -> list[torch.Tensor]:
137
+ all_frames = []
138
+ for size, video in zip(sizes, videos, strict=False):
139
+ audio_timestamps = (
140
+ feature_idx_to_wav_idx(torch.arange(size)) / audio_sampling_rate
141
+ )
142
+ if isinstance(video, str):
143
+ decoder = VideoDecoder(video, dimension_order="NCHW")
144
+ data = decoder.get_frames_in_range(0, len(decoder))
145
+ diffs = (audio_timestamps[None] - data.pts_seconds[:, None]).abs()
146
+ frame_idxs = diffs.argmin(dim=0)
147
+ frames = data.data[frame_idxs]
148
+ else:
149
+ assert video.size(1) == 3, (
150
+ f"Expected video tensor to be in NCHW format, but found {video.size(1)} channels"
151
+ )
152
+ idx = torch.linspace(0, video.size(0) - 1, int(size)).round().long()
153
+ frames = video[idx]
154
+ all_frames.append(frames)
155
+ return all_frames
156
+
157
+
158
+ class Processor:
159
+ config_cls: Callable
160
+
161
+ def __init__(self, audio_hop_length: int, audio_sampling_rate: int):
162
+ self.audio_hop_length = audio_hop_length
163
+ self.audio_sampling_rate = audio_sampling_rate
164
+
165
+ @classmethod
166
+ def _get_config(cls, model_name_or_path: str):
167
+ if os.path.exists(model_name_or_path):
168
+ config_path = os.path.join(model_name_or_path, "config.json")
169
+ else:
170
+ config_path = hf_hub_download(
171
+ repo_id=model_name_or_path,
172
+ filename="config.json",
173
+ revision=cls.revision,
174
+ )
175
+ with open(config_path) as fin:
176
+ config = cls.config_cls(**json.load(fin))
177
+ return config
178
+
179
+ @classmethod
180
+ def from_pretrained(cls, model_name_or_path: str) -> "Processor":
181
+ config = cls._get_config(model_name_or_path)
182
+ return cls(
183
+ audio_hop_length=config.audio_codec.hop_length,
184
+ audio_sampling_rate=config.audio_codec.sample_rate,
185
+ )
186
+
187
+ def feature_to_wav_idx(self, feature_idx):
188
+ return feature_idx * self.audio_hop_length
189
+
190
+ def wav_to_feature_idx(self, wav_idx):
191
+ if torch.is_tensor(wav_idx):
192
+ ceil = torch.ceil
193
+ else:
194
+ ceil = math.ceil
195
+ return ceil(wav_idx / self.audio_hop_length)
196
+
197
+ def mask_videos(
198
+ self,
199
+ videos: List[str | torch.Tensor],
200
+ masks: List[str | torch.Tensor],
201
+ ) -> list[torch.Tensor]:
202
+ video = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in videos]
203
+ video_mask = [VideoDecoder(v)[:] if isinstance(v, str) else v for v in masks]
204
+ return [v * m.eq(0) for v, m in zip(video, video_mask, strict=False)]
205
+
206
+
207
+ class SAMAudioProcessor(Processor):
208
+ config_cls = SAMAudioConfig
209
+ revision = None
210
+
211
+ def __call__(
212
+ self,
213
+ descriptions: list[str],
214
+ audios: list[str | torch.Tensor],
215
+ anchors: Optional[list[list[Anchor]]] = None,
216
+ masked_videos: Optional[list[str | torch.Tensor]] = None,
217
+ ):
218
+ """
219
+ Processes input data for the model.
220
+
221
+ Args:
222
+ descriptions (list[str]): List of text descriptions corresponding to each audio sample.
223
+ audios (list[str]): List of audio file paths or tensors.
224
+ If a tensor:
225
+ - should have shape (channels, time) where channels=1 for mono and 2 for stereo.
226
+ - should be resampled to 48_000 hz
227
+ anchors (Optional[list[list[Anchor]]], optional): List of anchors for each sample,
228
+ where each anchor is a tuple (token, start_time, end_time).
229
+ masked_videos (Optional[list[str | torch.Tensor]], optional): List of masked video file paths or tensors.
230
+ If a tensor, should have shape (N, C, H, W)
231
+
232
+ Returns:
233
+ Batch: A Batch object containing processed audio, sizes, descriptions, anchor ids, anchor alignment, audio pad mask, and optionally masked video.
234
+ """
235
+
236
+ assert len(descriptions) == len(audios)
237
+ assert anchors is None or len(descriptions) == len(anchors)
238
+ assert masked_videos is None or len(descriptions) == len(masked_videos)
239
+
240
+ audios, wav_sizes = batch_audio(audios, self.audio_sampling_rate)
241
+
242
+ sizes = self.wav_to_feature_idx(wav_sizes)
243
+ audio_pad_mask = mask_from_sizes(sizes)
244
+ masked_video = None
245
+ if masked_videos is not None:
246
+ masked_video = load_video(
247
+ sizes, masked_videos, self.feature_to_wav_idx, self.audio_sampling_rate
248
+ )
249
+
250
+ return Batch(
251
+ audios=audios,
252
+ sizes=sizes,
253
+ descriptions=descriptions,
254
+ audio_pad_mask=audio_pad_mask,
255
+ anchors=anchors,
256
+ masked_video=masked_video,
257
+ hop_length=self.audio_hop_length,
258
+ audio_sampling_rate=self.audio_sampling_rate,
259
+ wav_sizes=wav_sizes,
260
+ )
261
+
262
+
263
+ class SAMAudioJudgeProcessor(Processor):
264
+ config_cls = SAMAudioJudgeConfig
265
+ revision = "sam_audio"
266
+
267
+ def __init__(
268
+ self,
269
+ audio_hop_length: int,
270
+ audio_sampling_rate: int,
271
+ tokenizer: AutoTokenizer,
272
+ ):
273
+ super().__init__(audio_hop_length, audio_sampling_rate)
274
+ self.tokenizer = tokenizer
275
+
276
+ @classmethod
277
+ def from_pretrained(cls, model_name_or_path: str) -> "SAMAudioJudgeProcessor":
278
+ config = cls._get_config(model_name_or_path)
279
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
280
+ return cls(
281
+ audio_hop_length=config.audio_codec.hop_length,
282
+ audio_sampling_rate=config.audio_codec.sample_rate,
283
+ tokenizer=tokenizer,
284
+ )
285
+
286
+ def _reflect_pad(self, wav):
287
+ if wav.ndim == 1:
288
+ wav = wav.unsqueeze(0)
289
+ if wav.size(-1) % self.audio_hop_length == 0:
290
+ return wav
291
+ p1d = (0, self.audio_hop_length - (wav.size(-1) % self.audio_hop_length))
292
+ return torch.nn.functional.pad(wav, p1d, mode="reflect")
293
+
294
+ def _load_audio(self, path: str):
295
+ ad = AudioDecoder(path, sample_rate=self.audio_sampling_rate, num_channels=1)
296
+ return ad.get_all_samples().data
297
+
298
+ def _process_audio(
299
+ self,
300
+ raw_audio,
301
+ sampling_rate: Optional[int] = None,
302
+ ):
303
+ from_file = False
304
+ if isinstance(raw_audio, str):
305
+ raw_audio = [raw_audio]
306
+
307
+ if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str):
308
+ loaded = []
309
+ for audio_file in raw_audio:
310
+ loaded.append(self._load_audio(audio_file))
311
+ raw_audio = loaded
312
+ from_file = True
313
+
314
+ if sampling_rate is not None:
315
+ if sampling_rate != self.audio_sampling_rate:
316
+ raise ValueError(
317
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
318
+ f" {self.audio_sampling_rate}. Please make sure that the provided audio input was sampled with"
319
+ f" {self.audio_sampling_rate} and not {sampling_rate}."
320
+ )
321
+ elif not from_file:
322
+ logger.warning(
323
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
324
+ "Failing to do so can result in silent errors that might be hard to debug."
325
+ )
326
+
327
+ if isinstance(raw_audio, list):
328
+ raw_audio = [self._reflect_pad(x).T for x in raw_audio]
329
+ else:
330
+ raw_audio = self._reflect_pad(raw_audio).T
331
+
332
+ # verify inputs are valid
333
+ for example in raw_audio:
334
+ if example.ndim > 2:
335
+ raise ValueError(
336
+ f"Expected input shape (channels, num_samples), but got shape ({example.shape})"
337
+ )
338
+
339
+ lengths = torch.tensor([x.size(0) for x in raw_audio])
340
+ input_values = pad_sequence(raw_audio, batch_first=True).transpose(1, 2)
341
+ padding_mask = torch.arange(lengths.max())[None] < lengths[:, None]
342
+
343
+ return BatchFeature(
344
+ {"input_values": input_values, "padding_mask": padding_mask}
345
+ )
346
+
347
+ def __call__(
348
+ self,
349
+ text: Optional[str] = None,
350
+ input_audio: Optional[
351
+ str | list[str] | torch.Tensor | list[torch.Tensor]
352
+ ] = None,
353
+ separated_audio: Optional[
354
+ str | list[str] | torch.Tensor | list[torch.Tensor]
355
+ ] = None,
356
+ sampling_rate: Optional[int] = None,
357
+ **kwargs,
358
+ ):
359
+ batch = BatchFeature()
360
+ if text is not None:
361
+ batch.update(
362
+ self.tokenizer(
363
+ text,
364
+ return_tensors="pt",
365
+ padding="longest",
366
+ max_length=512,
367
+ truncation=True,
368
+ )
369
+ )
370
+
371
+ if input_audio is not None:
372
+ batch.update(self._process_audio(input_audio, sampling_rate))
373
+
374
+ if separated_audio is not None:
375
+ batch["separated_values"] = self._process_audio(
376
+ separated_audio, sampling_rate
377
+ )["input_values"]
378
+
379
+ return batch
380
+
381
+
382
+ __all__ = ["SAMAudioProcessor", "SAMAudioJudgeProcessor", "Batch"]
sam_audio/ranking/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from sam_audio.model.config import (
4
+ ClapRankerConfig,
5
+ EnsembleRankerConfig,
6
+ ImageBindRankerConfig,
7
+ JudgeRankerConfig,
8
+ )
9
+ from sam_audio.ranking.clap import ClapRanker
10
+ from sam_audio.ranking.imagebind import ImageBindRanker
11
+ from sam_audio.ranking.judge import JudgeRanker
12
+ from sam_audio.ranking.ranker import EnsembleRanker
13
+
14
+
15
+ def create_ranker(config):
16
+ if isinstance(config, ImageBindRankerConfig):
17
+ return ImageBindRanker(config)
18
+ elif isinstance(config, ClapRankerConfig):
19
+ return ClapRanker(config)
20
+ elif isinstance(config, JudgeRankerConfig):
21
+ return JudgeRanker(config)
22
+ elif isinstance(config, EnsembleRankerConfig):
23
+ ranker_cfgs, weights = zip(*config.rankers.values(), strict=False)
24
+ return EnsembleRanker(
25
+ rankers=[create_ranker(cfg) for cfg in ranker_cfgs],
26
+ weights=weights,
27
+ )
28
+ else:
29
+ assert config is None
30
+ return None
sam_audio/ranking/clap.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import torch
4
+ import torchaudio
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from sam_audio.model.config import ClapRankerConfig
8
+ from sam_audio.ranking.ranker import Ranker
9
+
10
+
11
+ def get_model(checkpoint_file=None, device="cpu"):
12
+ import laion_clap
13
+
14
+ model = laion_clap.CLAP_Module(enable_fusion=False, amodel="HTSAT-tiny").to(device)
15
+
16
+ if checkpoint_file is None:
17
+ checkpoint_file = hf_hub_download(
18
+ repo_id="lukewys/laion_clap", filename="630k-best.pt"
19
+ )
20
+ state_dict = torch.load(checkpoint_file, map_location=device, weights_only=False)[
21
+ "state_dict"
22
+ ]
23
+ if next(iter(state_dict.items()))[0].startswith("module"):
24
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
25
+
26
+ if "text_branch.embeddings.position_ids" in state_dict:
27
+ del state_dict["text_branch.embeddings.position_ids"]
28
+
29
+ model.model.load_state_dict(state_dict)
30
+ return model.eval()
31
+
32
+
33
+ class ClapRanker(Ranker):
34
+ def __init__(self, config: ClapRankerConfig):
35
+ from laion_clap.training import data
36
+
37
+ self.laion_data_module = data
38
+ super().__init__()
39
+ self.config = config
40
+ self.model = get_model(checkpoint_file=config.checkpoint)
41
+
42
+ def _prepare_audio(self, audio, sample_rate):
43
+ audio_features = []
44
+ for candidates in audio:
45
+ if sample_rate != 48_000:
46
+ candidates = torchaudio.functional.resample(
47
+ candidates, sample_rate, 48000
48
+ )
49
+
50
+ quantized = self.laion_data_module.int16_to_float32_torch(
51
+ self.laion_data_module.float32_to_int16_torch(candidates)
52
+ ).float()
53
+ for sample in quantized:
54
+ temp_dict = {}
55
+ temp_dict = self.laion_data_module.get_audio_features(
56
+ temp_dict,
57
+ sample,
58
+ 480000,
59
+ data_truncating=(
60
+ "fusion" if self.model.enable_fusion else "rand_trunc"
61
+ ),
62
+ data_filling="repeatpad",
63
+ audio_cfg=self.model.model_cfg["audio_cfg"],
64
+ require_grad=False,
65
+ )
66
+ audio_features.append(temp_dict)
67
+ return audio_features
68
+
69
+ @torch.inference_mode()
70
+ def forward(
71
+ self,
72
+ extracted_audio: list[torch.Tensor],
73
+ descriptions: list[str],
74
+ sample_rate: int = 48_000,
75
+ **kwargs,
76
+ ):
77
+ audio_embed = self.model.model.get_audio_embedding(
78
+ self._prepare_audio(extracted_audio, sample_rate)
79
+ )
80
+ text_embed = self.model.get_text_embedding(descriptions, use_tensor=True)
81
+ bsz = len(extracted_audio)
82
+ candidates = len(audio_embed) // bsz
83
+ audio_embed = audio_embed.reshape(bsz, candidates, -1)
84
+ text_embed = text_embed.reshape(bsz, -1, 1)
85
+ scores = audio_embed @ text_embed
86
+ return scores.squeeze(-1)
sam_audio/ranking/imagebind.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import math
4
+ from typing import List, Union
5
+
6
+ import torch
7
+ import torchaudio
8
+
9
+ from sam_audio.model.config import ImageBindRankerConfig
10
+ from sam_audio.ranking.ranker import Ranker
11
+
12
+ try:
13
+ from imagebind.data import (
14
+ ConstantClipsPerVideoSampler,
15
+ NormalizeVideo,
16
+ SpatialCrop,
17
+ get_clip_timepoints,
18
+ load_and_transform_video_data,
19
+ pv_transforms,
20
+ transforms,
21
+ waveform2melspec,
22
+ )
23
+ from imagebind.models.imagebind_model import ModalityType, imagebind_huge
24
+
25
+ __imagebind_exists__ = True
26
+ except ImportError:
27
+ __imagebind_exists__ = False
28
+
29
+
30
+ def load_and_transform_audio_data(
31
+ audios: List[Union[str, torch.Tensor]],
32
+ input_sample_rate=None,
33
+ num_mel_bins=128,
34
+ target_length=204,
35
+ sample_rate=16000,
36
+ clip_duration=2,
37
+ clips_per_video=3,
38
+ mean=-4.268,
39
+ std=9.138,
40
+ device=None,
41
+ ):
42
+ if audios is None:
43
+ return None
44
+
45
+ audio_outputs = []
46
+ clip_sampler = ConstantClipsPerVideoSampler(
47
+ clip_duration=clip_duration, clips_per_video=clips_per_video
48
+ )
49
+
50
+ for audio in audios:
51
+ if isinstance(audio, str):
52
+ waveform, input_sample_rate = torchaudio.load(audio)
53
+ else:
54
+ assert torch.is_tensor(audio)
55
+ assert sample_rate is not None
56
+ # Preprocessing needs to be done in full precision
57
+ waveform = audio.float()
58
+ if waveform.ndim == 1:
59
+ waveform = waveform[None]
60
+ if sample_rate != input_sample_rate:
61
+ waveform = torchaudio.functional.resample(
62
+ waveform, orig_freq=input_sample_rate, new_freq=sample_rate
63
+ )
64
+ all_clips_timepoints = get_clip_timepoints(
65
+ clip_sampler, waveform.size(1) / sample_rate
66
+ )
67
+ all_clips = []
68
+ for clip_timepoints in all_clips_timepoints:
69
+ waveform_clip = waveform[
70
+ :,
71
+ int(clip_timepoints[0] * sample_rate) : int(
72
+ clip_timepoints[1] * sample_rate
73
+ ),
74
+ ]
75
+ waveform_melspec = waveform2melspec(
76
+ waveform_clip, sample_rate, num_mel_bins, target_length
77
+ )
78
+ all_clips.append(waveform_melspec)
79
+
80
+ normalize = transforms.Normalize(mean=mean, std=std)
81
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
82
+
83
+ all_clips = torch.stack(all_clips, dim=0)
84
+ audio_outputs.append(all_clips)
85
+
86
+ return torch.stack(audio_outputs, dim=0)
87
+
88
+
89
+ class VideoTransform:
90
+ def __init__(self, clip_duration=2, clips_per_video=5):
91
+ self.clip_duration = clip_duration
92
+ self.clips_per_video = clips_per_video
93
+ self.clip_sampler = ConstantClipsPerVideoSampler(
94
+ clip_duration=clip_duration, clips_per_video=clips_per_video
95
+ )
96
+ self.video_transform = transforms.Compose(
97
+ [
98
+ pv_transforms.ShortSideScale(224),
99
+ NormalizeVideo(
100
+ mean=(0.48145466, 0.4578275, 0.40821073),
101
+ std=(0.26862954, 0.26130258, 0.27577711),
102
+ ),
103
+ ]
104
+ )
105
+ self.spatial_crop = SpatialCrop(224, num_crops=3)
106
+
107
+ def load_video_fast(self, videos, durations, **kwargs):
108
+ result = []
109
+ for video, duration in zip(videos, durations, strict=False):
110
+ nframes = video.size(0)
111
+ fps = video.size(0) / duration
112
+ timepoints = get_clip_timepoints(
113
+ self.clip_sampler,
114
+ duration,
115
+ )
116
+ # Instead of loading 5 2s clips, and then sub-sampling frames, we figure
117
+ # Out the indices of the final clips we want and only decode those.
118
+ all_idxs = []
119
+ for start_time, end_time in timepoints:
120
+ idxs = torch.arange(
121
+ min(int(math.ceil(fps * start_time)), nframes - 1),
122
+ min(int(math.ceil(fps * end_time)), nframes),
123
+ )
124
+ ts = (
125
+ torch.linspace(0, idxs.size(0) - 1, self.clip_duration)
126
+ .clamp(max=idxs.size(0) - 1)
127
+ .long()
128
+ )
129
+ all_idxs.append(idxs[ts])
130
+ all_idxs = torch.cat(all_idxs)
131
+ fast_frames = video[all_idxs].transpose(0, 1)
132
+ result.append(fast_frames.chunk(self.clips_per_video, dim=1))
133
+ return result
134
+
135
+ def transform_video(self, batch, device=None):
136
+ device = device or torch.device("cpu")
137
+ video_outputs = []
138
+ for all_video in batch:
139
+ all_video = [
140
+ self.video_transform(clip.to(device) / 255.0) for clip in all_video
141
+ ]
142
+ all_video = self.spatial_crop(all_video)
143
+ all_video = torch.stack(all_video, dim=0)
144
+ video_outputs.append(all_video)
145
+ return torch.stack(video_outputs, dim=0)
146
+
147
+ def __call__(self, videos, durations, device=None):
148
+ return self.transform_video(
149
+ self.load_video_fast(videos, durations), device=device
150
+ )
151
+
152
+
153
+ class ImageBindRanker(Ranker):
154
+ def __init__(self, cfg: ImageBindRankerConfig):
155
+ super().__init__()
156
+ assert __imagebind_exists__, (
157
+ "Install ImageBind in order to use this ranker: https://github.com/facebookresearch/ImageBind/tree/main"
158
+ )
159
+
160
+ self.model = imagebind_huge(pretrained=cfg.checkpoint is None)
161
+ if cfg.checkpoint is not None:
162
+ self.model.load_state_dict(torch.load(cfg.checkpoint, map_location="cpu"))
163
+ self.model = self.model.eval()
164
+ self.video_transform = VideoTransform()
165
+
166
+ @torch.inference_mode()
167
+ def forward(
168
+ self,
169
+ extracted_audio: list[torch.Tensor],
170
+ videos: list[torch.Tensor | str],
171
+ sample_rate: int = 48_000,
172
+ **kwargs,
173
+ ):
174
+ audio_data = torch.cat(
175
+ [
176
+ load_and_transform_audio_data(x, input_sample_rate=sample_rate)
177
+ for x in extracted_audio
178
+ ],
179
+ dim=0,
180
+ )
181
+ if isinstance(videos[0], str):
182
+ video_data = load_and_transform_video_data(videos)
183
+ else:
184
+ durations = [x.size(-1) / sample_rate for x in extracted_audio]
185
+ video_data = self.video_transform(videos, durations, audio_data.device)
186
+
187
+ inputs = {ModalityType.AUDIO: audio_data, ModalityType.VISION: video_data}
188
+ embs = self.model(inputs)
189
+ audio_embs, video_embs = embs[ModalityType.AUDIO], embs[ModalityType.VISION]
190
+ audio_embs, video_embs = (
191
+ audio_embs / ((audio_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
192
+ video_embs / ((video_embs**2).sum(dim=-1, keepdims=True) ** 0.5),
193
+ )
194
+ bsz = len(extracted_audio)
195
+ candidates = len(audio_embs) // bsz
196
+ scores = audio_embs.view(bsz, candidates, -1) @ video_embs.view(bsz, -1, 1)
197
+ return scores
sam_audio/ranking/judge.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ import torch
4
+
5
+ from ..model.config import JudgeRankerConfig
6
+ from ..model.judge import SAMAudioJudgeModel
7
+ from ..processor import SAMAudioJudgeProcessor
8
+ from .ranker import Ranker
9
+
10
+
11
+ class JudgeRanker(Ranker):
12
+ def __init__(self, config: JudgeRankerConfig):
13
+ super().__init__()
14
+ self.config = config
15
+ self.model = SAMAudioJudgeModel.from_pretrained(config.checkpoint_or_model_id)
16
+ self.processor = SAMAudioJudgeProcessor.from_pretrained(
17
+ config.checkpoint_or_model_id
18
+ )
19
+
20
+ @torch.inference_mode()
21
+ def forward(
22
+ self,
23
+ input_audio: list[torch.Tensor],
24
+ extracted_audio: list[torch.Tensor],
25
+ descriptions: list[str],
26
+ sample_rate: int = 48_000,
27
+ **kwargs,
28
+ ):
29
+ bsz, ncandidates = len(input_audio), len(input_audio[0])
30
+ input_seqs = [x[None] for candidates in input_audio for x in candidates]
31
+ extracted_seqs = [x[None] for candidates in extracted_audio for x in candidates]
32
+ repeated_descriptions = [x for x in descriptions for _ in range(ncandidates)]
33
+ processed = self.processor(
34
+ text=repeated_descriptions,
35
+ input_audio=input_seqs,
36
+ separated_audio=extracted_seqs,
37
+ return_tensors="pt",
38
+ padding=True,
39
+ sampling_rate=sample_rate,
40
+ )
41
+ res = self.model(**processed.to(input_audio[0].device))
42
+ return res.overall.view(bsz, ncandidates)
sam_audio/ranking/ranker.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from abc import ABCMeta, abstractmethod
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+
9
+ class Ranker(torch.nn.Module, metaclass=ABCMeta):
10
+ @abstractmethod
11
+ def forward(self, audio: list[torch.Tensor], **kwargs) -> torch.Tensor:
12
+ """
13
+ Args:
14
+ audio: (list[torch.Tensor]) where each element in the list corresponds to
15
+ the candidates for the i'th generation (num_candidates, num_frames)
16
+ Returns:
17
+ (torch.Tensor) of shape (batch_size, num_candidates) correspoding to the ranking scores
18
+ """
19
+ pass
20
+
21
+
22
+ class EnsembleRanker(Ranker):
23
+ def __init__(self, rankers: List[Ranker], weights: List[float]):
24
+ super().__init__()
25
+ assert len(rankers) == len(weights)
26
+ self.rankers = torch.nn.ModuleList(rankers)
27
+ self.weights = weights
28
+
29
+ def forward(self, **kwargs) -> torch.Tensor:
30
+ result = None
31
+ for weight, ranker in zip(self.weights, self.rankers, strict=False):
32
+ if result is None:
33
+ result = weight * ranker(**kwargs)
34
+ else:
35
+ result += weight * ranker(**kwargs)
36
+ return result
sam_audio/ranking/sound_activity.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
2
+
3
+ from io import BytesIO
4
+ from typing import Tuple, Union
5
+
6
+ import torch
7
+ from torchcodec.encoders import AudioEncoder
8
+
9
+ from ..model.config import SoundActivityRankerConfig
10
+ from .ranker import Ranker
11
+
12
+ try:
13
+ import pydub
14
+ except ImportError:
15
+ pydub = None
16
+
17
+
18
+ def get_peak_rms(audio, win_ms=250, hop_ms=100):
19
+ """
20
+ win_length and hop_length are in ms
21
+ """
22
+ last_slice_start = len(audio) - win_ms
23
+ slice_starts = range(0, last_slice_start + 1, hop_ms)
24
+ peak_rms = -1
25
+ for i in slice_starts:
26
+ audio_slice = audio[i : i + win_ms]
27
+ peak_rms = max(peak_rms, audio_slice.rms / audio.max_possible_amplitude)
28
+ # Ensure peak_rms is positive
29
+ peak_rms = max(peak_rms, 0)
30
+ return peak_rms
31
+
32
+
33
+ def torch_tensor_to_pydub(wav: torch.Tensor, sample_rate: int):
34
+ bytesio = BytesIO()
35
+ encoder = AudioEncoder(wav, sample_rate=sample_rate)
36
+ encoder.to_file_like(bytesio, format="wav")
37
+ bytesio.seek(0)
38
+ audio = pydub.AudioSegment.from_file(bytesio, format="wav")
39
+ return audio
40
+
41
+
42
+ def detect_nonsilent(
43
+ path: Union[str, Tuple[torch.Tensor, int]], # either a file path or pair wav & sr
44
+ min_sil_ms=250,
45
+ sil_threshold=-40,
46
+ threshold_mode="rel_to_max",
47
+ ):
48
+ TH_MODES = {"abs", "rel_to_max"}
49
+ SAMPLE_RATE = 24_000
50
+ assert threshold_mode in TH_MODES, f"{threshold_mode=} not in {TH_MODES}"
51
+ if isinstance(path, str):
52
+ audio = pydub.AudioSegment.from_file(path)
53
+ else: # tuple of (tensor, sr)
54
+ audio = torch_tensor_to_pydub(path[0], path[1])
55
+ audio = audio.set_frame_rate(SAMPLE_RATE)
56
+ if threshold_mode == "rel_to_max":
57
+ peak_rms = get_peak_rms(audio)
58
+ sil_threshold = sil_threshold + pydub.utils.ratio_to_db(
59
+ peak_rms
60
+ ) # convert to absolute db threshold
61
+ elif threshold_mode == "abs":
62
+ pass
63
+ else:
64
+ raise NotImplementedError(f"Unknown threshold_mode '{threshold_mode}'")
65
+ spans = pydub.silence.detect_nonsilent(
66
+ audio, min_silence_len=min_sil_ms, silence_thresh=sil_threshold, seek_step=10
67
+ )
68
+ spans = [(round(start / 1000, 3), round(end / 1000, 3)) for start, end in spans]
69
+ return spans
70
+
71
+
72
+ def compute_iou_recall_precision(hyp_spans, ref_spans):
73
+ def span_length(span):
74
+ return span[1] - span[0]
75
+
76
+ def intersection_length(span1, span2):
77
+ return max(0, min(span1[1], span2[1]) - max(span1[0], span2[0]))
78
+
79
+ total_hyp_length = sum(span_length(span) for span in hyp_spans)
80
+ total_ref_length = sum(span_length(span) for span in ref_spans)
81
+ total_intersection = 0
82
+ for hyp_span in hyp_spans:
83
+ for ref_span in ref_spans:
84
+ total_intersection += intersection_length(hyp_span, ref_span)
85
+
86
+ union_spans = hyp_spans + ref_spans # Combine both lists to compute union
87
+ union_length = sum(span_length(span) for span in union_spans) - total_intersection
88
+
89
+ iou = total_intersection / union_length if union_length > 0 else 0
90
+ recall = total_intersection / total_ref_length if total_ref_length > 0 else 0
91
+ precision = total_intersection / total_hyp_length if total_hyp_length > 0 else 0
92
+
93
+ return {"iou": iou, "recall": recall, "precision": precision}
94
+
95
+
96
+ class SoundActivityRanker(Ranker):
97
+ def __init__(self, config: SoundActivityRankerConfig):
98
+ if pydub is None:
99
+ raise ImportError(
100
+ 'Install reranking dependencies: `pip install "sam-audio[reranking]"`'
101
+ )
102
+ super().__init__()
103
+ self.config = config
104
+
105
+ @torch.inference_mode()
106
+ def forward(
107
+ self,
108
+ extracted_audio: list[torch.Tensor],
109
+ spans: list[list[list[float]]],
110
+ sample_rate: int = 48_000,
111
+ **kwargs,
112
+ ):
113
+ device = extracted_audio[0].device
114
+ scores = []
115
+ for wav, current_spans in zip(extracted_audio, spans, strict=True):
116
+ wav = wav.to(torch.float32).cpu()
117
+ # get non-silent spans
118
+ hyp_spans = detect_nonsilent(
119
+ (wav, sample_rate),
120
+ sil_threshold=self.config.sil_threshold,
121
+ threshold_mode=self.config.threshold_mode,
122
+ )
123
+ timestamps = [[span[1], span[2]] for span in current_spans]
124
+ result = compute_iou_recall_precision(hyp_spans, timestamps)
125
+ scores.append(result[self.config.metric])
126
+
127
+ # convert to tensor
128
+ scores = torch.tensor(scores, device=device)
129
+ return scores