File size: 725 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from __future__ import annotations

from dataclasses import dataclass

import torch


@dataclass(frozen=True)
class Precision:
    compute: torch.dtype
    logits: torch.dtype


def resolve_precision(kind: str | None, device: torch.device) -> Precision:
    normalized = (kind or "auto").lower()
    if normalized == "auto":
        normalized = "bfloat16" if device.type == "cuda" else "float32"
    if normalized == "bfloat16":
        compute = torch.bfloat16 if device.type == "cuda" else torch.float32
        return Precision(compute=compute, logits=torch.float32)
    if normalized == "float32":
        return Precision(compute=torch.float32, logits=torch.float32)
    raise ValueError(f"Unsupported dtype '{kind}'")