Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
gary-boon
Claude Opus 4.5
commited on
Commit
·
62525b2
1
Parent(s):
9080f28
Add recommended_dtype to model configs
Browse filesEach model now specifies its recommended dtype:
- codegen-350m: fp16
- code-llama-7b: fp16
- devstral-small: bf16 (required for Mistral models)
Model loader now uses recommended_dtype from config when
TORCH_DTYPE env var is not explicitly set. This ensures
Devstral automatically loads with bf16 without requiring
manual configuration.
Priority: TORCH_DTYPE env > model config > device default
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- backend/model_config.py +7 -3
- backend/model_service.py +12 -2
backend/model_config.py
CHANGED
|
@@ -22,6 +22,7 @@ class ModelConfig(TypedDict):
|
|
| 22 |
requires_gpu: bool
|
| 23 |
min_vram_gb: float
|
| 24 |
min_ram_gb: float
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
# Supported models registry
|
|
@@ -39,7 +40,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
|
|
| 39 |
"attention_type": "multi_head",
|
| 40 |
"requires_gpu": False,
|
| 41 |
"min_vram_gb": 2.0,
|
| 42 |
-
"min_ram_gb": 4.0
|
|
|
|
| 43 |
},
|
| 44 |
"code-llama-7b": {
|
| 45 |
"hf_path": "codellama/CodeLlama-7b-hf",
|
|
@@ -54,7 +56,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
|
|
| 54 |
"attention_type": "grouped_query",
|
| 55 |
"requires_gpu": True, # Strongly recommended for usable performance
|
| 56 |
"min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
|
| 57 |
-
"min_ram_gb": 18.0
|
|
|
|
| 58 |
},
|
| 59 |
"devstral-small": {
|
| 60 |
"hf_path": "mistralai/Devstral-Small-2507",
|
|
@@ -69,7 +72,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
|
|
| 69 |
"attention_type": "grouped_query",
|
| 70 |
"requires_gpu": True, # BF16 required, GPU strongly recommended
|
| 71 |
"min_vram_gb": 48.0, # BF16 requires ~48GB VRAM
|
| 72 |
-
"min_ram_gb": 96.0
|
|
|
|
| 73 |
}
|
| 74 |
}
|
| 75 |
|
|
|
|
| 22 |
requires_gpu: bool
|
| 23 |
min_vram_gb: float
|
| 24 |
min_ram_gb: float
|
| 25 |
+
recommended_dtype: str # "fp16", "bf16", or "fp32"
|
| 26 |
|
| 27 |
|
| 28 |
# Supported models registry
|
|
|
|
| 40 |
"attention_type": "multi_head",
|
| 41 |
"requires_gpu": False,
|
| 42 |
"min_vram_gb": 2.0,
|
| 43 |
+
"min_ram_gb": 4.0,
|
| 44 |
+
"recommended_dtype": "fp16" # fp16 for GPU, fp32 for CPU
|
| 45 |
},
|
| 46 |
"code-llama-7b": {
|
| 47 |
"hf_path": "codellama/CodeLlama-7b-hf",
|
|
|
|
| 56 |
"attention_type": "grouped_query",
|
| 57 |
"requires_gpu": True, # Strongly recommended for usable performance
|
| 58 |
"min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
|
| 59 |
+
"min_ram_gb": 18.0, # FP16 requires ~18GB RAM for CPU fallback
|
| 60 |
+
"recommended_dtype": "fp16"
|
| 61 |
},
|
| 62 |
"devstral-small": {
|
| 63 |
"hf_path": "mistralai/Devstral-Small-2507",
|
|
|
|
| 72 |
"attention_type": "grouped_query",
|
| 73 |
"requires_gpu": True, # BF16 required, GPU strongly recommended
|
| 74 |
"min_vram_gb": 48.0, # BF16 requires ~48GB VRAM
|
| 75 |
+
"min_ram_gb": 96.0, # BF16 requires ~96GB RAM for CPU fallback
|
| 76 |
+
"recommended_dtype": "bf16" # Devstral requires bfloat16
|
| 77 |
}
|
| 78 |
}
|
| 79 |
|
backend/model_service.py
CHANGED
|
@@ -154,8 +154,18 @@ class ModelManager:
|
|
| 154 |
self.device = torch.device("cpu")
|
| 155 |
device_name = "CPU"
|
| 156 |
|
| 157 |
-
# Determine dtype from environment or defaults
|
| 158 |
dtype_str = os.environ.get("TORCH_DTYPE", "").lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
| 160 |
self.dtype = torch.bfloat16
|
| 161 |
dtype_name = "bfloat16"
|
|
@@ -166,7 +176,7 @@ class ModelManager:
|
|
| 166 |
self.dtype = torch.float32
|
| 167 |
dtype_name = "float32"
|
| 168 |
elif self.device.type == "cpu":
|
| 169 |
-
# Default to float32 for CPU
|
| 170 |
self.dtype = torch.float32
|
| 171 |
dtype_name = "float32 (CPU default)"
|
| 172 |
else:
|
|
|
|
| 154 |
self.device = torch.device("cpu")
|
| 155 |
device_name = "CPU"
|
| 156 |
|
| 157 |
+
# Determine dtype from environment, model config, or defaults
|
| 158 |
dtype_str = os.environ.get("TORCH_DTYPE", "").lower()
|
| 159 |
+
|
| 160 |
+
# If not set in env, use model's recommended dtype
|
| 161 |
+
if not dtype_str:
|
| 162 |
+
from .model_config import get_model_config
|
| 163 |
+
model_config = get_model_config(self.model_id)
|
| 164 |
+
if model_config and "recommended_dtype" in model_config:
|
| 165 |
+
dtype_str = model_config["recommended_dtype"]
|
| 166 |
+
logger.info(f"Using model's recommended dtype: {dtype_str}")
|
| 167 |
+
|
| 168 |
+
# Parse dtype string to torch dtype
|
| 169 |
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
| 170 |
self.dtype = torch.bfloat16
|
| 171 |
dtype_name = "bfloat16"
|
|
|
|
| 176 |
self.dtype = torch.float32
|
| 177 |
dtype_name = "float32"
|
| 178 |
elif self.device.type == "cpu":
|
| 179 |
+
# Default to float32 for CPU (safest)
|
| 180 |
self.dtype = torch.float32
|
| 181 |
dtype_name = "float32 (CPU default)"
|
| 182 |
else:
|