YWMditto commited on
Commit
b06b6e7
·
1 Parent(s): eb293bd

update readme

Browse files
Files changed (1) hide show
  1. README.md +50 -6
README.md CHANGED
@@ -186,8 +186,8 @@ MOSS-TTS provides a convenient `generate` interface for rapid usage. The example
186
  3. Duration control
187
 
188
  ```python
189
- import os
190
  from pathlib import Path
 
191
  import torch
192
  import torchaudio
193
  from transformers import AutoModel, AutoProcessor
@@ -203,6 +203,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
203
  device = "cuda" if torch.cuda.is_available() else "cpu"
204
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  processor = AutoProcessor.from_pretrained(
207
  pretrained_model_name_or_path,
208
  trust_remote_code=True,
@@ -239,14 +261,14 @@ conversations = [
239
  model = AutoModel.from_pretrained(
240
  pretrained_model_name_or_path,
241
  trust_remote_code=True,
242
- attn_implementation="sdpa",
 
243
  torch_dtype=dtype,
244
  ).to(device)
245
  model.eval()
246
 
247
  batch_size = 1
248
 
249
- messages = []
250
  save_dir = Path("inference_root")
251
  save_dir.mkdir(exist_ok=True, parents=True)
252
  sample_idx = 0
@@ -276,8 +298,8 @@ with torch.no_grad():
276
  MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
277
 
278
  ```python
279
- import os
280
  from pathlib import Path
 
281
  import torch
282
  import torchaudio
283
  from transformers import AutoModel, AutoProcessor
@@ -293,6 +315,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS"
293
  device = "cuda" if torch.cuda.is_available() else "cpu"
294
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  processor = AutoProcessor.from_pretrained(
297
  pretrained_model_name_or_path,
298
  trust_remote_code=True
@@ -322,14 +366,14 @@ conversations = [
322
  model = AutoModel.from_pretrained(
323
  pretrained_model_name_or_path,
324
  trust_remote_code=True,
325
- attn_implementation="sdpa",
 
326
  torch_dtype=dtype,
327
  ).to(device)
328
  model.eval()
329
 
330
  batch_size = 1
331
 
332
- messages = []
333
  save_dir = Path("inference_root")
334
  save_dir.mkdir(exist_ok=True, parents=True)
335
  sample_idx = 0
 
186
  3. Duration control
187
 
188
  ```python
 
189
  from pathlib import Path
190
+ import importlib.util
191
  import torch
192
  import torchaudio
193
  from transformers import AutoModel, AutoProcessor
 
203
  device = "cuda" if torch.cuda.is_available() else "cpu"
204
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
205
 
206
+ def resolve_attn_implementation() -> str:
207
+ # Prefer FlashAttention 2 when package + device conditions are met.
208
+ if (
209
+ device == "cuda"
210
+ and importlib.util.find_spec("flash_attn") is not None
211
+ and dtype in {torch.float16, torch.bfloat16}
212
+ ):
213
+ major, _ = torch.cuda.get_device_capability()
214
+ if major >= 8:
215
+ return "flash_attention_2"
216
+
217
+ # CUDA fallback: use PyTorch SDPA kernels.
218
+ if device == "cuda":
219
+ return "sdpa"
220
+
221
+ # CPU fallback.
222
+ return "eager"
223
+
224
+
225
+ attn_implementation = resolve_attn_implementation()
226
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
227
+
228
  processor = AutoProcessor.from_pretrained(
229
  pretrained_model_name_or_path,
230
  trust_remote_code=True,
 
261
  model = AutoModel.from_pretrained(
262
  pretrained_model_name_or_path,
263
  trust_remote_code=True,
264
+ # If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
265
+ attn_implementation=attn_implementation,
266
  torch_dtype=dtype,
267
  ).to(device)
268
  model.eval()
269
 
270
  batch_size = 1
271
 
 
272
  save_dir = Path("inference_root")
273
  save_dir.mkdir(exist_ok=True, parents=True)
274
  sample_idx = 0
 
298
  MOSS-TTS supports continuation-based cloning: provide a prefix audio clip in the assistant message, and make sure the **prefix transcript** is included in the text. The model continues in the same speaker identity and style.
299
 
300
  ```python
 
301
  from pathlib import Path
302
+ import importlib.util
303
  import torch
304
  import torchaudio
305
  from transformers import AutoModel, AutoProcessor
 
315
  device = "cuda" if torch.cuda.is_available() else "cpu"
316
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
317
 
318
+ def resolve_attn_implementation() -> str:
319
+ # Prefer FlashAttention 2 when package + device conditions are met.
320
+ if (
321
+ device == "cuda"
322
+ and importlib.util.find_spec("flash_attn") is not None
323
+ and dtype in {torch.float16, torch.bfloat16}
324
+ ):
325
+ major, _ = torch.cuda.get_device_capability()
326
+ if major >= 8:
327
+ return "flash_attention_2"
328
+
329
+ # CUDA fallback: use PyTorch SDPA kernels.
330
+ if device == "cuda":
331
+ return "sdpa"
332
+
333
+ # CPU fallback.
334
+ return "eager"
335
+
336
+
337
+ attn_implementation = resolve_attn_implementation()
338
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
339
+
340
  processor = AutoProcessor.from_pretrained(
341
  pretrained_model_name_or_path,
342
  trust_remote_code=True
 
366
  model = AutoModel.from_pretrained(
367
  pretrained_model_name_or_path,
368
  trust_remote_code=True,
369
+ # If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
370
+ attn_implementation=attn_implementation,
371
  torch_dtype=dtype,
372
  ).to(device)
373
  model.eval()
374
 
375
  batch_size = 1
376
 
 
377
  save_dir = Path("inference_root")
378
  save_dir.mkdir(exist_ok=True, parents=True)
379
  sample_idx = 0