gary-boon Claude Opus 4.5 commited on
Commit
5f122aa
·
1 Parent(s): a2875a2

Add DEVICE env var to force CPU mode on DGX Spark

Browse files

GB10 GPU (sm_121 compute capability) is not yet supported by
PyTorch/NGC containers. This adds a DEVICE environment variable
override to force CPU mode until GPU support is available.

- Add os import to model_service.py
- Check DEVICE env var before auto-detecting device
- Support DEVICE=cpu or DEVICE=cuda to override detection
- Default .env.spark.example to DEVICE=cpu for Spark

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (2) hide show
  1. .env.spark.example +4 -0
  2. backend/model_service.py +12 -3
.env.spark.example CHANGED
@@ -19,3 +19,7 @@ MAX_CONTEXT=8192
19
  BATCH_SIZE=1
20
  TORCH_DTYPE=fp16
21
  # TORCH_DTYPE=bf16 # Use bf16 for Devstral (Phase 3)
 
 
 
 
 
19
  BATCH_SIZE=1
20
  TORCH_DTYPE=fp16
21
  # TORCH_DTYPE=bf16 # Use bf16 for Devstral (Phase 3)
22
+
23
+ # Device Override (set to 'cpu' if GPU not supported yet)
24
+ # DEVICE=cuda # Default: auto-detect
25
+ DEVICE=cpu # Force CPU until GB10 GPU support available
backend/model_service.py CHANGED
@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
  import asyncio
10
  import json
 
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer
13
  from typing import Optional, List, Dict, Any
@@ -119,8 +120,16 @@ class ModelManager:
119
  async def initialize(self):
120
  """Load model on startup"""
121
  try:
122
- # Detect device
123
- if torch.cuda.is_available():
 
 
 
 
 
 
 
 
124
  self.device = torch.device("cuda")
125
  device_name = "CUDA GPU"
126
  elif torch.backends.mps.is_available():
@@ -129,7 +138,7 @@ class ModelManager:
129
  else:
130
  self.device = torch.device("cpu")
131
  device_name = "CPU"
132
-
133
  logger.info(f"Loading model on {device_name}...")
134
 
135
  # Load model
 
8
  from pydantic import BaseModel
9
  import asyncio
10
  import json
11
+ import os
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from typing import Optional, List, Dict, Any
 
120
  async def initialize(self):
121
  """Load model on startup"""
122
  try:
123
+ # Check for device override from environment
124
+ device_override = os.environ.get("DEVICE", "").lower()
125
+
126
+ if device_override == "cpu":
127
+ self.device = torch.device("cpu")
128
+ device_name = "CPU (forced via DEVICE env var)"
129
+ elif device_override == "cuda":
130
+ self.device = torch.device("cuda")
131
+ device_name = "CUDA GPU (forced via DEVICE env var)"
132
+ elif torch.cuda.is_available():
133
  self.device = torch.device("cuda")
134
  device_name = "CUDA GPU"
135
  elif torch.backends.mps.is_available():
 
138
  else:
139
  self.device = torch.device("cpu")
140
  device_name = "CPU"
141
+
142
  logger.info(f"Loading model on {device_name}...")
143
 
144
  # Load model