Nekochu commited on
Commit
04ccf32
·
1 Parent(s): 917e4ed

SDPA first on Blackwell, FA2 only for Ampere/Hopper, txt caption support

Browse files
Files changed (1) hide show
  1. train_engine.py +14 -16
train_engine.py CHANGED
@@ -541,37 +541,35 @@ def _ensure_acestep_imports():
541
 
542
 
543
  def _attn_candidates(device: str) -> List[str]:
544
- """FA2 -> SDPA -> eager, filtered by availability.
545
 
546
- On CUDA with flash_attn installed and compute capability >= 8.0,
547
- flash_attention_2 is tried first. On CPU, flash_attention_2 is
548
- always skipped (it requires CUDA).
549
  """
550
  candidates = []
551
  if device.startswith("cuda"):
 
552
  try:
553
  import flash_attn # noqa: F401
554
  dev_idx = int(device.split(":")[1]) if ":" in device else 0
555
  props = torch.cuda.get_device_properties(dev_idx)
556
- if props.major >= 8:
557
- candidates.append("flash_attention_2")
558
- logger.info(
559
- "flash_attention_2 available (compute %d.%d, flash_attn installed)",
560
- props.major, props.minor,
561
- )
562
  else:
563
- logger.info(
564
- "flash_attention_2 skipped: compute %d.%d < 8.0 (need Ampere+)",
565
- props.major, props.minor,
566
- )
567
  except ImportError:
568
  logger.info("flash_attention_2 skipped: flash_attn package not installed")
569
  except Exception as exc:
570
  logger.info("flash_attention_2 skipped: %s", exc)
571
  else:
 
572
  logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
573
- candidates.extend(["sdpa", "eager"])
574
- return candidates
 
575
 
576
 
577
  def load_model_for_training(
 
541
 
542
 
543
  def _attn_candidates(device: str) -> List[str]:
544
+ """SDPA -> FA2 -> eager, filtered by availability.
545
 
546
+ SDPA is preferred (faster on Blackwell SM12.0, native cuDNN).
547
+ FA2 is fallback for older GPUs where SDPA is slower.
548
+ On CPU, only SDPA and eager are tried.
549
  """
550
  candidates = []
551
  if device.startswith("cuda"):
552
+ candidates.append("sdpa")
553
  try:
554
  import flash_attn # noqa: F401
555
  dev_idx = int(device.split(":")[1]) if ":" in device else 0
556
  props = torch.cuda.get_device_properties(dev_idx)
557
+ if props.major >= 8 and props.major < 12:
558
+ # FA2 is faster on Ampere/Hopper (SM 8.x-9.x), slower on Blackwell (SM 12.x)
559
+ candidates.insert(0, "flash_attention_2")
560
+ logger.info("FA2 prioritized (compute %d.%d, Ampere/Hopper)", props.major, props.minor)
 
 
561
  else:
562
+ logger.info("FA2 available but SDPA preferred (compute %d.%d)", props.major, props.minor)
 
 
 
563
  except ImportError:
564
  logger.info("flash_attention_2 skipped: flash_attn package not installed")
565
  except Exception as exc:
566
  logger.info("flash_attention_2 skipped: %s", exc)
567
  else:
568
+ candidates.append("sdpa")
569
  logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
570
+ if "eager" not in candidates:
571
+ candidates.append("eager")
572
+ return list(dict.fromkeys(candidates))
573
 
574
 
575
  def load_model_for_training(