Spaces:
Sleeping
Sleeping
File size: 1,135 Bytes
5ec9e9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Optional, Tuple, Type
import torch
def _apply_torch_amp_shims() -> None:
"""
AST expects torch.amp.GradScaler/autocast (torch 2.3+); shim from torch.cuda.amp for 2.2.
"""
if not hasattr(torch.amp, "GradScaler") and hasattr(torch.cuda, "amp"):
torch.amp.GradScaler = torch.cuda.amp.GradScaler # type: ignore[attr-defined]
if not hasattr(torch.amp, "autocast") and hasattr(torch.cuda, "amp"):
torch.amp.autocast = torch.cuda.amp.autocast # type: ignore[attr-defined]
def load_ast_trainer() -> Tuple[Optional[Type[object]], Optional[Type[object]], Optional[Exception]]:
"""
Try to import AdaptiveSparseTrainer and ASTConfig from adaptive-sparse-training.
Returns (trainer_cls, config_cls, error)
"""
try:
_apply_torch_amp_shims()
from adaptive_sparse_training import AdaptiveSparseTrainer # type: ignore
from adaptive_sparse_training.config import ASTConfig # type: ignore
return AdaptiveSparseTrainer, ASTConfig, None
except Exception as exc: # pragma: no cover - optional dependency
return None, None, exc
|