update readme
Browse files
README.md
CHANGED
|
@@ -186,8 +186,8 @@ MOSS-TTS provides a convenient `generate` interface for rapid usage. The example
|
|
| 186 |
3. Duration control
|
| 187 |
|
| 188 |
```python
|
| 189 |
-
import os
|
| 190 |
from pathlib import Path
|
|
|
|
| 191 |
import torch
|
| 192 |
import torchaudio
|
| 193 |
from transformers import AutoModel, AutoProcessor
|
|
@@ -203,6 +203,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
|
|
| 203 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 204 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
processor = AutoProcessor.from_pretrained(
|
| 207 |
pretrained_model_name_or_path,
|
| 208 |
trust_remote_code=True,
|
|
@@ -239,14 +261,14 @@ conversations = [
|
|
| 239 |
model = AutoModel.from_pretrained(
|
| 240 |
pretrained_model_name_or_path,
|
| 241 |
trust_remote_code=True,
|
| 242 |
-
attn_implementation="
|
|
|
|
| 243 |
torch_dtype=dtype,
|
| 244 |
).to(device)
|
| 245 |
model.eval()
|
| 246 |
|
| 247 |
batch_size = 1
|
| 248 |
|
| 249 |
-
messages = []
|
| 250 |
save_dir = Path("inference_root")
|
| 251 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 252 |
sample_idx = 0
|
|
@@ -276,8 +298,8 @@ with torch.no_grad():
|
|
| 276 |
MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
|
| 277 |
|
| 278 |
```python
|
| 279 |
-
import os
|
| 280 |
from pathlib import Path
|
|
|
|
| 281 |
import torch
|
| 282 |
import torchaudio
|
| 283 |
from transformers import AutoModel, AutoProcessor
|
|
@@ -293,6 +315,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
|
|
| 293 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 294 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
processor = AutoProcessor.from_pretrained(
|
| 297 |
pretrained_model_name_or_path,
|
| 298 |
trust_remote_code=True
|
|
@@ -322,14 +366,14 @@ conversations = [
|
|
| 322 |
model = AutoModel.from_pretrained(
|
| 323 |
pretrained_model_name_or_path,
|
| 324 |
trust_remote_code=True,
|
| 325 |
-
attn_implementation="
|
|
|
|
| 326 |
torch_dtype=dtype,
|
| 327 |
).to(device)
|
| 328 |
model.eval()
|
| 329 |
|
| 330 |
batch_size = 1
|
| 331 |
|
| 332 |
-
messages = []
|
| 333 |
save_dir = Path("inference_root")
|
| 334 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 335 |
sample_idx = 0
|
|
|
|
| 186 |
3. Duration control
|
| 187 |
|
| 188 |
```python
|
|
|
|
| 189 |
from pathlib import Path
|
| 190 |
+
import importlib.util
|
| 191 |
import torch
|
| 192 |
import torchaudio
|
| 193 |
from transformers import AutoModel, AutoProcessor
|
|
|
|
| 203 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 204 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 205 |
|
| 206 |
+
def resolve_attn_implementation() -> str:
|
| 207 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 208 |
+
if (
|
| 209 |
+
device == "cuda"
|
| 210 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 211 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 212 |
+
):
|
| 213 |
+
major, _ = torch.cuda.get_device_capability()
|
| 214 |
+
if major >= 8:
|
| 215 |
+
return "flash_attention_2"
|
| 216 |
+
|
| 217 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 218 |
+
if device == "cuda":
|
| 219 |
+
return "sdpa"
|
| 220 |
+
|
| 221 |
+
# CPU fallback.
|
| 222 |
+
return "eager"
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
attn_implementation = resolve_attn_implementation()
|
| 226 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 227 |
+
|
| 228 |
processor = AutoProcessor.from_pretrained(
|
| 229 |
pretrained_model_name_or_path,
|
| 230 |
trust_remote_code=True,
|
|
|
|
| 261 |
model = AutoModel.from_pretrained(
|
| 262 |
pretrained_model_name_or_path,
|
| 263 |
trust_remote_code=True,
|
| 264 |
+
# If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
|
| 265 |
+
attn_implementation=attn_implementation,
|
| 266 |
torch_dtype=dtype,
|
| 267 |
).to(device)
|
| 268 |
model.eval()
|
| 269 |
|
| 270 |
batch_size = 1
|
| 271 |
|
|
|
|
| 272 |
save_dir = Path("inference_root")
|
| 273 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 274 |
sample_idx = 0
|
|
|
|
| 298 |
MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
|
| 299 |
|
| 300 |
```python
|
|
|
|
| 301 |
from pathlib import Path
|
| 302 |
+
import importlib.util
|
| 303 |
import torch
|
| 304 |
import torchaudio
|
| 305 |
from transformers import AutoModel, AutoProcessor
|
|
|
|
| 315 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 316 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 317 |
|
| 318 |
+
def resolve_attn_implementation() -> str:
|
| 319 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 320 |
+
if (
|
| 321 |
+
device == "cuda"
|
| 322 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 323 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 324 |
+
):
|
| 325 |
+
major, _ = torch.cuda.get_device_capability()
|
| 326 |
+
if major >= 8:
|
| 327 |
+
return "flash_attention_2"
|
| 328 |
+
|
| 329 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 330 |
+
if device == "cuda":
|
| 331 |
+
return "sdpa"
|
| 332 |
+
|
| 333 |
+
# CPU fallback.
|
| 334 |
+
return "eager"
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
attn_implementation = resolve_attn_implementation()
|
| 338 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 339 |
+
|
| 340 |
processor = AutoProcessor.from_pretrained(
|
| 341 |
pretrained_model_name_or_path,
|
| 342 |
trust_remote_code=True
|
|
|
|
| 366 |
model = AutoModel.from_pretrained(
|
| 367 |
pretrained_model_name_or_path,
|
| 368 |
trust_remote_code=True,
|
| 369 |
+
# If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
|
| 370 |
+
attn_implementation=attn_implementation,
|
| 371 |
torch_dtype=dtype,
|
| 372 |
).to(device)
|
| 373 |
model.eval()
|
| 374 |
|
| 375 |
batch_size = 1
|
| 376 |
|
|
|
|
| 377 |
save_dir = Path("inference_root")
|
| 378 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 379 |
sample_idx = 0
|