YWMditto commited on
Commit
3fa84fb
·
1 Parent(s): ea29ade

update readme

Browse files
Files changed (1) hide show
  1. README.md +56 -16
README.md CHANGED
@@ -183,7 +183,7 @@ MOSS-TTS provides a convenient `generate` interface for rapid usage. The example
183
  3. Duration control
184
 
185
  ```python
186
- import os
187
  from pathlib import Path
188
  import torch
189
  import torchaudio
@@ -222,6 +222,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
222
  device = "cuda" if torch.cuda.is_available() else "cpu"
223
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  processor = AutoProcessor.from_pretrained(
226
  pretrained_model_name_or_path,
227
  trust_remote_code=True,
@@ -286,7 +308,7 @@ conversations = [
286
  model = AutoModel.from_pretrained(
287
  pretrained_model_name_or_path,
288
  trust_remote_code=True,
289
- attn_implementation="sdpa",
290
  torch_dtype=dtype,
291
  ).to(device)
292
  model.eval()
@@ -312,7 +334,6 @@ generation_config.layers = [
312
 
313
  batch_size = 1
314
 
315
- messages = []
316
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
317
  save_dir.mkdir(exist_ok=True, parents=True)
318
  sample_idx = 0
@@ -330,11 +351,10 @@ with torch.no_grad():
330
  )
331
 
332
  for message in processor.decode(outputs):
333
- for seg_idx, audio in enumerate(message.audio_codes_list):
334
- # audio is a waveform tensor after decode_audio_codes
335
- out_path = save_dir / f"sample{sample_idx}_seg{seg_idx}.wav"
336
- sample_idx += 1
337
- torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
338
 
339
  ```
340
 
@@ -343,7 +363,7 @@ with torch.no_grad():
343
  MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
344
 
345
  ```python
346
- import os
347
  from pathlib import Path
348
  import torch
349
  import torchaudio
@@ -380,6 +400,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
380
  device = "cuda" if torch.cuda.is_available() else "cpu"
381
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  processor = AutoProcessor.from_pretrained(
384
  pretrained_model_name_or_path,
385
  trust_remote_code=True,
@@ -414,7 +456,7 @@ conversations = [
414
  model = AutoModel.from_pretrained(
415
  pretrained_model_name_or_path,
416
  trust_remote_code=True,
417
- attn_implementation="sdpa",
418
  torch_dtype=dtype,
419
  ).to(device)
420
  model.eval()
@@ -441,7 +483,6 @@ generation_config.layers = [
441
 
442
  batch_size = 1
443
 
444
- messages = []
445
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
446
  save_dir.mkdir(exist_ok=True, parents=True)
447
  sample_idx = 0
@@ -459,11 +500,10 @@ with torch.no_grad():
459
  )
460
 
461
  for message in processor.decode(outputs):
462
- for seg_idx, audio in enumerate(message.audio_codes_list):
463
- # audio is a waveform tensor after decode_audio_codes
464
- out_path = save_dir / f"sample{sample_idx}_seg{seg_idx}.wav"
465
- sample_idx += 1
466
- torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
467
 
468
  ```
469
 
 
183
  3. Duration control
184
 
185
  ```python
186
+ import importlib.util
187
  from pathlib import Path
188
  import torch
189
  import torchaudio
 
222
  device = "cuda" if torch.cuda.is_available() else "cpu"
223
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
224
 
225
+ def resolve_attn_implementation() -> str:
226
+ # Prefer FlashAttention 2 when package + device conditions are met.
227
+ if (
228
+ device == "cuda"
229
+ and importlib.util.find_spec("flash_attn") is not None
230
+ and dtype in {torch.float16, torch.bfloat16}
231
+ ):
232
+ major, _ = torch.cuda.get_device_capability()
233
+ if major >= 8:
234
+ return "flash_attention_2"
235
+
236
+ # CUDA fallback: use PyTorch SDPA kernels.
237
+ if device == "cuda":
238
+ return "sdpa"
239
+
240
+ # CPU fallback.
241
+ return "eager"
242
+
243
+
244
+ attn_implementation = resolve_attn_implementation()
245
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
246
+
247
  processor = AutoProcessor.from_pretrained(
248
  pretrained_model_name_or_path,
249
  trust_remote_code=True,
 
308
  model = AutoModel.from_pretrained(
309
  pretrained_model_name_or_path,
310
  trust_remote_code=True,
311
+ attn_implementation=attn_implementation,
312
  torch_dtype=dtype,
313
  ).to(device)
314
  model.eval()
 
334
 
335
  batch_size = 1
336
 
 
337
  save_dir = Path(f"inference_root_moss_tts_local_transformer_generation")
338
  save_dir.mkdir(exist_ok=True, parents=True)
339
  sample_idx = 0
 
351
  )
352
 
353
  for message in processor.decode(outputs):
354
+ audio = message.audio_codes_list[0]
355
+ out_path = save_dir / f"sample{sample_idx}.wav"
356
+ sample_idx += 1
357
+ torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
358
 
359
  ```
360
 
 
363
  MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
364
 
365
  ```python
366
+ import importlib.util
367
  from pathlib import Path
368
  import torch
369
  import torchaudio
 
400
  device = "cuda" if torch.cuda.is_available() else "cpu"
401
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
402
 
403
+ def resolve_attn_implementation() -> str:
404
+ # Prefer FlashAttention 2 when package + device conditions are met.
405
+ if (
406
+ device == "cuda"
407
+ and importlib.util.find_spec("flash_attn") is not None
408
+ and dtype in {torch.float16, torch.bfloat16}
409
+ ):
410
+ major, _ = torch.cuda.get_device_capability()
411
+ if major >= 8:
412
+ return "flash_attention_2"
413
+
414
+ # CUDA fallback: use PyTorch SDPA kernels.
415
+ if device == "cuda":
416
+ return "sdpa"
417
+
418
+ # CPU fallback.
419
+ return "eager"
420
+
421
+
422
+ attn_implementation = resolve_attn_implementation()
423
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
424
+
425
  processor = AutoProcessor.from_pretrained(
426
  pretrained_model_name_or_path,
427
  trust_remote_code=True,
 
456
  model = AutoModel.from_pretrained(
457
  pretrained_model_name_or_path,
458
  trust_remote_code=True,
459
+ attn_implementation=attn_implementation,
460
  torch_dtype=dtype,
461
  ).to(device)
462
  model.eval()
 
483
 
484
  batch_size = 1
485
 
 
486
  save_dir = Path("inference_root_moss_tts_local_transformer_continuation")
487
  save_dir.mkdir(exist_ok=True, parents=True)
488
  sample_idx = 0
 
500
  )
501
 
502
  for message in processor.decode(outputs):
503
+ audio = message.audio_codes_list[0]
504
+ out_path = save_dir / f"sample{sample_idx}.wav"
505
+ sample_idx += 1
506
+ torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
507
 
508
  ```
509