update readme
Browse files
README.md
CHANGED
|
@@ -151,29 +151,32 @@ Notes:
|
|
| 151 |
MOSS-TTSD uses a **continuation** workflow: provide reference audio for each speaker, their transcripts as a prefix, and the dialogue text to generate. The model continues in each speaker's identity.
|
| 152 |
|
| 153 |
```python
|
| 154 |
-
import os
|
| 155 |
from pathlib import Path
|
| 156 |
import torch
|
| 157 |
import torchaudio
|
| 158 |
from transformers import AutoModel, AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTSD-v1.0"
|
| 161 |
-
audio_tokenizer_name_or_path = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
|
| 162 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 163 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 164 |
|
| 165 |
processor = AutoProcessor.from_pretrained(
|
| 166 |
pretrained_model_name_or_path,
|
| 167 |
trust_remote_code=True,
|
| 168 |
-
codec_path=audio_tokenizer_name_or_path,
|
| 169 |
)
|
| 170 |
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
|
| 171 |
-
processor.audio_tokenizer.eval()
|
| 172 |
|
| 173 |
model = AutoModel.from_pretrained(
|
| 174 |
pretrained_model_name_or_path,
|
| 175 |
trust_remote_code=True,
|
| 176 |
-
attn_implementation="flash_attention_2"
|
|
|
|
| 177 |
torch_dtype=dtype,
|
| 178 |
).to(device)
|
| 179 |
model.eval()
|
|
@@ -226,7 +229,7 @@ conversations = [
|
|
| 226 |
|
| 227 |
batch_size = 1
|
| 228 |
|
| 229 |
-
save_dir = Path("
|
| 230 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 231 |
sample_idx = 0
|
| 232 |
with torch.no_grad():
|
|
|
|
| 151 |
MOSS-TTSD uses a **continuation** workflow: provide reference audio for each speaker, their transcripts as a prefix, and the dialogue text to generate. The model continues in each speaker's identity.
|
| 152 |
|
| 153 |
```python
|
|
|
|
| 154 |
from pathlib import Path
|
| 155 |
import torch
|
| 156 |
import torchaudio
|
| 157 |
from transformers import AutoModel, AutoProcessor
|
| 158 |
+
# Disable the broken cuDNN SDPA backend
|
| 159 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 160 |
+
# Keep these enabled as fallbacks
|
| 161 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 162 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 163 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 164 |
|
| 165 |
pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTSD-v1.0"
|
|
|
|
| 166 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 167 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 168 |
|
| 169 |
processor = AutoProcessor.from_pretrained(
|
| 170 |
pretrained_model_name_or_path,
|
| 171 |
trust_remote_code=True,
|
|
|
|
| 172 |
)
|
| 173 |
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
|
|
|
|
| 174 |
|
| 175 |
model = AutoModel.from_pretrained(
|
| 176 |
pretrained_model_name_or_path,
|
| 177 |
trust_remote_code=True,
|
| 178 |
+
# If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
|
| 179 |
+
attn_implementation="sdpa",
|
| 180 |
torch_dtype=dtype,
|
| 181 |
).to(device)
|
| 182 |
model.eval()
|
|
|
|
| 229 |
|
| 230 |
batch_size = 1
|
| 231 |
|
| 232 |
+
save_dir = Path("inference_root")
|
| 233 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 234 |
sample_idx = 0
|
| 235 |
with torch.no_grad():
|