YWMditto commited on
Commit
d2d6a31
·
1 Parent(s): a25e3ca

update readme

Browse files
Files changed (1) hide show
  1. README.md +24 -1
README.md CHANGED
@@ -152,6 +152,7 @@ MOSS-TTSD uses a **continuation** workflow: provide reference audio for each spe
152
 
153
  ```python
154
  from pathlib import Path
 
155
  import torch
156
  import torchaudio
157
  from transformers import AutoModel, AutoProcessor
@@ -166,6 +167,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTSD-v1.0"
166
  device = "cuda" if torch.cuda.is_available() else "cpu"
167
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  processor = AutoProcessor.from_pretrained(
170
  pretrained_model_name_or_path,
171
  trust_remote_code=True,
@@ -176,7 +199,7 @@ model = AutoModel.from_pretrained(
176
  pretrained_model_name_or_path,
177
  trust_remote_code=True,
178
  # If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
179
- attn_implementation="sdpa",
180
  torch_dtype=dtype,
181
  ).to(device)
182
  model.eval()
 
152
 
153
  ```python
154
  from pathlib import Path
155
+ import importlib.util
156
  import torch
157
  import torchaudio
158
  from transformers import AutoModel, AutoProcessor
 
167
  device = "cuda" if torch.cuda.is_available() else "cpu"
168
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
169
 
170
+ def resolve_attn_implementation() -> str:
171
+ # Prefer FlashAttention 2 when package + device conditions are met.
172
+ if (
173
+ device == "cuda"
174
+ and importlib.util.find_spec("flash_attn") is not None
175
+ and dtype in {torch.float16, torch.bfloat16}
176
+ ):
177
+ major, _ = torch.cuda.get_device_capability()
178
+ if major >= 8:
179
+ return "flash_attention_2"
180
+
181
+ # CUDA fallback: use PyTorch SDPA kernels.
182
+ if device == "cuda":
183
+ return "sdpa"
184
+
185
+ # CPU fallback.
186
+ return "eager"
187
+
188
+
189
+ attn_implementation = resolve_attn_implementation()
190
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
191
+
192
  processor = AutoProcessor.from_pretrained(
193
  pretrained_model_name_or_path,
194
  trust_remote_code=True,
 
199
  pretrained_model_name_or_path,
200
  trust_remote_code=True,
201
  # If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
202
+ attn_implementation=attn_implementation,
203
  torch_dtype=dtype,
204
  ).to(device)
205
  model.eval()