Spaces:
Running
on
Zero
Running
on
Zero
File size: 725 Bytes
1315cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from __future__ import annotations
from dataclasses import dataclass
import torch
@dataclass(frozen=True)
class Precision:
compute: torch.dtype
logits: torch.dtype
def resolve_precision(kind: str | None, device: torch.device) -> Precision:
normalized = (kind or "auto").lower()
if normalized == "auto":
normalized = "bfloat16" if device.type == "cuda" else "float32"
if normalized == "bfloat16":
compute = torch.bfloat16 if device.type == "cuda" else torch.float32
return Precision(compute=compute, logits=torch.float32)
if normalized == "float32":
return Precision(compute=torch.float32, logits=torch.float32)
raise ValueError(f"Unsupported dtype '{kind}'")
|