YWMditto commited on
Commit
1b228e9
·
1 Parent(s): 49ca09f

update readme

Browse files
Files changed (1) hide show
  1. README.md +26 -3
README.md CHANGED
@@ -142,8 +142,8 @@ Notes:
142
  ### Basic Usage
143
 
144
  ```python
145
- import os
146
  from pathlib import Path
 
147
  import torch
148
  import torchaudio
149
  from transformers import AutoModel, AutoProcessor
@@ -159,6 +159,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-SoundEffect"
159
  device = "cuda" if torch.cuda.is_available() else "cpu"
160
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  processor = AutoProcessor.from_pretrained(
163
  pretrained_model_name_or_path,
164
  trust_remote_code=True,
@@ -176,14 +198,14 @@ conversations = [
176
  model = AutoModel.from_pretrained(
177
  pretrained_model_name_or_path,
178
  trust_remote_code=True,
179
- attn_implementation="sdpa",
 
180
  torch_dtype=dtype,
181
  ).to(device)
182
  model.eval()
183
 
184
  batch_size = 1
185
 
186
- messages = []
187
  save_dir = Path("inference_root")
188
  save_dir.mkdir(exist_ok=True, parents=True)
189
  sample_idx = 0
@@ -205,6 +227,7 @@ with torch.no_grad():
205
  out_path = save_dir / f"sample{sample_idx}.wav"
206
  sample_idx += 1
207
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
 
208
  ```
209
 
210
  ### Input Types
 
142
  ### Basic Usage
143
 
144
  ```python
 
145
  from pathlib import Path
146
+ import importlib.util
147
  import torch
148
  import torchaudio
149
  from transformers import AutoModel, AutoProcessor
 
159
  device = "cuda" if torch.cuda.is_available() else "cpu"
160
  dtype = torch.bfloat16 if device == "cuda" else torch.float32
161
 
162
+ def resolve_attn_implementation() -> str:
163
+ # Prefer FlashAttention 2 when package + device conditions are met.
164
+ if (
165
+ device == "cuda"
166
+ and importlib.util.find_spec("flash_attn") is not None
167
+ and dtype in {torch.float16, torch.bfloat16}
168
+ ):
169
+ major, _ = torch.cuda.get_device_capability()
170
+ if major >= 8:
171
+ return "flash_attention_2"
172
+
173
+ # CUDA fallback: use PyTorch SDPA kernels.
174
+ if device == "cuda":
175
+ return "sdpa"
176
+
177
+ # CPU fallback.
178
+ return "eager"
179
+
180
+
181
+ attn_implementation = resolve_attn_implementation()
182
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
183
+
184
  processor = AutoProcessor.from_pretrained(
185
  pretrained_model_name_or_path,
186
  trust_remote_code=True,
 
198
  model = AutoModel.from_pretrained(
199
  pretrained_model_name_or_path,
200
  trust_remote_code=True,
201
+ # If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
202
+ attn_implementation=attn_implementation,
203
  torch_dtype=dtype,
204
  ).to(device)
205
  model.eval()
206
 
207
  batch_size = 1
208
 
 
209
  save_dir = Path("inference_root")
210
  save_dir.mkdir(exist_ok=True, parents=True)
211
  sample_idx = 0
 
227
  out_path = save_dir / f"sample{sample_idx}.wav"
228
  sample_idx += 1
229
  torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
230
+
231
  ```
232
 
233
  ### Input Types