cgoodmaker Claude Opus 4.6 commited on
Commit
1a97904
·
1 Parent(s): 5157ba3

Force MCP tool models to CPU to avoid GPU VRAM contention with MedGemma

Browse files

- MCP subprocess passes SKINPRO_TOOL_DEVICE=cpu env var
- MONET and ConvNeXt respect this override to stay on CPU
- Prevents OOM/hang on T4 (16GB) where MedGemma already uses ~8GB
- Fix max_length warning: set both max_length=None and max_new_tokens=400 on generation_config

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

models/convnext_classifier.py CHANGED
@@ -3,6 +3,7 @@ ConvNeXt Classifier Tool - Skin lesion classification using ConvNeXt + MONET fea
3
  Loads seed42_fold0.pt checkpoint and performs classification.
4
  """
5
 
 
6
  import torch
7
  import torch.nn as nn
8
  import numpy as np
@@ -173,8 +174,11 @@ class ConvNeXtClassifier:
173
  if self.loaded:
174
  return
175
 
176
- # Determine device
177
- if self.device is None:
 
 
 
178
  if torch.cuda.is_available():
179
  self.device = "cuda"
180
  elif torch.backends.mps.is_available():
 
3
  Loads seed42_fold0.pt checkpoint and performs classification.
4
  """
5
 
6
+ import os
7
  import torch
8
  import torch.nn as nn
9
  import numpy as np
 
174
  if self.loaded:
175
  return
176
 
177
+ # Determine device (respect SKINPRO_TOOL_DEVICE override for GPU sharing)
178
+ forced = os.environ.get("SKINPRO_TOOL_DEVICE")
179
+ if forced:
180
+ self.device = forced
181
+ elif self.device is None:
182
  if torch.cuda.is_available():
183
  self.device = "cuda"
184
  elif torch.backends.mps.is_available():
models/medgemma_agent.py CHANGED
@@ -76,6 +76,10 @@ class MCPClient:
76
  """Spawn the MCP server subprocess and complete the handshake."""
77
  root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
78
  server_script = os.path.join(root, "mcp_server", "server.py")
 
 
 
 
79
  self._process = subprocess.Popen(
80
  [sys.executable, server_script], # use same venv Python (has all ML packages)
81
  stdin=subprocess.PIPE,
@@ -83,6 +87,7 @@ class MCPClient:
83
  stderr=subprocess.PIPE,
84
  text=True,
85
  bufsize=1, # line-buffered
 
86
  )
87
  self._initialize()
88
 
@@ -246,10 +251,12 @@ class MedGemmaAgent:
246
  tokenizer=processor.tokenizer,
247
  )
248
 
249
- # Clear default max_length from generation_config to avoid conflict
250
  # with max_new_tokens passed at inference time
251
  if hasattr(self.pipe.model, "generation_config"):
252
- self.pipe.model.generation_config.max_length = None
 
 
253
 
254
  self._print(f"Model loaded in {time.time() - start:.1f}s")
255
  self.loaded = True
 
76
  """Spawn the MCP server subprocess and complete the handshake."""
77
  root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
78
  server_script = os.path.join(root, "mcp_server", "server.py")
79
+ # Force MCP tool models (MONET, ConvNeXt) onto CPU so they don't
80
+ # compete with MedGemma for GPU VRAM (T4 has only 16 GB).
81
+ env = os.environ.copy()
82
+ env["SKINPRO_TOOL_DEVICE"] = "cpu"
83
  self._process = subprocess.Popen(
84
  [sys.executable, server_script], # use same venv Python (has all ML packages)
85
  stdin=subprocess.PIPE,
 
87
  stderr=subprocess.PIPE,
88
  text=True,
89
  bufsize=1, # line-buffered
90
+ env=env,
91
  )
92
  self._initialize()
93
 
 
251
  tokenizer=processor.tokenizer,
252
  )
253
 
254
+ # Clear default max_length (20) from generation_config to avoid conflict
255
  # with max_new_tokens passed at inference time
256
  if hasattr(self.pipe.model, "generation_config"):
257
+ gc = self.pipe.model.generation_config
258
+ gc.max_length = None
259
+ gc.max_new_tokens = 400
260
 
261
  self._print(f"Model loaded in {time.time() - start:.1f}s")
262
  self.loaded = True
models/monet_tool.py CHANGED
@@ -3,6 +3,7 @@ MONET Tool - Skin lesion feature extraction using MONET model
3
  Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb
4
  """
5
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  import numpy as np
@@ -90,8 +91,11 @@ class MonetTool:
90
  if self.loaded:
91
  return
92
 
93
- # Determine device
94
- if self.device is None:
 
 
 
95
  if torch.cuda.is_available():
96
  self.device = "cuda:0"
97
  elif torch.backends.mps.is_available():
 
3
  Correct implementation based on MONET tutorial: automatic_concept_annotation.ipynb
4
  """
5
 
6
+ import os
7
  import torch
8
  import torch.nn.functional as F
9
  import numpy as np
 
91
  if self.loaded:
92
  return
93
 
94
+ # Determine device (respect SKINPRO_TOOL_DEVICE override for GPU sharing)
95
+ forced = os.environ.get("SKINPRO_TOOL_DEVICE")
96
+ if forced:
97
+ self.device = forced
98
+ elif self.device is None:
99
  if torch.cuda.is_available():
100
  self.device = "cuda:0"
101
  elif torch.backends.mps.is_available():