raylim commited on
Commit
14f8f4b
·
unverified ·
1 Parent(s): 658b7b2

Implement proper multi-GPU-call chunking for ZeroGPU

Browse files

- Each chunk gets its own @spaces.GPU decorated function call
- Each call gets fresh 300s GPU token (no token sharing/expiry)
- Wrapper function (no decorator) splits work and calls chunk functions
- CTransPath: 2000 tiles per chunk (180s each)
- Optimus: 1500 tiles per chunk (300s each)
- Non-ZeroGPU: processes all tiles at once (no chunking overhead)

This is the correct pattern for ZeroGPU time limits.

Files changed (1) hide show
  1. src/mosaic/analysis.py +57 -62
src/mosaic/analysis.py CHANGED
@@ -40,8 +40,10 @@ from mosaic.inference import run_aeon, run_paladin
40
 
41
 
42
  @spaces.GPU(duration=180)
43
- def _extract_ctranspath_features_chunk(coords_chunk, slide_path, attrs, num_workers, batch_size):
44
- """Extract CTransPath features for a chunk of coordinates on GPU.
 
 
45
 
46
  Args:
47
  coords_chunk: Chunk of tissue tile coordinates
@@ -51,7 +53,7 @@ def _extract_ctranspath_features_chunk(coords_chunk, slide_path, attrs, num_work
51
  batch_size: Batch size for inference
52
 
53
  Returns:
54
- tuple: (ctranspath_features, coords_chunk)
55
  """
56
  features, _ = get_features(
57
  coords_chunk,
@@ -67,7 +69,7 @@ def _extract_ctranspath_features_chunk(coords_chunk, slide_path, attrs, num_work
67
 
68
 
69
  def _extract_ctranspath_features(coords, slide_path, attrs, num_workers):
70
- """Extract CTransPath features on GPU.
71
 
72
  Args:
73
  coords: Tissue tile coordinates
@@ -80,63 +82,61 @@ def _extract_ctranspath_features(coords, slide_path, attrs, num_workers):
80
  """
81
  if IS_ZEROGPU:
82
  num_workers = 0
83
- logger.info("Running CTransPath on ZeroGPU: setting num_workers=0")
84
- # Split into smaller chunks to stay within GPU time limits
85
  chunk_size = 2000
86
- total_tiles = len(coords)
87
- logger.info(f"Processing {total_tiles} tiles in chunks of {chunk_size}")
88
  else:
89
  num_workers = max(num_workers, 8)
90
- logger.info(f"Running CTransPath with num_workers={num_workers}")
91
  chunk_size = len(coords) # Process all at once
 
92
 
93
- # Use larger batch size on H100 for better throughput
94
  batch_size = 128 if IS_ZEROGPU else 64
95
 
96
  start_time = pd.Timestamp.now()
97
 
98
- # Process in chunks
99
  all_features = []
100
  for i in range(0, len(coords), chunk_size):
101
  chunk_coords = coords[i:i+chunk_size]
102
  chunk_num = i // chunk_size + 1
103
  total_chunks = (len(coords) + chunk_size - 1) // chunk_size
104
 
105
- logger.info(f"Extracting CTransPath features for chunk {chunk_num}/{total_chunks} "
106
- f"({len(chunk_coords)} tiles, batch_size={batch_size})")
107
 
108
- chunk_features = _extract_ctranspath_features_chunk(
109
- chunk_coords, slide_path, attrs, num_workers, batch_size
110
- )
111
- all_features.append(chunk_features)
 
 
 
 
 
 
 
 
 
 
 
112
 
 
113
  logger.info(f"Chunk {chunk_num}/{total_chunks} completed")
114
 
115
  # Concatenate all features
116
  import numpy as np
117
- ctranspath_features = np.concatenate(all_features, axis=0)
118
 
119
  end_time = pd.Timestamp.now()
120
-
121
- # Log memory stats if CUDA is available
122
- if torch.cuda.is_available():
123
- try:
124
- max_gpu_memory = torch.cuda.max_memory_allocated() / (1024**3)
125
- logger.info(
126
- f"CTransPath extraction took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
127
- )
128
- torch.cuda.reset_peak_memory_stats()
129
- except Exception:
130
- logger.info(f"CTransPath extraction took {end_time - start_time}")
131
- else:
132
- logger.info(f"CTransPath extraction took {end_time - start_time}")
133
 
134
  return ctranspath_features, coords
135
 
136
 
137
  @spaces.GPU(duration=300)
138
- def _extract_optimus_features_chunk(coords_chunk, slide_path, attrs, num_workers, batch_size):
139
- """Extract Optimus features for a chunk of coordinates on GPU.
 
 
140
 
141
  Args:
142
  coords_chunk: Chunk of tissue tile coordinates
@@ -162,7 +162,7 @@ def _extract_optimus_features_chunk(coords_chunk, slide_path, attrs, num_workers
162
 
163
 
164
  def _extract_optimus_features(filtered_coords, slide_path, attrs, num_workers):
165
- """Extract Optimus features on GPU.
166
 
167
  Args:
168
  filtered_coords: Filtered tissue tile coordinates
@@ -175,57 +175,52 @@ def _extract_optimus_features(filtered_coords, slide_path, attrs, num_workers):
175
  """
176
  if IS_ZEROGPU:
177
  num_workers = 0
178
- logger.info("Running Optimus on ZeroGPU: setting num_workers=0")
179
- # Split into very small chunks to stay within GPU time limits
180
- # Even 1881 tiles caused expiry, so use 1500 tiles per chunk
181
  chunk_size = 1500
182
- total_tiles = len(filtered_coords)
183
- logger.info(f"Processing {total_tiles} tiles in chunks of {chunk_size}")
184
  else:
185
  num_workers = max(num_workers, 8)
186
- logger.info(f"Running Optimus with num_workers={num_workers}")
187
  chunk_size = len(filtered_coords) # Process all at once
 
188
 
189
- # Use larger batch size on H100 for better throughput
190
  batch_size = 128 if IS_ZEROGPU else 64
191
 
192
  start_time = pd.Timestamp.now()
193
 
194
- # Process in chunks
195
  all_features = []
196
  for i in range(0, len(filtered_coords), chunk_size):
197
  chunk_coords = filtered_coords[i:i+chunk_size]
198
  chunk_num = i // chunk_size + 1
199
  total_chunks = (len(filtered_coords) + chunk_size - 1) // chunk_size
200
 
201
- logger.info(f"Extracting Optimus features for chunk {chunk_num}/{total_chunks} "
202
- f"({len(chunk_coords)} tiles, batch_size={batch_size})")
203
 
204
- chunk_features = _extract_optimus_features_chunk(
205
- chunk_coords, slide_path, attrs, num_workers, batch_size
206
- )
207
- all_features.append(chunk_features)
 
 
 
 
 
 
 
 
 
 
 
208
 
 
209
  logger.info(f"Chunk {chunk_num}/{total_chunks} completed")
210
 
211
  # Concatenate all features
212
  import numpy as np
213
- features = np.concatenate(all_features, axis=0)
214
 
215
  end_time = pd.Timestamp.now()
216
-
217
- # Log memory stats if CUDA is available
218
- if torch.cuda.is_available():
219
- try:
220
- max_gpu_memory = torch.cuda.max_memory_allocated() / (1024**3)
221
- logger.info(
222
- f"Optimus extraction took {end_time - start_time} and used {max_gpu_memory:.2f} GB GPU memory"
223
- )
224
- torch.cuda.reset_peak_memory_stats()
225
- except Exception:
226
- logger.info(f"Optimus extraction took {end_time - start_time}")
227
- else:
228
- logger.info(f"Optimus extraction took {end_time - start_time}")
229
 
230
  return features
231
 
 
40
 
41
 
42
  @spaces.GPU(duration=180)
43
+ def _extract_ctranspath_features_single_chunk(coords_chunk, slide_path, attrs, num_workers, batch_size):
44
+ """Extract CTransPath features for ONE chunk with its own GPU allocation.
45
+
46
+ This function gets its own GPU token for up to 180 seconds.
47
 
48
  Args:
49
  coords_chunk: Chunk of tissue tile coordinates
 
53
  batch_size: Batch size for inference
54
 
55
  Returns:
56
+ CTransPath features for this chunk
57
  """
58
  features, _ = get_features(
59
  coords_chunk,
 
69
 
70
 
71
  def _extract_ctranspath_features(coords, slide_path, attrs, num_workers):
72
+ """Extract CTransPath features, splitting into multiple GPU calls if needed.
73
 
74
  Args:
75
  coords: Tissue tile coordinates
 
82
  """
83
  if IS_ZEROGPU:
84
  num_workers = 0
85
+ # Split into chunks - each chunk gets its own GPU call/token
 
86
  chunk_size = 2000
87
+ logger.info(f"Running CTransPath on ZeroGPU: splitting {len(coords)} tiles into chunks of {chunk_size}")
 
88
  else:
89
  num_workers = max(num_workers, 8)
 
90
  chunk_size = len(coords) # Process all at once
91
+ logger.info(f"Running CTransPath with {num_workers} workers")
92
 
 
93
  batch_size = 128 if IS_ZEROGPU else 64
94
 
95
  start_time = pd.Timestamp.now()
96
 
97
+ # Process each chunk with separate GPU allocation
98
  all_features = []
99
  for i in range(0, len(coords), chunk_size):
100
  chunk_coords = coords[i:i+chunk_size]
101
  chunk_num = i // chunk_size + 1
102
  total_chunks = (len(coords) + chunk_size - 1) // chunk_size
103
 
104
+ logger.info(f"Processing CTransPath chunk {chunk_num}/{total_chunks} ({len(chunk_coords)} tiles)")
 
105
 
106
+ if IS_ZEROGPU:
107
+ # Each call gets fresh GPU token
108
+ chunk_features = _extract_ctranspath_features_single_chunk(
109
+ chunk_coords, slide_path, attrs, num_workers, batch_size
110
+ )
111
+ else:
112
+ # Non-ZeroGPU: direct call without decorator overhead
113
+ chunk_features, _ = get_features(
114
+ chunk_coords, slide_path, attrs,
115
+ model_type=ModelType.CTRANSPATH,
116
+ model_path="data/ctranspath.pth",
117
+ num_workers=num_workers,
118
+ batch_size=batch_size,
119
+ use_gpu=True,
120
+ )
121
 
122
+ all_features.append(chunk_features)
123
  logger.info(f"Chunk {chunk_num}/{total_chunks} completed")
124
 
125
  # Concatenate all features
126
  import numpy as np
127
+ ctranspath_features = np.concatenate(all_features, axis=0) if len(all_features) > 1 else all_features[0]
128
 
129
  end_time = pd.Timestamp.now()
130
+ logger.info(f"CTransPath extraction took {end_time - start_time} total")
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  return ctranspath_features, coords
133
 
134
 
135
  @spaces.GPU(duration=300)
136
+ def _extract_optimus_features_single_chunk(coords_chunk, slide_path, attrs, num_workers, batch_size):
137
+ """Extract Optimus features for ONE chunk with its own GPU allocation.
138
+
139
+ This function gets its own GPU token for up to 300 seconds.
140
 
141
  Args:
142
  coords_chunk: Chunk of tissue tile coordinates
 
162
 
163
 
164
  def _extract_optimus_features(filtered_coords, slide_path, attrs, num_workers):
165
+ """Extract Optimus features, splitting into multiple GPU calls if needed.
166
 
167
  Args:
168
  filtered_coords: Filtered tissue tile coordinates
 
175
  """
176
  if IS_ZEROGPU:
177
  num_workers = 0
178
+ # Split into chunks - each chunk gets its own GPU call/token
 
 
179
  chunk_size = 1500
180
+ logger.info(f"Running Optimus on ZeroGPU: splitting {len(filtered_coords)} tiles into chunks of {chunk_size}")
 
181
  else:
182
  num_workers = max(num_workers, 8)
 
183
  chunk_size = len(filtered_coords) # Process all at once
184
+ logger.info(f"Running Optimus with {num_workers} workers")
185
 
 
186
  batch_size = 128 if IS_ZEROGPU else 64
187
 
188
  start_time = pd.Timestamp.now()
189
 
190
+ # Process each chunk with separate GPU allocation
191
  all_features = []
192
  for i in range(0, len(filtered_coords), chunk_size):
193
  chunk_coords = filtered_coords[i:i+chunk_size]
194
  chunk_num = i // chunk_size + 1
195
  total_chunks = (len(filtered_coords) + chunk_size - 1) // chunk_size
196
 
197
+ logger.info(f"Processing Optimus chunk {chunk_num}/{total_chunks} ({len(chunk_coords)} tiles)")
 
198
 
199
+ if IS_ZEROGPU:
200
+ # Each call gets fresh GPU token
201
+ chunk_features = _extract_optimus_features_single_chunk(
202
+ chunk_coords, slide_path, attrs, num_workers, batch_size
203
+ )
204
+ else:
205
+ # Non-ZeroGPU: direct call without decorator overhead
206
+ chunk_features, _ = get_features(
207
+ chunk_coords, slide_path, attrs,
208
+ model_type=ModelType.OPTIMUS,
209
+ model_path="data/optimus.pkl",
210
+ num_workers=num_workers,
211
+ batch_size=batch_size,
212
+ use_gpu=True,
213
+ )
214
 
215
+ all_features.append(chunk_features)
216
  logger.info(f"Chunk {chunk_num}/{total_chunks} completed")
217
 
218
  # Concatenate all features
219
  import numpy as np
220
+ features = np.concatenate(all_features, axis=0) if len(all_features) > 1 else all_features[0]
221
 
222
  end_time = pd.Timestamp.now()
223
+ logger.info(f"Optimus extraction took {end_time - start_time} total")
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  return features
226