Victarry commited on
Commit
fa7e466
·
1 Parent(s): 0a0d256

Make P2P in warmup/cooldown stage to sync comm.

Browse files
Files changed (1) hide show
  1. src/execution_model.py +61 -1
src/execution_model.py CHANGED
@@ -243,6 +243,39 @@ class Schedule:
243
  return None
244
  return self.ops[(batch_id, stage_id, op_type)]
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def get_dependencies(self, op: Operation, include_device_dependency=True):
247
  deps = []
248
  if isinstance(op, OverlappedOperation):
@@ -327,7 +360,34 @@ class Schedule:
327
  if include_device_dependency:
328
  device_index = self.device_queues[op.device_id].ops.index(op)
329
  if device_index > 0:
330
- deps.append((self.device_queues[op.device_id].ops[device_index - 1], 0.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  return deps
332
 
333
  def show(self):
 
243
  return None
244
  return self.ops[(batch_id, stage_id, op_type)]
245
 
246
+ def get_p2p_receiver_op(self, sender_op: Operation) -> Optional[Operation]:
247
+ """
248
+ Get the operation that receives P2P data from sender_op.
249
+
250
+ For forward ops: sender on stage N sends to receiver on stage N+1
251
+ For backward ops: sender on stage N sends to receiver on stage N-1
252
+
253
+ Returns None if there is no P2P receiver (first/last stage).
254
+ """
255
+ if isinstance(sender_op, OverlappedOperation):
256
+ # For overlapped ops, return None (P2P is overlapped with computation)
257
+ return None
258
+
259
+ if sender_op.op_type == "forward":
260
+ # Forward sends to next stage
261
+ next_stage = sender_op.stage_id + 1
262
+ if next_stage >= self.config.num_stages:
263
+ return None # Last stage, no P2P
264
+ return self.get_op(sender_op.batch_id, next_stage, "forward", allow_none=True)
265
+
266
+ elif sender_op.op_type in ("backward", "backward_D"):
267
+ # Backward sends to previous stage
268
+ prev_stage = sender_op.stage_id - 1
269
+ if prev_stage < 0:
270
+ return None # First stage, no P2P
271
+ # Try backward_D first, then backward
272
+ receiver = self.get_op(sender_op.batch_id, prev_stage, "backward_D", allow_none=True)
273
+ if receiver is None:
274
+ receiver = self.get_op(sender_op.batch_id, prev_stage, "backward", allow_none=True)
275
+ return receiver
276
+
277
+ return None
278
+
279
  def get_dependencies(self, op: Operation, include_device_dependency=True):
280
  deps = []
281
  if isinstance(op, OverlappedOperation):
 
360
  if include_device_dependency:
361
  device_index = self.device_queues[op.device_id].ops.index(op)
362
  if device_index > 0:
363
+ prev_op = self.device_queues[op.device_id].ops[device_index - 1]
364
+
365
+ # Check if sync P2P should apply
366
+ # Sync P2P means sender waits for P2P transfer to complete before next op
367
+ # This adds p2p_latency to the device dependency gap
368
+ sync_p2p_gap = 0.0
369
+ if self.config.p2p_latency > 0:
370
+ is_prev_overlapped = isinstance(prev_op, OverlappedOperation)
371
+ is_current_overlapped = isinstance(op, OverlappedOperation)
372
+
373
+ # Only add sync P2P gap when:
374
+ # 1. Both ops are not overlapped (not in overlap schedule's steady state)
375
+ # 2. Both ops have the same base type (both forward or both backward)
376
+ # 3. Both ops are on the same stage (ensures we're in a pure warmup/cooldown
377
+ # sequence for a specific stage, avoiding cycles in interleaved schedules)
378
+ if not is_prev_overlapped and not is_current_overlapped:
379
+ prev_base_type = "backward" if prev_op.op_type.startswith("backward") else prev_op.op_type
380
+ curr_base_type = "backward" if op.op_type.startswith("backward") else op.op_type
381
+
382
+ if prev_base_type == curr_base_type and prev_op.stage_id == op.stage_id:
383
+ receiver_op = self.get_p2p_receiver_op(prev_op)
384
+ if receiver_op is not None and not isinstance(receiver_op, OverlappedOperation):
385
+ # Sync P2P: sender waits for P2P transfer to complete
386
+ # Current op starts after prev_op.end_time + p2p_latency
387
+ # (not after receiver completes, just after transfer completes)
388
+ sync_p2p_gap = self.config.p2p_latency
389
+
390
+ deps.append((prev_op, sync_p2p_gap))
391
  return deps
392
 
393
  def show(self):