Spaces:
Running
Running
SDPA first on Blackwell, FA2 only for Ampere/Hopper, txt caption support
Browse files- train_engine.py +14 -16
train_engine.py
CHANGED
|
@@ -541,37 +541,35 @@ def _ensure_acestep_imports():
|
|
| 541 |
|
| 542 |
|
| 543 |
def _attn_candidates(device: str) -> List[str]:
|
| 544 |
-
"""
|
| 545 |
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
"""
|
| 550 |
candidates = []
|
| 551 |
if device.startswith("cuda"):
|
|
|
|
| 552 |
try:
|
| 553 |
import flash_attn # noqa: F401
|
| 554 |
dev_idx = int(device.split(":")[1]) if ":" in device else 0
|
| 555 |
props = torch.cuda.get_device_properties(dev_idx)
|
| 556 |
-
if props.major >= 8:
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
props.major, props.minor,
|
| 561 |
-
)
|
| 562 |
else:
|
| 563 |
-
logger.info(
|
| 564 |
-
"flash_attention_2 skipped: compute %d.%d < 8.0 (need Ampere+)",
|
| 565 |
-
props.major, props.minor,
|
| 566 |
-
)
|
| 567 |
except ImportError:
|
| 568 |
logger.info("flash_attention_2 skipped: flash_attn package not installed")
|
| 569 |
except Exception as exc:
|
| 570 |
logger.info("flash_attention_2 skipped: %s", exc)
|
| 571 |
else:
|
|
|
|
| 572 |
logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
|
| 573 |
-
|
| 574 |
-
|
|
|
|
| 575 |
|
| 576 |
|
| 577 |
def load_model_for_training(
|
|
|
|
| 541 |
|
| 542 |
|
| 543 |
def _attn_candidates(device: str) -> List[str]:
|
| 544 |
+
"""SDPA -> FA2 -> eager, filtered by availability.
|
| 545 |
|
| 546 |
+
SDPA is preferred (faster on Blackwell SM12.0, native cuDNN).
|
| 547 |
+
FA2 is fallback for older GPUs where SDPA is slower.
|
| 548 |
+
On CPU, only SDPA and eager are tried.
|
| 549 |
"""
|
| 550 |
candidates = []
|
| 551 |
if device.startswith("cuda"):
|
| 552 |
+
candidates.append("sdpa")
|
| 553 |
try:
|
| 554 |
import flash_attn # noqa: F401
|
| 555 |
dev_idx = int(device.split(":")[1]) if ":" in device else 0
|
| 556 |
props = torch.cuda.get_device_properties(dev_idx)
|
| 557 |
+
if props.major >= 8 and props.major < 12:
|
| 558 |
+
# FA2 is faster on Ampere/Hopper (SM 8.x-9.x), slower on Blackwell (SM 12.x)
|
| 559 |
+
candidates.insert(0, "flash_attention_2")
|
| 560 |
+
logger.info("FA2 prioritized (compute %d.%d, Ampere/Hopper)", props.major, props.minor)
|
|
|
|
|
|
|
| 561 |
else:
|
| 562 |
+
logger.info("FA2 available but SDPA preferred (compute %d.%d)", props.major, props.minor)
|
|
|
|
|
|
|
|
|
|
| 563 |
except ImportError:
|
| 564 |
logger.info("flash_attention_2 skipped: flash_attn package not installed")
|
| 565 |
except Exception as exc:
|
| 566 |
logger.info("flash_attention_2 skipped: %s", exc)
|
| 567 |
else:
|
| 568 |
+
candidates.append("sdpa")
|
| 569 |
logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
|
| 570 |
+
if "eager" not in candidates:
|
| 571 |
+
candidates.append("eager")
|
| 572 |
+
return list(dict.fromkeys(candidates))
|
| 573 |
|
| 574 |
|
| 575 |
def load_model_for_training(
|