update readme
Browse files
README.md
CHANGED
|
@@ -142,8 +142,8 @@ Notes:
|
|
| 142 |
### Basic Usage
|
| 143 |
|
| 144 |
```python
|
| 145 |
-
import os
|
| 146 |
from pathlib import Path
|
|
|
|
| 147 |
import torch
|
| 148 |
import torchaudio
|
| 149 |
from transformers import AutoModel, AutoProcessor
|
|
@@ -159,6 +159,28 @@ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-SoundEffect"
|
|
| 159 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 160 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
processor = AutoProcessor.from_pretrained(
|
| 163 |
pretrained_model_name_or_path,
|
| 164 |
trust_remote_code=True,
|
|
@@ -176,14 +198,14 @@ conversations = [
|
|
| 176 |
model = AutoModel.from_pretrained(
|
| 177 |
pretrained_model_name_or_path,
|
| 178 |
trust_remote_code=True,
|
| 179 |
-
attn_implementation="
|
|
|
|
| 180 |
torch_dtype=dtype,
|
| 181 |
).to(device)
|
| 182 |
model.eval()
|
| 183 |
|
| 184 |
batch_size = 1
|
| 185 |
|
| 186 |
-
messages = []
|
| 187 |
save_dir = Path("inference_root")
|
| 188 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 189 |
sample_idx = 0
|
|
@@ -205,6 +227,7 @@ with torch.no_grad():
|
|
| 205 |
out_path = save_dir / f"sample{sample_idx}.wav"
|
| 206 |
sample_idx += 1
|
| 207 |
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
|
|
|
| 208 |
```
|
| 209 |
|
| 210 |
### Input Types
|
|
|
|
| 142 |
### Basic Usage
|
| 143 |
|
| 144 |
```python
|
|
|
|
| 145 |
from pathlib import Path
|
| 146 |
+
import importlib.util
|
| 147 |
import torch
|
| 148 |
import torchaudio
|
| 149 |
from transformers import AutoModel, AutoProcessor
|
|
|
|
| 159 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 160 |
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 161 |
|
| 162 |
+
def resolve_attn_implementation() -> str:
|
| 163 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 164 |
+
if (
|
| 165 |
+
device == "cuda"
|
| 166 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 167 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 168 |
+
):
|
| 169 |
+
major, _ = torch.cuda.get_device_capability()
|
| 170 |
+
if major >= 8:
|
| 171 |
+
return "flash_attention_2"
|
| 172 |
+
|
| 173 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 174 |
+
if device == "cuda":
|
| 175 |
+
return "sdpa"
|
| 176 |
+
|
| 177 |
+
# CPU fallback.
|
| 178 |
+
return "eager"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
attn_implementation = resolve_attn_implementation()
|
| 182 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 183 |
+
|
| 184 |
processor = AutoProcessor.from_pretrained(
|
| 185 |
pretrained_model_name_or_path,
|
| 186 |
trust_remote_code=True,
|
|
|
|
| 198 |
model = AutoModel.from_pretrained(
|
| 199 |
pretrained_model_name_or_path,
|
| 200 |
trust_remote_code=True,
|
| 201 |
+
# If FlashAttention 2 is installed, you can set attn_implementation="flash_attention_2"
|
| 202 |
+
attn_implementation=attn_implementation,
|
| 203 |
torch_dtype=dtype,
|
| 204 |
).to(device)
|
| 205 |
model.eval()
|
| 206 |
|
| 207 |
batch_size = 1
|
| 208 |
|
|
|
|
| 209 |
save_dir = Path("inference_root")
|
| 210 |
save_dir.mkdir(exist_ok=True, parents=True)
|
| 211 |
sample_idx = 0
|
|
|
|
| 227 |
out_path = save_dir / f"sample{sample_idx}.wav"
|
| 228 |
sample_idx += 1
|
| 229 |
torchaudio.save(out_path, audio.unsqueeze(0), processor.model_config.sampling_rate)
|
| 230 |
+
|
| 231 |
```
|
| 232 |
|
| 233 |
### Input Types
|