update readme
Browse files
README.md
CHANGED
|
@@ -183,7 +183,7 @@ MOSS-TTS provides a convenient `generate` interface for rapid usage. The example
|
|
| 183 |
3. Duration control
|
| 184 |
|
| 185 |
```python
|
| 186 |
-
import
|
| 187 |
from pathlib import Path
|
| 188 |
import torch
|
| 189 |
import torchaudio
|
|
@@ -222,6 +222,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
|
|
| 222 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 223 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
processor = AutoProcessor.from_pretrained(
|
| 226 |
pretrained_model_name_or_path,
|
| 227 |
trust_remote_code=True,
|
|
@@ -286,7 +308,7 @@ conversations = [
|
|
| 286 |
model = AutoModel.from_pretrained(
|
| 287 |
pretrained_model_name_or_path,
|
| 288 |
trust_remote_code=True,
|
| 289 |
-
attn_implementation=
|
| 290 |
torch_dtype=dtype,
|
| 291 |
).to(device)
|
| 292 |
model.eval()
|
|
@@ -312,7 +334,6 @@ generation_config.layers = [
|
|
| 312 |
|
| 313 |
batch_size = 1
|
| 314 |
|
| 315 |
-
messages = []
|
| 316 |
save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
|
| 317 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 318 |
sample_idx = 0
|
|
@@ -330,11 +351,10 @@ with torch.no_grad():
|
|
| 330 |
)
|
| 331 |
|
| 332 |
for message in processor.decode(outputs):
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
| 338 |
|
| 339 |
```
|
| 340 |
|
|
@@ -343,7 +363,7 @@ with torch.no_grad():
|
|
| 343 |
MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
|
| 344 |
|
| 345 |
```python
|
| 346 |
-
import
|
| 347 |
from pathlib import Path
|
| 348 |
import torch
|
| 349 |
import torchaudio
|
|
@@ -380,6 +400,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
|
|
| 380 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 381 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
processor = AutoProcessor.from_pretrained(
|
| 384 |
pretrained_model_name_or_path,
|
| 385 |
trust_remote_code=True,
|
|
@@ -414,7 +456,7 @@ conversations = [
|
|
| 414 |
model = AutoModel.from_pretrained(
|
| 415 |
pretrained_model_name_or_path,
|
| 416 |
trust_remote_code=True,
|
| 417 |
-
attn_implementation=
|
| 418 |
torch_dtype=dtype,
|
| 419 |
).to(device)
|
| 420 |
model.eval()
|
|
@@ -441,7 +483,6 @@ generation_config.layers = [
|
|
| 441 |
|
| 442 |
batch_size = 1
|
| 443 |
|
| 444 |
-
messages = []
|
| 445 |
save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
|
| 446 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 447 |
sample_idx = 0
|
|
@@ -459,11 +500,10 @@ with torch.no_grad():
|
|
| 459 |
)
|
| 460 |
|
| 461 |
for message in processor.decode(outputs):
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
| 467 |
|
| 468 |
```
|
| 469 |
|
|
|
|
| 183 |
3. Duration control
|
| 184 |
|
| 185 |
```python
|
| 186 |
+
import importlib.util
|
| 187 |
from pathlib import Path
|
| 188 |
import torch
|
| 189 |
import torchaudio
|
|
|
|
| 222 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 223 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 224 |
|
| 225 |
+
def resolve_attn_implementation() -> str:
|
| 226 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 227 |
+
if (
|
| 228 |
+
device == "cuda"
|
| 229 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 230 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 231 |
+
):
|
| 232 |
+
major, _ = torch.cuda.get_device_capability()
|
| 233 |
+
if major >= 8:
|
| 234 |
+
return "flash_attention_2"
|
| 235 |
+
|
| 236 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 237 |
+
if device == "cuda":
|
| 238 |
+
return "sdpa"
|
| 239 |
+
|
| 240 |
+
# CPU fallback.
|
| 241 |
+
return "eager"
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
attn_implementation = resolve_attn_implementation()
|
| 245 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 246 |
+
|
| 247 |
processor = AutoProcessor.from_pretrained(
|
| 248 |
pretrained_model_name_or_path,
|
| 249 |
trust_remote_code=True,
|
|
|
|
| 308 |
model = AutoModel.from_pretrained(
|
| 309 |
pretrained_model_name_or_path,
|
| 310 |
trust_remote_code=True,
|
| 311 |
+
attn_implementation=attn_implementation,
|
| 312 |
torch_dtype=dtype,
|
| 313 |
).to(device)
|
| 314 |
model.eval()
|
|
|
|
| 334 |
|
| 335 |
batch_size = 1
|
| 336 |
|
|
|
|
| 337 |
save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
|
| 338 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 339 |
sample_idx = 0
|
|
|
|
| 351 |
)
|
| 352 |
|
| 353 |
for message in processor.decode(outputs):
|
| 354 |
+
audio = message.audio_codes_list[0]
|
| 355 |
+
out_path = save_dir / f"sample{sample_idx}.wav"
|
| 356 |
+
sample_idx += 1
|
| 357 |
+
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
|
|
|
| 358 |
|
| 359 |
```
|
| 360 |
|
|
|
|
| 363 |
MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
|
| 364 |
|
| 365 |
```python
|
| 366 |
+
import importlib.util
|
| 367 |
from pathlib import Path
|
| 368 |
import torch
|
| 369 |
import torchaudio
|
|
|
|
| 400 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 401 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 402 |
|
| 403 |
+
def resolve_attn_implementation() -> str:
|
| 404 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 405 |
+
if (
|
| 406 |
+
device == "cuda"
|
| 407 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 408 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 409 |
+
):
|
| 410 |
+
major, _ = torch.cuda.get_device_capability()
|
| 411 |
+
if major >= 8:
|
| 412 |
+
return "flash_attention_2"
|
| 413 |
+
|
| 414 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 415 |
+
if device == "cuda":
|
| 416 |
+
return "sdpa"
|
| 417 |
+
|
| 418 |
+
# CPU fallback.
|
| 419 |
+
return "eager"
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
attn_implementation = resolve_attn_implementation()
|
| 423 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 424 |
+
|
| 425 |
processor = AutoProcessor.from_pretrained(
|
| 426 |
pretrained_model_name_or_path,
|
| 427 |
trust_remote_code=True,
|
|
|
|
| 456 |
model = AutoModel.from_pretrained(
|
| 457 |
pretrained_model_name_or_path,
|
| 458 |
trust_remote_code=True,
|
| 459 |
+
attn_implementation=attn_implementation,
|
| 460 |
torch_dtype=dtype,
|
| 461 |
).to(device)
|
| 462 |
model.eval()
|
|
|
|
| 483 |
|
| 484 |
batch_size = 1
|
| 485 |
|
|
|
|
| 486 |
save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
|
| 487 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 488 |
sample_idx = 0
|
|
|
|
| 500 |
)
|
| 501 |
|
| 502 |
for message in processor.decode(outputs):
|
| 503 |
+
audio = message.audio_codes_list[0]
|
| 504 |
+
out_path = save_dir / f"sample{sample_idx}.wav"
|
| 505 |
+
sample_idx += 1
|
| 506 |
+
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
|
|
|
| 507 |
|
| 508 |
```
|
| 509 |
|