| import torch |
| from comfy import model_management |
|
|
| def string_to_dtype(s="none", mode=None): |
| s = s.lower().strip() |
| if s in ["default", "as-is"]: |
| return None |
| elif s in ["auto", "auto (comfy)"]: |
| if mode == "vae": |
| return model_management.vae_device() |
| elif mode == "text_encoder": |
| return model_management.text_encoder_dtype() |
| elif mode == "unet": |
| return model_management.unet_dtype() |
| else: |
| raise NotImplementedError(f"Unknown dtype mode '{mode}'") |
| elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: |
| return None |
| elif s in ["fp32", "float32", "float"]: |
| return torch.float32 |
| elif s in ["bf16", "bfloat16"]: |
| return torch.bfloat16 |
| elif s in ["fp16", "float16", "half"]: |
| return torch.float16 |
| elif "fp8" in s or "float8" in s: |
| if "e5m2" in s: |
| return torch.float8_e5m2 |
| elif "e4m3" in s: |
| return torch.float8_e4m3fn |
| else: |
| raise NotImplementedError(f"Unknown 8bit dtype '{s}'") |
| elif "bnb" in s: |
| assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" |
| return s |
| elif s is None: |
| return None |
| else: |
| raise NotImplementedError(f"Unknown dtype '{s}'") |