Dia2-2B / dia2 /core /precision.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
725 Bytes
from __future__ import annotations
from dataclasses import dataclass
import torch
@dataclass(frozen=True)
class Precision:
compute: torch.dtype
logits: torch.dtype
def resolve_precision(kind: str | None, device: torch.device) -> Precision:
normalized = (kind or "auto").lower()
if normalized == "auto":
normalized = "bfloat16" if device.type == "cuda" else "float32"
if normalized == "bfloat16":
compute = torch.bfloat16 if device.type == "cuda" else torch.float32
return Precision(compute=compute, logits=torch.float32)
if normalized == "float32":
return Precision(compute=torch.float32, logits=torch.float32)
raise ValueError(f"Unsupported dtype '{kind}'")