Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| class Precision: | |
| compute: torch.dtype | |
| logits: torch.dtype | |
| def resolve_precision(kind: str | None, device: torch.device) -> Precision: | |
| normalized = (kind or "auto").lower() | |
| if normalized == "auto": | |
| normalized = "bfloat16" if device.type == "cuda" else "float32" | |
| if normalized == "bfloat16": | |
| compute = torch.bfloat16 if device.type == "cuda" else torch.float32 | |
| return Precision(compute=compute, logits=torch.float32) | |
| if normalized == "float32": | |
| return Precision(compute=torch.float32, logits=torch.float32) | |
| raise ValueError(f"Unsupported dtype '{kind}'") | |