Spaces:
Configuration error
Configuration error
| import torch | |
| from comfy import model_management | |
| def string_to_dtype(s="none", mode=None): | |
| s = s.lower().strip() | |
| if s in ["default", "as-is"]: | |
| return None | |
| elif s in ["auto", "auto (comfy)"]: | |
| if mode == "vae": | |
| return model_management.vae_device() | |
| elif mode == "text_encoder": | |
| return model_management.text_encoder_dtype() | |
| elif mode == "unet": | |
| return model_management.unet_dtype() | |
| else: | |
| raise NotImplementedError(f"Unknown dtype mode '{mode}'") | |
| elif s in ["none", "auto (hf)", "auto (hf/bnb)"]: | |
| return None | |
| elif s in ["fp32", "float32", "float"]: | |
| return torch.float32 | |
| elif s in ["bf16", "bfloat16"]: | |
| return torch.bfloat16 | |
| elif s in ["fp16", "float16", "half"]: | |
| return torch.float16 | |
| elif "fp8" in s or "float8" in s: | |
| if "e5m2" in s: | |
| return torch.float8_e5m2 | |
| elif "e4m3" in s: | |
| return torch.float8_e4m3fn | |
| else: | |
| raise NotImplementedError(f"Unknown 8bit dtype '{s}'") | |
| elif "bnb" in s: | |
| assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'" | |
| return s | |
| elif s is None: | |
| return None | |
| else: | |
| raise NotImplementedError(f"Unknown dtype '{s}'") |