gary-boon Claude Opus 4.5 commited on
Commit
9080f28
·
1 Parent(s): e694533

Phase 2: Add Devstral backend support

Browse files

Add support for Mistral-based models (Devstral) with:

- MistralAdapter in model_adapter.py for Mistral architecture
- devstral-small config (40 layers, 32 Q heads, 8 KV heads, 131K vocab)
- Percentage-based layer classification (works for any layer count)
- Environment variable support:
- DEFAULT_MODEL: which model to load (default: codegen-350m)
- TORCH_DTYPE: bf16/fp16/fp32 (default: auto based on device)
- MAX_CONTEXT: context length limit (default: 8192)
- BATCH_SIZE: batch size (default: 1)
- GET /models endpoint: list available models with hardware availability
- GET /models/current endpoint: return currently loaded model info

Devstral requires TORCH_DTYPE=bf16 when deployed.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

backend/model_adapter.py CHANGED
@@ -240,6 +240,63 @@ class CodeLlamaAdapter(ModelAdapter):
240
  return (attn.q_proj, attn.k_proj, attn.v_proj)
241
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  def create_adapter(model: Any, tokenizer: Any, model_id: str) -> ModelAdapter:
244
  """
245
  Factory function to create appropriate adapter for a model
@@ -267,6 +324,9 @@ def create_adapter(model: Any, tokenizer: Any, model_id: str) -> ModelAdapter:
267
  elif architecture == "llama":
268
  logger.info(f"Creating Code-Llama adapter for {model_id}")
269
  adapter = CodeLlamaAdapter(model, tokenizer, config)
 
 
 
270
  else:
271
  raise ValueError(f"Unsupported architecture: {architecture}")
272
 
 
240
  return (attn.q_proj, attn.k_proj, attn.v_proj)
241
 
242
 
243
+ class MistralAdapter(ModelAdapter):
244
+ """
245
+ Adapter for Mistral-based models (Devstral, Mistral, Codestral, etc.)
246
+ Uses Grouped Query Attention (GQA) similar to LLaMA but with sliding window attention
247
+ """
248
+
249
+ def _get_layers(self):
250
+ """
251
+ Defensive access: Mistral layers may be nested differently depending on model variant.
252
+ Handles both model.model.layers and model.layers structures.
253
+ """
254
+ if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
255
+ return self.model.model.layers
256
+ elif hasattr(self.model, 'layers'):
257
+ return self.model.layers
258
+ raise AttributeError("Cannot find transformer layers in Mistral model")
259
+
260
+ def get_num_layers(self) -> int:
261
+ return self.model.config.num_hidden_layers
262
+
263
+ def get_num_heads(self) -> int:
264
+ return self.model.config.num_attention_heads
265
+
266
+ def get_num_kv_heads(self) -> Optional[int]:
267
+ """
268
+ Mistral/Devstral uses GQA - typically 8 KV heads for 32 Q heads
269
+ """
270
+ return getattr(self.model.config, 'num_key_value_heads', None)
271
+
272
+ def get_layer_module(self, layer_idx: int):
273
+ """
274
+ Mistral structure: model.model.layers[layer_idx]
275
+ """
276
+ return self._get_layers()[layer_idx]
277
+
278
+ def get_attention_module(self, layer_idx: int):
279
+ """
280
+ Mistral attention: layers[layer_idx].self_attn
281
+ """
282
+ return self._get_layers()[layer_idx].self_attn
283
+
284
+ def get_ffn_module(self, layer_idx: int):
285
+ """
286
+ Mistral FFN: layers[layer_idx].mlp
287
+ """
288
+ return self._get_layers()[layer_idx].mlp
289
+
290
+ def get_qkv_projections(self, layer_idx: int):
291
+ """
292
+ Mistral Q, K, V projections
293
+ Mistral has separate q_proj, k_proj, v_proj modules
294
+ Note: K and V use GQA (8 KV heads vs 32 Q heads for Devstral)
295
+ """
296
+ attn = self.get_attention_module(layer_idx)
297
+ return (attn.q_proj, attn.k_proj, attn.v_proj)
298
+
299
+
300
  def create_adapter(model: Any, tokenizer: Any, model_id: str) -> ModelAdapter:
301
  """
302
  Factory function to create appropriate adapter for a model
 
324
  elif architecture == "llama":
325
  logger.info(f"Creating Code-Llama adapter for {model_id}")
326
  adapter = CodeLlamaAdapter(model, tokenizer, config)
327
+ elif architecture == "mistral":
328
+ logger.info(f"Creating Mistral adapter for {model_id}")
329
+ adapter = MistralAdapter(model, tokenizer, config)
330
  else:
331
  raise ValueError(f"Unsupported architecture: {architecture}")
332
 
backend/model_config.py CHANGED
@@ -55,6 +55,21 @@ SUPPORTED_MODELS: Dict[str, ModelConfig] = {
55
  "requires_gpu": True, # Strongly recommended for usable performance
56
  "min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
57
  "min_ram_gb": 18.0 # FP16 requires ~18GB RAM for CPU fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  }
59
  }
60
 
 
55
  "requires_gpu": True, # Strongly recommended for usable performance
56
  "min_vram_gb": 14.0, # FP16 requires ~14GB VRAM
57
  "min_ram_gb": 18.0 # FP16 requires ~18GB RAM for CPU fallback
58
+ },
59
+ "devstral-small": {
60
+ "hf_path": "mistralai/Devstral-Small-2507",
61
+ "display_name": "Devstral Small 24B",
62
+ "architecture": "mistral",
63
+ "size": "24B",
64
+ "num_layers": 40,
65
+ "num_heads": 32,
66
+ "num_kv_heads": 8, # GQA: 32 Q heads, 8 KV heads
67
+ "vocab_size": 131072,
68
+ "context_length": 131072,
69
+ "attention_type": "grouped_query",
70
+ "requires_gpu": True, # BF16 required, GPU strongly recommended
71
+ "min_vram_gb": 48.0, # BF16 requires ~48GB VRAM
72
+ "min_ram_gb": 96.0 # BF16 requires ~96GB RAM for CPU fallback
73
  }
74
  }
75
 
backend/model_service.py CHANGED
@@ -106,16 +106,31 @@ class TraceData(BaseModel):
106
 
107
  class ModelManager:
108
  """Manages model loading and generation with trace extraction"""
109
-
110
  def __init__(self):
111
  self.model = None
112
  self.tokenizer = None
113
  self.adapter = None # ModelAdapter for multi-model support
114
  self.device = None
115
- self.model_name = "Salesforce/codegen-350M-mono"
116
- self.model_id = "codegen-350m" # Model ID for adapter lookup
117
  self.websocket_clients: List[WebSocket] = []
118
  self.trace_buffer: List[TraceData] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  async def initialize(self):
121
  """Load model on startup"""
@@ -139,12 +154,34 @@ class ModelManager:
139
  self.device = torch.device("cpu")
140
  device_name = "CPU"
141
 
142
- logger.info(f"Loading model on {device_name}...")
143
-
144
- # Load model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  self.model = AutoModelForCausalLM.from_pretrained(
146
  self.model_name,
147
- torch_dtype=torch.float32 if self.device.type == "cpu" else torch.float16,
148
  low_cpu_mem_usage=True,
149
  trust_remote_code=True
150
  ).to(self.device)
@@ -922,6 +959,87 @@ async def debug_device():
922
  "timestamp": datetime.now().isoformat()
923
  }
924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
  @app.get("/model/info")
926
  async def model_info(authenticated: bool = Depends(verify_api_key)):
927
  """Get detailed information about the loaded model"""
@@ -1549,15 +1667,16 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1549
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1550
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1551
 
1552
- # Detect layer-level pattern
1553
  layer_pattern = None
 
1554
  if layer_idx == 0:
1555
  layer_pattern = {"type": "positional", "confidence": 0.78}
1556
- elif layer_idx <= 5 and step > 0:
1557
  layer_pattern = {"type": "previous_token", "confidence": 0.65}
1558
- elif 5 <= layer_idx <= 15:
1559
  layer_pattern = {"type": "induction", "confidence": 0.87}
1560
- elif layer_idx > 15:
1561
  layer_pattern = {"type": "semantic", "confidence": 0.92}
1562
 
1563
  layer_data_this_token.append({
 
106
 
107
  class ModelManager:
108
  """Manages model loading and generation with trace extraction"""
109
+
110
  def __init__(self):
111
  self.model = None
112
  self.tokenizer = None
113
  self.adapter = None # ModelAdapter for multi-model support
114
  self.device = None
115
+ self.dtype = None # Will be set from TORCH_DTYPE env var
 
116
  self.websocket_clients: List[WebSocket] = []
117
  self.trace_buffer: List[TraceData] = []
118
+
119
+ # Read configuration from environment variables
120
+ self.model_id = os.environ.get("DEFAULT_MODEL", "codegen-350m")
121
+ self.max_context = int(os.environ.get("MAX_CONTEXT", "8192"))
122
+ self.batch_size = int(os.environ.get("BATCH_SIZE", "1"))
123
+
124
+ # Get model config and HF path
125
+ from .model_config import get_model_config
126
+ config = get_model_config(self.model_id)
127
+ if config:
128
+ self.model_name = config["hf_path"]
129
+ else:
130
+ # Fallback to default if model_id not found
131
+ logger.warning(f"Unknown model ID '{self.model_id}', falling back to codegen-350m")
132
+ self.model_id = "codegen-350m"
133
+ self.model_name = "Salesforce/codegen-350M-mono"
134
 
135
  async def initialize(self):
136
  """Load model on startup"""
 
154
  self.device = torch.device("cpu")
155
  device_name = "CPU"
156
 
157
+ # Determine dtype from environment or defaults
158
+ dtype_str = os.environ.get("TORCH_DTYPE", "").lower()
159
+ if dtype_str == "bf16" or dtype_str == "bfloat16":
160
+ self.dtype = torch.bfloat16
161
+ dtype_name = "bfloat16"
162
+ elif dtype_str == "fp16" or dtype_str == "float16":
163
+ self.dtype = torch.float16
164
+ dtype_name = "float16"
165
+ elif dtype_str == "fp32" or dtype_str == "float32":
166
+ self.dtype = torch.float32
167
+ dtype_name = "float32"
168
+ elif self.device.type == "cpu":
169
+ # Default to float32 for CPU
170
+ self.dtype = torch.float32
171
+ dtype_name = "float32 (CPU default)"
172
+ else:
173
+ # Default to float16 for GPU
174
+ self.dtype = torch.float16
175
+ dtype_name = "float16 (GPU default)"
176
+
177
+ logger.info(f"Loading model '{self.model_id}' on {device_name} with dtype {dtype_name}...")
178
+ logger.info(f" HuggingFace path: {self.model_name}")
179
+ logger.info(f" Max context: {self.max_context}, Batch size: {self.batch_size}")
180
+
181
+ # Load model with configured dtype
182
  self.model = AutoModelForCausalLM.from_pretrained(
183
  self.model_name,
184
+ torch_dtype=self.dtype,
185
  low_cpu_mem_usage=True,
186
  trust_remote_code=True
187
  ).to(self.device)
 
959
  "timestamp": datetime.now().isoformat()
960
  }
961
 
962
+
963
+ @app.get("/models")
964
+ async def list_models():
965
+ """List all available models this backend can serve.
966
+
967
+ Returns model metadata including availability based on current hardware.
968
+ Used by frontend to populate model selector dynamically.
969
+ """
970
+ from .model_config import SUPPORTED_MODELS
971
+
972
+ # Check current device capabilities
973
+ has_gpu = manager.device is not None and manager.device.type in ["cuda", "mps"]
974
+ available_vram = 0
975
+ if has_gpu and torch.cuda.is_available():
976
+ available_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) # GB
977
+
978
+ models = []
979
+ for model_id, config in SUPPORTED_MODELS.items():
980
+ # Determine if model is available on current hardware
981
+ is_available = True
982
+ if config["requires_gpu"] and not has_gpu:
983
+ is_available = False
984
+ elif has_gpu and available_vram < config["min_vram_gb"]:
985
+ is_available = False
986
+
987
+ models.append({
988
+ "id": model_id,
989
+ "name": config["display_name"],
990
+ "size": config["size"],
991
+ "architecture": config["architecture"],
992
+ "num_layers": config["num_layers"],
993
+ "num_heads": config["num_heads"],
994
+ "vocab_size": config["vocab_size"],
995
+ "context_length": config["context_length"],
996
+ "attention_type": config["attention_type"],
997
+ "requires_gpu": config["requires_gpu"],
998
+ "available": is_available
999
+ })
1000
+
1001
+ return {"models": models}
1002
+
1003
+
1004
+ @app.get("/models/current")
1005
+ async def current_model():
1006
+ """Return info about the currently loaded model.
1007
+
1008
+ Used by frontend to verify which model is active and its configuration.
1009
+ Returns null fields if no model is loaded.
1010
+ """
1011
+ if manager.model is None:
1012
+ return {
1013
+ "id": None,
1014
+ "name": None,
1015
+ "device": None,
1016
+ "dtype": None,
1017
+ "loaded": False
1018
+ }
1019
+
1020
+ # Get dtype string
1021
+ dtype_str = None
1022
+ if manager.dtype is not None:
1023
+ if manager.dtype == torch.bfloat16:
1024
+ dtype_str = "bf16"
1025
+ elif manager.dtype == torch.float16:
1026
+ dtype_str = "fp16"
1027
+ elif manager.dtype == torch.float32:
1028
+ dtype_str = "fp32"
1029
+ else:
1030
+ dtype_str = str(manager.dtype)
1031
+
1032
+ return {
1033
+ "id": manager.model_id,
1034
+ "name": manager.model_name,
1035
+ "device": str(manager.device) if manager.device else None,
1036
+ "dtype": dtype_str,
1037
+ "loaded": True,
1038
+ "max_context": manager.max_context,
1039
+ "batch_size": manager.batch_size
1040
+ }
1041
+
1042
+
1043
  @app.get("/model/info")
1044
  async def model_info(authenticated: bool = Depends(verify_api_key)):
1045
  """Get detailed information about the loaded model"""
 
1667
  # Sort by max_weight (return all heads, frontend will decide how many to display)
1668
  critical_heads.sort(key=lambda h: h["max_weight"], reverse=True)
1669
 
1670
+ # Detect layer-level pattern (percentage-based for any layer count)
1671
  layer_pattern = None
1672
+ layer_fraction = (layer_idx + 1) / n_layers # 1-indexed fraction
1673
  if layer_idx == 0:
1674
  layer_pattern = {"type": "positional", "confidence": 0.78}
1675
+ elif layer_fraction <= 0.25 and step > 0:
1676
  layer_pattern = {"type": "previous_token", "confidence": 0.65}
1677
+ elif layer_fraction <= 0.75:
1678
  layer_pattern = {"type": "induction", "confidence": 0.87}
1679
+ else:
1680
  layer_pattern = {"type": "semantic", "confidence": 0.92}
1681
 
1682
  layer_data_this_token.append({