Spaces:
Running
Running
| from typing import Optional, Tuple, Type | |
| import torch | |
| def _apply_torch_amp_shims() -> None: | |
| """ | |
| AST expects torch.amp.GradScaler/autocast (torch 2.3+); shim from torch.cuda.amp for 2.2. | |
| """ | |
| if not hasattr(torch.amp, "GradScaler") and hasattr(torch.cuda, "amp"): | |
| torch.amp.GradScaler = torch.cuda.amp.GradScaler # type: ignore[attr-defined] | |
| if not hasattr(torch.amp, "autocast") and hasattr(torch.cuda, "amp"): | |
| torch.amp.autocast = torch.cuda.amp.autocast # type: ignore[attr-defined] | |
| def load_ast_trainer() -> Tuple[Optional[Type[object]], Optional[Type[object]], Optional[Exception]]: | |
| """ | |
| Try to import AdaptiveSparseTrainer and ASTConfig from adaptive-sparse-training. | |
| Returns (trainer_cls, config_cls, error) | |
| """ | |
| try: | |
| _apply_torch_amp_shims() | |
| from adaptive_sparse_training import AdaptiveSparseTrainer # type: ignore | |
| from adaptive_sparse_training.config import ASTConfig # type: ignore | |
| return AdaptiveSparseTrainer, ASTConfig, None | |
| except Exception as exc: # pragma: no cover - optional dependency | |
| return None, None, exc | |