gary-boon Claude Opus 4.5 commited on
Commit
62525b2
·
1 Parent(s): 9080f28

Add recommended_dtype to model configs

Browse files

Each model now specifies its recommended dtype:
- codegen-350m: fp16
- code-llama-7b: fp16
- devstral-small: bf16 (required for Mistral models)

Model loader now uses recommended_dtype from config when
TORCH_DTYPE env var is not explicitly set. This ensures
Devstral automatically loads with bf16 without requiring
manual configuration.

Priority: TORCH_DTYPE env > model config > device default

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

backend/model_config.py CHANGED
@@ -22,6 +22,7 @@ class ModelConfig(TypedDict):
22
  requires_gpu: bool
23
  min_vram_gb: float
24
  min_ram_gb: float
 
25
 
26
 
27
  # Supported models registry
@@ -39,7 +40,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
39
  "attention_type": "multi_head",
40
  "requires_gpu": False,
41
  "min_vram_gb": 2.0,
42
- "min_ram_gb": 4.0
 
43
  },
44
  "code-llama-7b": {
45
  "hf_path": "codellama/CodeLlama-7b-hf",
@@ -54,7 +56,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
54
  "attention_type": "grouped_query",
55
  "requires_gpu": True, # Strongly recommended for usable performance
56
  "min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
57
- "min_ram_gb": 18.0 # FP16 requires ~18GB RAM for CPU fallback
 
58
  },
59
  "devstral-small": {
60
  "hf_path": "mistralai/Devstral-Small-2507",
@@ -69,7 +72,8 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
69
  "attention_type": "grouped_query",
70
  "requires_gpu": True, # BF16 required, GPU strongly recommended
71
  "min_vram_gb": 48.0, # BF16 requires ~48GB VRAM
72
- "min_ram_gb": 96.0 # BF16 requires ~96GB RAM for CPU fallback
 
73
  }
74
  }
75
 
 
22
  requires_gpu: bool
23
  min_vram_gb: float
24
  min_ram_gb: float
25
+ recommended_dtype: str # "fp16", "bf16", or "fp32"
26
 
27
 
28
  # Supported models registry
 
40
  "attention_type": "multi_head",
41
  "requires_gpu": False,
42
  "min_vram_gb": 2.0,
43
+ "min_ram_gb": 4.0,
44
+ "recommended_dtype": "fp16" # fp16 for GPU, fp32 for CPU
45
  },
46
  "code-llama-7b": {
47
  "hf_path": "codellama/CodeLlama-7b-hf",
 
56
  "attention_type": "grouped_query",
57
  "requires_gpu": True, # Strongly recommended for usable performance
58
  "min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
59
+ "min_ram_gb": 18.0, # FP16 requires ~18GB RAM for CPU fallback
60
+ "recommended_dtype": "fp16"
61
  },
62
  "devstral-small": {
63
  "hf_path": "mistralai/Devstral-Small-2507",
 
72
  "attention_type": "grouped_query",
73
  "requires_gpu": True, # BF16 required, GPU strongly recommended
74
  "min_vram_gb": 48.0, # BF16 requires ~48GB VRAM
75
+ "min_ram_gb": 96.0, # BF16 requires ~96GB RAM for CPU fallback
76
+ "recommended_dtype": "bf16" # Devstral requires bfloat16
77
  }
78
  }
79
 
backend/model_service.py CHANGED
@@ -154,8 +154,18 @@ class ModelManager:
154
  self.device = torch.device("cpu")
155
  device_name = "CPU"
156
 
157
- # Determine dtype from environment or defaults
158
  dtype_str = os.environ.get("TORCH_DTYPE", "").lower()
 
 
 
 
 
 
 
 
 
 
159
  if dtype_str == "bf16" or dtype_str == "bfloat16":
160
  self.dtype = torch.bfloat16
161
  dtype_name = "bfloat16"
@@ -166,7 +176,7 @@ class ModelManager:
166
  self.dtype = torch.float32
167
  dtype_name = "float32"
168
  elif self.device.type == "cpu":
169
- # Default to float32 for CPU
170
  self.dtype = torch.float32
171
  dtype_name = "float32 (CPU default)"
172
  else:
 
154
  self.device = torch.device("cpu")
155
  device_name = "CPU"
156
 
157
+ # Determine dtype from environment, model config, or defaults
158
  dtype_str = os.environ.get("TORCH_DTYPE", "").lower()
159
+
160
+ # If not set in env, use model's recommended dtype
161
+ if not dtype_str:
162
+ from .model_config import get_model_config
163
+ model_config = get_model_config(self.model_id)
164
+ if model_config and "recommended_dtype" in model_config:
165
+ dtype_str = model_config["recommended_dtype"]
166
+ logger.info(f"Using model's recommended dtype: {dtype_str}")
167
+
168
+ # Parse dtype string to torch dtype
169
  if dtype_str == "bf16" or dtype_str == "bfloat16":
170
  self.dtype = torch.bfloat16
171
  dtype_name = "bfloat16"
 
176
  self.dtype = torch.float32
177
  dtype_name = "float32"
178
  elif self.device.type == "cpu":
179
+ # Default to float32 for CPU (safest)
180
  self.dtype = torch.float32
181
  dtype_name = "float32 (CPU default)"
182
  else: