download
raw
17.7 kB
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import collections
from dataclasses import dataclass
from typing import Any, Dict, List, Set, Type
import torch
from torch.fx import Node
from magi_compiler.config import CompileConfig
from magi_compiler.utils import magi_logger
@dataclass
class OffloadRuntimeContext:
"""
Wrap the dynamic runtime context needed during Offload
"""
env: Dict[Node, Any] # Tensor environment of the current computation graph
h2d_stream: torch.cuda.Stream # CUDA stream for data transfer
compute_stream: torch.cuda.Stream # CUDA stream for computation
buffers: Dict[str, torch.Tensor] # GPU copies stored during Warmup
submod_0_handoff: Dict[Node, torch.Tensor] # Container for weights used by Submod 0 for the next iteration
need_profile: bool = False
SchedulerType = Type["OffloadScheduler"]
class SchedulerFactory:
_REGISTRY: Dict[str, SchedulerType] = {}
@classmethod
def register(cls, name: str):
def decorator(scheduler_cls: SchedulerType):
cls._REGISTRY[name] = scheduler_cls
return scheduler_cls
return decorator
@classmethod
def create(cls, compile_config: CompileConfig, common_args: Dict[str, Any]) -> "OffloadScheduler":
offload_policy_name = compile_config.offload_config.offload_policy.value
scheduler_cls = cls._REGISTRY.get(offload_policy_name)
if not scheduler_cls:
raise ValueError(f"Unknown offload policy: {offload_policy_name}. " f"Available: {list(cls._REGISTRY.keys())}")
return scheduler_cls(compile_config=compile_config, **common_args)
class OffloadScheduler(abc.ABC):
def __init__(
self,
compile_config: CompileConfig,
submod_nodes: List[Node],
submod_weights_map: Dict[str, List[str]],
name_node_map: Dict[str, Node],
weight_sizes: Dict[str, int],
):
"""
:param config: Compile configuration
:param submod_nodes: List of submodule nodes in execution order
:param submod_weights_map: {submod_name: [weight_names...]}
:param name_node_map: {weight_name: weight_node} for looking up Node objects
"""
self.compile_config = compile_config
self.submod_nodes = submod_nodes
self.submod_num = len(submod_nodes)
self.submod_weights_map = submod_weights_map
self.name_node_map = name_node_map
self.weight_sizes = weight_sizes
self.kept_weights: Set[str] = set()
@abc.abstractmethod
def schedule_kept_weights(self, profile_data: Dict[str, float]):
"""
Static decision: determine which weights to keep on GPU based on profile data
"""
pass
@abc.abstractmethod
def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext):
"""
:param env: Tensor environment of the current execution environment
:param h2d_stream: CUDA stream for data transfer
:param buffers: GPU copies stored during Warmup
:param submod_0_next_iter_weights: Container for weights used by Submod 0 for the next iteration
"""
pass
def is_kept(self, weight_name: str) -> bool:
return weight_name in self.kept_weights
def get_keep_weight_size(self) -> int:
return sum(
[
self.name_node_map[w_name].meta.get("example_value").numel()
* self.name_node_map[w_name].meta.get("example_value").element_size()
for w_name in self.kept_weights
]
)
@SchedulerFactory.register("BASE")
class BaseScheduler(OffloadScheduler):
"""
Basic strategy implementation:
1. schedule_kept_weights: Default to not keep any weights on GPU (gpu_resident_weight_ratio=0.0)
2. prefetch: Inherit the original "Odd/Even + Next-Next" logic
"""
def __init__(
self,
compile_config: CompileConfig,
submod_nodes: List[Node],
submod_weights_map: Dict[str, List[str]],
name_node_map: Dict[str, Node],
weight_sizes: Dict[str, int],
):
super().__init__(compile_config, submod_nodes, submod_weights_map, name_node_map, weight_sizes)
self.submod_cpu_names = set()
def schedule_kept_weights(self, profile_data: Dict[str, float]):
self.kept_weights.clear()
# offload all submodules to CPU
self.submod_cpu_names = set(self.submod_nodes[i].name for i in range(len(self.submod_nodes)))
def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext):
env = ctx.env
h2d_stream = ctx.h2d_stream
buffers = ctx.buffers
submod_0_next_iter_weights = ctx.submod_0_handoff
compute_stream = ctx.compute_stream
try:
idx = int(current_node_name.split('_')[-1])
except ValueError:
return
max_lookahead = 2
target_node = None
is_next_iter = False
for offset in range(1, max_lookahead + 1):
candidate_idx = (idx + offset) % self.submod_num
candidate_node = self.submod_nodes[candidate_idx]
if self.weight_sizes.get(candidate_node.name, 0) > 0:
target_node = candidate_node
if (idx + offset) >= self.submod_num:
is_next_iter = True
else:
is_next_iter = False
break
if target_node is None:
return
last_target = getattr(self, "last_prefetched_target", None)
if idx == 0:
last_target = None
if target_node.name == last_target:
return
self.last_prefetched_target = target_node.name
compute_stream.wait_stream(h2d_stream)
weight_names = self.submod_weights_map.get(target_node.name, [])
with torch.cuda.stream(h2d_stream):
for w_name in weight_names:
if w_name in self.kept_weights:
continue
w_node = self.name_node_map.get(w_name)
if not w_node:
continue
if is_next_iter:
if w_name in buffers:
submod_0_next_iter_weights[w_node] = buffers[w_name].to("cuda", non_blocking=True)
elif w_node in env and hasattr(env[w_node], 'device') and env[w_node].device.type == "cpu":
env[w_node] = env[w_node].to("cuda", non_blocking=True)
@SchedulerFactory.register("COST_EFFECTIVE")
class CostEffectiveScheduler(BaseScheduler):
"""
Cost-effective strategy implementation:
Determine which weights to keep on GPU based on the ratio of time to size
"""
def schedule_kept_weights(self, profile_data: Dict[str, float]):
weight_sizes = self.weight_sizes
gpu_resident_weight_ratio = self.compile_config.offload_config.gpu_resident_weight_ratio
if gpu_resident_weight_ratio <= 0.0:
return
candidates = []
total_weight_size = 0
for name, timing in profile_data.items():
if name == "h2d_bandwidth":
continue
duration = timing.get("compute", 0.0)
size = weight_sizes.get(name, 0)
ratio = duration / size if size > 0 else float('inf')
candidates.append((name, ratio, size))
total_weight_size += size
candidates.sort(key=lambda x: x[1])
current_kept_size = 0
limit = total_weight_size * gpu_resident_weight_ratio
self.kept_weights.clear()
for name, ratio, s in candidates:
if name == "submod_0":
continue
if current_kept_size + s <= limit:
current_kept_size += s
for w_name in self.submod_weights_map.get(name, []):
self.kept_weights.add(w_name)
self.submod_cpu_names.add(name)
else:
break
magi_logger.info(f"schedule_kept_weights size {current_kept_size} keep_submod_names: {self.submod_cpu_names}")
@SchedulerFactory.register("HEURISTIC")
class HeuristicScheduler(BaseScheduler):
def __init__(
self,
compile_config: CompileConfig,
submod_nodes: List[Node],
submod_weights_map: Dict[str, List[str]],
name_node_map: Dict[str, Node],
weight_sizes: Dict[str, int],
):
super().__init__(compile_config, submod_nodes, submod_weights_map, name_node_map, weight_sizes)
# safety margin (ms), default 0.1ms, to prevent pipeline bubbles due to profile fluctuations
self.safety_margin = getattr(compile_config, "prefetch_margin_ms", 0.1)
# prefetch schedule: { trigger_node_name : target_submod_index }
# when execute the trigger_node_name, prefetch the target_submod_index
self.prefetch_schedule = collections.defaultdict()
self.submod_load_events: Dict[str, torch.cuda.Event] = collections.defaultdict(torch.cuda.Event)
for i in range(len(self.submod_nodes)):
self.submod_load_events[self.submod_nodes[i].name] = torch.cuda.Event()
def schedule_kept_weights(self, profile_data: Dict[str, Any]):
"""
Core logic: build the prefetch schedule based on the profile data.
profile_data format:
{
"submod_name": {
"compute": 10.5,
"h2d": 5.2
},
"h2d_bandwidth": 5.2
}
"""
self.kept_weights.clear()
self.prefetch_schedule.clear()
weight_sizes = self.weight_sizes
len(self.submod_nodes)
compute_times = []
h2d_times = []
# GiB/s
EST_BANDWIDTH = profile_data.get("h2d_bandwidth") * self.compile_config.offload_config.bandwidth_safety_factor
for i, node in enumerate(self.submod_nodes):
data = profile_data.get(node.name, {})
c_time = data.get("compute", 0.0)
w_size = weight_sizes.get(node.name, 0)
# w_size is in bytes, EST_BANDWIDTH is in GB/s (1e9 bytes/s)
# h_time should be in milliseconds
h_time = (w_size / (EST_BANDWIDTH * 1e6)) if EST_BANDWIDTH > 0 else 0
compute_times.append(c_time)
h2d_times.append(h_time)
keep_submod_names = collections.defaultdict(list)
schedule_nodes = collections.defaultdict(dict)
submod_cpu_name = set()
cpu_weight_size = 0
for i in range(len(self.submod_nodes) - 1, -1, -1):
if weight_sizes.get(self.submod_nodes[i].name, 0.0) > 0.0:
offload_idx = i
break
# traverse submod_nodes in reverse order to ensure the latest start of transmission
i = offload_idx - 1
schedule_node_compute_time = 0
while i >= 0:
end_idx = i
schedule_node_compute_time = 0
while schedule_node_compute_time < self.safety_margin + h2d_times[offload_idx] and i >= 0:
schedule_node_compute_time += compute_times[i]
i -= 1
if i < 0 and schedule_node_compute_time < self.safety_margin + h2d_times[offload_idx]:
break
submod_cpu_name.add(self.submod_nodes[offload_idx].name)
cpu_weight_size += weight_sizes.get(self.submod_nodes[offload_idx].name, 0)
schedule_nodes[offload_idx].update(
{
"compute_start_idx": i + 1,
"compute_end_idx": end_idx,
"h2d_idx": offload_idx,
"compute_time": schedule_node_compute_time,
"h2d_time": h2d_times[offload_idx],
"ratio": schedule_node_compute_time / (self.safety_margin + h2d_times[offload_idx]),
}
)
offload_idx = i + 1
while weight_sizes.get(self.submod_nodes[offload_idx].name, 0.0) == 0.0 and offload_idx <= end_idx:
offload_idx += 1
gpu_resident_weight_ratio = self.compile_config.offload_config.gpu_resident_weight_ratio
total_weight_size = sum(weight_sizes.values())
limit = total_weight_size * gpu_resident_weight_ratio
keep_gpu_size = total_weight_size - cpu_weight_size
while limit < keep_gpu_size:
schedule_nodes_list = [(i, s_node["ratio"]) for i, s_node in schedule_nodes.items()]
schedule_nodes_list.sort(key=lambda x: x[1], reverse=True)
changed = False
for s_node_idx, _ in schedule_nodes_list:
s_node = schedule_nodes[s_node_idx]
compute_start_idx = s_node["compute_start_idx"]
for i in range(s_node["compute_start_idx"] + 1, s_node["compute_end_idx"] + 1):
if self.submod_nodes[i].name in submod_cpu_name or weight_sizes.get(self.submod_nodes[i].name, 0.0) == 0.0:
continue
offload_idx = i
schedule_node_compute_time = sum(compute_times[compute_start_idx:i])
schedule_nodes[offload_idx].update(
{
"compute_start_idx": compute_start_idx,
"compute_end_idx": max(i - 1, compute_start_idx),
"h2d_idx": offload_idx,
"compute_time": schedule_node_compute_time,
"h2d_time": h2d_times[offload_idx],
"ratio": schedule_node_compute_time / (self.safety_margin + h2d_times[offload_idx]),
}
)
submod_cpu_name.add(self.submod_nodes[i].name)
keep_gpu_size -= weight_sizes.get(self.submod_nodes[i].name, 0)
schedule_nodes[s_node_idx].update(
{
"compute_start_idx": i,
"compute_end_idx": s_node["compute_end_idx"],
"h2d_idx": s_node_idx,
"compute_time": s_node["compute_time"] - schedule_node_compute_time,
"h2d_time": h2d_times[s_node_idx],
"ratio": (s_node["compute_time"] - schedule_node_compute_time)
/ (self.safety_margin + h2d_times[s_node_idx]),
}
)
changed = True
break
if changed:
break
if not changed:
break
self.submod_cpu_names = submod_cpu_name
for i, s_node in schedule_nodes.items():
self.prefetch_schedule[self.submod_nodes[s_node["compute_start_idx"]].name] = s_node["h2d_idx"]
# add weights of submod that are not on CPU
keep_submod_names = set()
for i in range(len(self.submod_nodes)):
if self.submod_nodes[i].name not in submod_cpu_name:
for w_name in self.submod_weights_map.get(self.submod_nodes[i].name, []):
self.kept_weights.add(w_name)
keep_submod_names.add(self.submod_nodes[i].name)
def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext):
"""
1. check if the current node is on CPU, if yes, wait for the submod load event
2. if not on CPU, skip
3. prefetch the weights of the submod that the current node is pointing to
"""
if ctx.need_profile:
return super().prefetch(current_node_name, ctx)
env = ctx.env
h2d_stream = ctx.h2d_stream
buffers = ctx.buffers
submod_0_next_iter_weights = ctx.submod_0_handoff
compute_stream = ctx.compute_stream
if current_node_name in self.submod_cpu_names:
compute_stream.wait_event(self.submod_load_events[current_node_name])
h2d_idx = self.prefetch_schedule.get(current_node_name, None)
if h2d_idx is None:
return
compute_complete_event = torch.cuda.Event()
compute_complete_event.record(compute_stream)
h2d_stream.wait_event(compute_complete_event)
h2d_node = self.submod_nodes[h2d_idx]
is_last_step = h2d_idx == len(self.submod_nodes) - 1
weight_names = self.submod_weights_map.get(h2d_node.name, [])
with torch.cuda.stream(h2d_stream):
for w_name in weight_names:
if w_name in self.kept_weights:
continue
w_node = self.name_node_map.get(w_name)
if w_node in env and hasattr(env[w_node], 'device') and env[w_node].device.type == "cpu":
env[w_node] = env[w_node].to("cuda", non_blocking=True)
elif is_last_step and w_name in buffers:
submod_0_next_iter_weights[w_node] = buffers[w_name].to("cuda", non_blocking=True)
self.submod_load_events[h2d_node.name].record(h2d_stream)

Xet Storage Details

Size:
17.7 kB
·
Xet hash:
0bf328e052b35f5b19d7063c7a0e3d7b61cb8ad5dcad504c10c96f0c78dcc9c4

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.