Nekochu commited on
Commit
13f9406
·
1 Parent(s): 6cee8bd

disable flash_sdp on CPU, force attn_implementation=sdpa for training

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -335,6 +335,10 @@ time.sleep(2)
335
  gc.collect()
336
 
337
  try:
 
 
 
 
338
  import torchaudio
339
  _orig = torchaudio.load
340
  def _sf(p, *a, **kw):
@@ -371,10 +375,15 @@ try:
371
  from acestep.training_v2.trainer_fixed import FixedLoRATrainer
372
  from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
373
 
 
374
  model = load_decoder_for_training(
375
  checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
376
  device="cpu", precision="float32",
377
  ).float()
 
 
 
 
378
 
379
  trainer = FixedLoRATrainer(model,
380
  LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
 
335
  gc.collect()
336
 
337
  try:
338
+ import torch
339
+ torch.backends.cuda.enable_flash_sdp(False)
340
+ os.environ["ATTN_BACKEND"] = "sdpa"
341
+
342
  import torchaudio
343
  _orig = torchaudio.load
344
  def _sf(p, *a, **kw):
 
375
  from acestep.training_v2.trainer_fixed import FixedLoRATrainer
376
  from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
377
 
378
+ log(" Loading decoder (attn_implementation=sdpa)...")
379
  model = load_decoder_for_training(
380
  checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
381
  device="cpu", precision="float32",
382
  ).float()
383
+ for m in model.modules():
384
+ if hasattr(m, 'config') and hasattr(m.config, '_attn_implementation'):
385
+ m.config._attn_implementation = "sdpa"
386
+ log(" Decoder loaded, applying LoRA...")
387
 
388
  trainer = FixedLoRATrainer(model,
389
  LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),