File size: 1,135 Bytes
5ec9e9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from typing import Optional, Tuple, Type

import torch


def _apply_torch_amp_shims() -> None:
    """
    AST expects torch.amp.GradScaler/autocast (torch 2.3+); shim from torch.cuda.amp for 2.2.
    """
    if not hasattr(torch.amp, "GradScaler") and hasattr(torch.cuda, "amp"):
        torch.amp.GradScaler = torch.cuda.amp.GradScaler  # type: ignore[attr-defined]
    if not hasattr(torch.amp, "autocast") and hasattr(torch.cuda, "amp"):
        torch.amp.autocast = torch.cuda.amp.autocast  # type: ignore[attr-defined]


def load_ast_trainer() -> Tuple[Optional[Type[object]], Optional[Type[object]], Optional[Exception]]:
    """
    Try to import AdaptiveSparseTrainer and ASTConfig from adaptive-sparse-training.
    Returns (trainer_cls, config_cls, error)
    """
    try:
        _apply_torch_amp_shims()
        from adaptive_sparse_training import AdaptiveSparseTrainer  # type: ignore
        from adaptive_sparse_training.config import ASTConfig  # type: ignore

        return AdaptiveSparseTrainer, ASTConfig, None
    except Exception as exc:  # pragma: no cover - optional dependency
        return None, None, exc