Spaces:
Running
Running
disable flash_sdp on CPU, force attn_implementation=sdpa for training
Browse files
app.py
CHANGED
|
@@ -335,6 +335,10 @@ time.sleep(2)
|
|
| 335 |
gc.collect()
|
| 336 |
|
| 337 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
import torchaudio
|
| 339 |
_orig = torchaudio.load
|
| 340 |
def _sf(p, *a, **kw):
|
|
@@ -371,10 +375,15 @@ try:
|
|
| 371 |
from acestep.training_v2.trainer_fixed import FixedLoRATrainer
|
| 372 |
from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
|
| 373 |
|
|
|
|
| 374 |
model = load_decoder_for_training(
|
| 375 |
checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
|
| 376 |
device="cpu", precision="float32",
|
| 377 |
).float()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
trainer = FixedLoRATrainer(model,
|
| 380 |
LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
|
|
|
|
| 335 |
gc.collect()
|
| 336 |
|
| 337 |
try:
|
| 338 |
+
import torch
|
| 339 |
+
torch.backends.cuda.enable_flash_sdp(False)
|
| 340 |
+
os.environ["ATTN_BACKEND"] = "sdpa"
|
| 341 |
+
|
| 342 |
import torchaudio
|
| 343 |
_orig = torchaudio.load
|
| 344 |
def _sf(p, *a, **kw):
|
|
|
|
| 375 |
from acestep.training_v2.trainer_fixed import FixedLoRATrainer
|
| 376 |
from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
|
| 377 |
|
| 378 |
+
log(" Loading decoder (attn_implementation=sdpa)...")
|
| 379 |
model = load_decoder_for_training(
|
| 380 |
checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
|
| 381 |
device="cpu", precision="float32",
|
| 382 |
).float()
|
| 383 |
+
for m in model.modules():
|
| 384 |
+
if hasattr(m, 'config') and hasattr(m.config, '_attn_implementation'):
|
| 385 |
+
m.config._attn_implementation = "sdpa"
|
| 386 |
+
log(" Decoder loaded, applying LoRA...")
|
| 387 |
|
| 388 |
trainer = FixedLoRATrainer(model,
|
| 389 |
LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
|