Spaces:
Running
Running
Make P2P in warmup/cooldown stage to sync comm.
Browse files- src/execution_model.py +61 -1
src/execution_model.py
CHANGED
|
@@ -243,6 +243,39 @@ class Schedule:
|
|
| 243 |
return None
|
| 244 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
| 247 |
deps = []
|
| 248 |
if isinstance(op, OverlappedOperation):
|
|
@@ -327,7 +360,34 @@ class Schedule:
|
|
| 327 |
if include_device_dependency:
|
| 328 |
device_index = self.device_queues[op.device_id].ops.index(op)
|
| 329 |
if device_index > 0:
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
return deps
|
| 332 |
|
| 333 |
def show(self):
|
|
|
|
| 243 |
return None
|
| 244 |
return self.ops[(batch_id, stage_id, op_type)]
|
| 245 |
|
| 246 |
+
def get_p2p_receiver_op(self, sender_op: Operation) -> Optional[Operation]:
|
| 247 |
+
"""
|
| 248 |
+
Get the operation that receives P2P data from sender_op.
|
| 249 |
+
|
| 250 |
+
For forward ops: sender on stage N sends to receiver on stage N+1
|
| 251 |
+
For backward ops: sender on stage N sends to receiver on stage N-1
|
| 252 |
+
|
| 253 |
+
Returns None if there is no P2P receiver (first/last stage).
|
| 254 |
+
"""
|
| 255 |
+
if isinstance(sender_op, OverlappedOperation):
|
| 256 |
+
# For overlapped ops, return None (P2P is overlapped with computation)
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
if sender_op.op_type == "forward":
|
| 260 |
+
# Forward sends to next stage
|
| 261 |
+
next_stage = sender_op.stage_id + 1
|
| 262 |
+
if next_stage >= self.config.num_stages:
|
| 263 |
+
return None # Last stage, no P2P
|
| 264 |
+
return self.get_op(sender_op.batch_id, next_stage, "forward", allow_none=True)
|
| 265 |
+
|
| 266 |
+
elif sender_op.op_type in ("backward", "backward_D"):
|
| 267 |
+
# Backward sends to previous stage
|
| 268 |
+
prev_stage = sender_op.stage_id - 1
|
| 269 |
+
if prev_stage < 0:
|
| 270 |
+
return None # First stage, no P2P
|
| 271 |
+
# Try backward_D first, then backward
|
| 272 |
+
receiver = self.get_op(sender_op.batch_id, prev_stage, "backward_D", allow_none=True)
|
| 273 |
+
if receiver is None:
|
| 274 |
+
receiver = self.get_op(sender_op.batch_id, prev_stage, "backward", allow_none=True)
|
| 275 |
+
return receiver
|
| 276 |
+
|
| 277 |
+
return None
|
| 278 |
+
|
| 279 |
def get_dependencies(self, op: Operation, include_device_dependency=True):
|
| 280 |
deps = []
|
| 281 |
if isinstance(op, OverlappedOperation):
|
|
|
|
| 360 |
if include_device_dependency:
|
| 361 |
device_index = self.device_queues[op.device_id].ops.index(op)
|
| 362 |
if device_index > 0:
|
| 363 |
+
prev_op = self.device_queues[op.device_id].ops[device_index - 1]
|
| 364 |
+
|
| 365 |
+
# Check if sync P2P should apply
|
| 366 |
+
# Sync P2P means sender waits for P2P transfer to complete before next op
|
| 367 |
+
# This adds p2p_latency to the device dependency gap
|
| 368 |
+
sync_p2p_gap = 0.0
|
| 369 |
+
if self.config.p2p_latency > 0:
|
| 370 |
+
is_prev_overlapped = isinstance(prev_op, OverlappedOperation)
|
| 371 |
+
is_current_overlapped = isinstance(op, OverlappedOperation)
|
| 372 |
+
|
| 373 |
+
# Only add sync P2P gap when:
|
| 374 |
+
# 1. Both ops are not overlapped (not in overlap schedule's steady state)
|
| 375 |
+
# 2. Both ops have the same base type (both forward or both backward)
|
| 376 |
+
# 3. Both ops are on the same stage (ensures we're in a pure warmup/cooldown
|
| 377 |
+
# sequence for a specific stage, avoiding cycles in interleaved schedules)
|
| 378 |
+
if not is_prev_overlapped and not is_current_overlapped:
|
| 379 |
+
prev_base_type = "backward" if prev_op.op_type.startswith("backward") else prev_op.op_type
|
| 380 |
+
curr_base_type = "backward" if op.op_type.startswith("backward") else op.op_type
|
| 381 |
+
|
| 382 |
+
if prev_base_type == curr_base_type and prev_op.stage_id == op.stage_id:
|
| 383 |
+
receiver_op = self.get_p2p_receiver_op(prev_op)
|
| 384 |
+
if receiver_op is not None and not isinstance(receiver_op, OverlappedOperation):
|
| 385 |
+
# Sync P2P: sender waits for P2P transfer to complete
|
| 386 |
+
# Current op starts after prev_op.end_time + p2p_latency
|
| 387 |
+
# (not after receiver completes, just after transfer completes)
|
| 388 |
+
sync_p2p_gap = self.config.p2p_latency
|
| 389 |
+
|
| 390 |
+
deps.append((prev_op, sync_p2p_gap))
|
| 391 |
return deps
|
| 392 |
|
| 393 |
def show(self):
|