Update README.md
Browse files
README.md
CHANGED
|
@@ -166,7 +166,7 @@ The code for XEUS is still in progress of being merged into the main ESPnet repo
|
|
| 166 |
pip install -e git+git://github.com/wanchichen/espnet.git@ssl
|
| 167 |
```
|
| 168 |
|
| 169 |
-
XEUS supports [Flash Attention], which can be installed as follows:
|
| 170 |
|
| 171 |
```
|
| 172 |
pip install flash-attn --no-build-isolation
|
|
@@ -174,6 +174,9 @@ pip install flash-attn --no-build-isolation
|
|
| 174 |
|
| 175 |
## Usage
|
| 176 |
|
|
|
|
|
|
|
|
|
|
| 177 |
```python
|
| 178 |
from torch.nn.utils.rnn import pad_sequence
|
| 179 |
from espnet2.tasks.ssl import SSLTask
|
|
@@ -187,6 +190,10 @@ xeus_model, xeus_train_args = SSLTask.build_model_from_file(
|
|
| 187 |
device,
|
| 188 |
)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
|
| 191 |
wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
|
| 192 |
wavs = pad_sequence([wavs], batch_first=True).to(device)
|
|
@@ -195,6 +202,25 @@ wavs = pad_sequence([wavs], batch_first=True).to(device)
|
|
| 195 |
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
|
| 196 |
```
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
## Results
|
| 199 |
|
| 200 |

|
|
|
|
| 166 |
pip install -e git+git://github.com/wanchichen/espnet.git@ssl
|
| 167 |
```
|
| 168 |
|
| 169 |
+
XEUS supports [Flash Attention](), which can be installed as follows:
|
| 170 |
|
| 171 |
```
|
| 172 |
pip install flash-attn --no-build-isolation
|
|
|
|
| 174 |
|
| 175 |
## Usage
|
| 176 |
|
| 177 |
+
|
| 178 |
+
Default Usage:
|
| 179 |
+
|
| 180 |
```python
|
| 181 |
from torch.nn.utils.rnn import pad_sequence
|
| 182 |
from espnet2.tasks.ssl import SSLTask
|
|
|
|
| 190 |
device,
|
| 191 |
)
|
| 192 |
|
| 193 |
+
use_flash_attn = False
|
| 194 |
+
[layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
|
| 195 |
+
xeus_model.use_flash_attn
|
| 196 |
+
|
| 197 |
wavs, sampling_rate = sf.read('/path/to/audio.wav') # sampling rate should be 16000
|
| 198 |
wav_lengths = torch.LongTensor([len(wav) for wav in [wavs]]).to(device)
|
| 199 |
wavs = pad_sequence([wavs], batch_first=True).to(device)
|
|
|
|
| 202 |
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1] # take the output of the last layer
|
| 203 |
```
|
| 204 |
|
| 205 |
+
With Flash Attention:
|
| 206 |
+
|
| 207 |
+
```python
|
| 208 |
+
[layer.use_flash_attn = True for layer in xeus_model.encoder.encoders]
|
| 209 |
+
|
| 210 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 211 |
+
feats = xeus_model.encode(wavs, wav_lengths, use_mask=False, use_final_output=False)[0][-1]
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
Tune the masking settings:
|
| 215 |
+
|
| 216 |
+
```python
|
| 217 |
+
|
| 218 |
+
xeus_model.masker.mask_prob = 0.65 # default 0.8
|
| 219 |
+
xeus_model.masker.mask_length = 20 # default 10
|
| 220 |
+
xeus_model.masker.mask_selection = 'static' # default uniform
|
| 221 |
+
xeus_model.train()
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
## Results
|
| 225 |
|
| 226 |

|