Buckets:
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import abc | |
| import collections | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Set, Type | |
| import torch | |
| from torch.fx import Node | |
| from magi_compiler.config import CompileConfig | |
| from magi_compiler.utils import magi_logger | |
| class OffloadRuntimeContext: | |
| """ | |
| Wrap the dynamic runtime context needed during Offload | |
| """ | |
| env: Dict[Node, Any] # Tensor environment of the current computation graph | |
| h2d_stream: torch.cuda.Stream # CUDA stream for data transfer | |
| compute_stream: torch.cuda.Stream # CUDA stream for computation | |
| buffers: Dict[str, torch.Tensor] # GPU copies stored during Warmup | |
| submod_0_handoff: Dict[Node, torch.Tensor] # Container for weights used by Submod 0 for the next iteration | |
| need_profile: bool = False | |
| SchedulerType = Type["OffloadScheduler"] | |
| class SchedulerFactory: | |
| _REGISTRY: Dict[str, SchedulerType] = {} | |
| def register(cls, name: str): | |
| def decorator(scheduler_cls: SchedulerType): | |
| cls._REGISTRY[name] = scheduler_cls | |
| return scheduler_cls | |
| return decorator | |
| def create(cls, compile_config: CompileConfig, common_args: Dict[str, Any]) -> "OffloadScheduler": | |
| offload_policy_name = compile_config.offload_config.offload_policy.value | |
| scheduler_cls = cls._REGISTRY.get(offload_policy_name) | |
| if not scheduler_cls: | |
| raise ValueError(f"Unknown offload policy: {offload_policy_name}. " f"Available: {list(cls._REGISTRY.keys())}") | |
| return scheduler_cls(compile_config=compile_config, **common_args) | |
| class OffloadScheduler(abc.ABC): | |
| def __init__( | |
| self, | |
| compile_config: CompileConfig, | |
| submod_nodes: List[Node], | |
| submod_weights_map: Dict[str, List[str]], | |
| name_node_map: Dict[str, Node], | |
| weight_sizes: Dict[str, int], | |
| ): | |
| """ | |
| :param config: Compile configuration | |
| :param submod_nodes: List of submodule nodes in execution order | |
| :param submod_weights_map: {submod_name: [weight_names...]} | |
| :param name_node_map: {weight_name: weight_node} for looking up Node objects | |
| """ | |
| self.compile_config = compile_config | |
| self.submod_nodes = submod_nodes | |
| self.submod_num = len(submod_nodes) | |
| self.submod_weights_map = submod_weights_map | |
| self.name_node_map = name_node_map | |
| self.weight_sizes = weight_sizes | |
| self.kept_weights: Set[str] = set() | |
| def schedule_kept_weights(self, profile_data: Dict[str, float]): | |
| """ | |
| Static decision: determine which weights to keep on GPU based on profile data | |
| """ | |
| pass | |
| def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext): | |
| """ | |
| :param env: Tensor environment of the current execution environment | |
| :param h2d_stream: CUDA stream for data transfer | |
| :param buffers: GPU copies stored during Warmup | |
| :param submod_0_next_iter_weights: Container for weights used by Submod 0 for the next iteration | |
| """ | |
| pass | |
| def is_kept(self, weight_name: str) -> bool: | |
| return weight_name in self.kept_weights | |
| def get_keep_weight_size(self) -> int: | |
| return sum( | |
| [ | |
| self.name_node_map[w_name].meta.get("example_value").numel() | |
| * self.name_node_map[w_name].meta.get("example_value").element_size() | |
| for w_name in self.kept_weights | |
| ] | |
| ) | |
| class BaseScheduler(OffloadScheduler): | |
| """ | |
| Basic strategy implementation: | |
| 1. schedule_kept_weights: Default to not keep any weights on GPU (gpu_resident_weight_ratio=0.0) | |
| 2. prefetch: Inherit the original "Odd/Even + Next-Next" logic | |
| """ | |
| def __init__( | |
| self, | |
| compile_config: CompileConfig, | |
| submod_nodes: List[Node], | |
| submod_weights_map: Dict[str, List[str]], | |
| name_node_map: Dict[str, Node], | |
| weight_sizes: Dict[str, int], | |
| ): | |
| super().__init__(compile_config, submod_nodes, submod_weights_map, name_node_map, weight_sizes) | |
| self.submod_cpu_names = set() | |
| def schedule_kept_weights(self, profile_data: Dict[str, float]): | |
| self.kept_weights.clear() | |
| # offload all submodules to CPU | |
| self.submod_cpu_names = set(self.submod_nodes[i].name for i in range(len(self.submod_nodes))) | |
| def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext): | |
| env = ctx.env | |
| h2d_stream = ctx.h2d_stream | |
| buffers = ctx.buffers | |
| submod_0_next_iter_weights = ctx.submod_0_handoff | |
| compute_stream = ctx.compute_stream | |
| try: | |
| idx = int(current_node_name.split('_')[-1]) | |
| except ValueError: | |
| return | |
| max_lookahead = 2 | |
| target_node = None | |
| is_next_iter = False | |
| for offset in range(1, max_lookahead + 1): | |
| candidate_idx = (idx + offset) % self.submod_num | |
| candidate_node = self.submod_nodes[candidate_idx] | |
| if self.weight_sizes.get(candidate_node.name, 0) > 0: | |
| target_node = candidate_node | |
| if (idx + offset) >= self.submod_num: | |
| is_next_iter = True | |
| else: | |
| is_next_iter = False | |
| break | |
| if target_node is None: | |
| return | |
| last_target = getattr(self, "last_prefetched_target", None) | |
| if idx == 0: | |
| last_target = None | |
| if target_node.name == last_target: | |
| return | |
| self.last_prefetched_target = target_node.name | |
| compute_stream.wait_stream(h2d_stream) | |
| weight_names = self.submod_weights_map.get(target_node.name, []) | |
| with torch.cuda.stream(h2d_stream): | |
| for w_name in weight_names: | |
| if w_name in self.kept_weights: | |
| continue | |
| w_node = self.name_node_map.get(w_name) | |
| if not w_node: | |
| continue | |
| if is_next_iter: | |
| if w_name in buffers: | |
| submod_0_next_iter_weights[w_node] = buffers[w_name].to("cuda", non_blocking=True) | |
| elif w_node in env and hasattr(env[w_node], 'device') and env[w_node].device.type == "cpu": | |
| env[w_node] = env[w_node].to("cuda", non_blocking=True) | |
| class CostEffectiveScheduler(BaseScheduler): | |
| """ | |
| Cost-effective strategy implementation: | |
| Determine which weights to keep on GPU based on the ratio of time to size | |
| """ | |
| def schedule_kept_weights(self, profile_data: Dict[str, float]): | |
| weight_sizes = self.weight_sizes | |
| gpu_resident_weight_ratio = self.compile_config.offload_config.gpu_resident_weight_ratio | |
| if gpu_resident_weight_ratio <= 0.0: | |
| return | |
| candidates = [] | |
| total_weight_size = 0 | |
| for name, timing in profile_data.items(): | |
| if name == "h2d_bandwidth": | |
| continue | |
| duration = timing.get("compute", 0.0) | |
| size = weight_sizes.get(name, 0) | |
| ratio = duration / size if size > 0 else float('inf') | |
| candidates.append((name, ratio, size)) | |
| total_weight_size += size | |
| candidates.sort(key=lambda x: x[1]) | |
| current_kept_size = 0 | |
| limit = total_weight_size * gpu_resident_weight_ratio | |
| self.kept_weights.clear() | |
| for name, ratio, s in candidates: | |
| if name == "submod_0": | |
| continue | |
| if current_kept_size + s <= limit: | |
| current_kept_size += s | |
| for w_name in self.submod_weights_map.get(name, []): | |
| self.kept_weights.add(w_name) | |
| self.submod_cpu_names.add(name) | |
| else: | |
| break | |
| magi_logger.info(f"schedule_kept_weights size {current_kept_size} keep_submod_names: {self.submod_cpu_names}") | |
| class HeuristicScheduler(BaseScheduler): | |
| def __init__( | |
| self, | |
| compile_config: CompileConfig, | |
| submod_nodes: List[Node], | |
| submod_weights_map: Dict[str, List[str]], | |
| name_node_map: Dict[str, Node], | |
| weight_sizes: Dict[str, int], | |
| ): | |
| super().__init__(compile_config, submod_nodes, submod_weights_map, name_node_map, weight_sizes) | |
| # safety margin (ms), default 0.1ms, to prevent pipeline bubbles due to profile fluctuations | |
| self.safety_margin = getattr(compile_config, "prefetch_margin_ms", 0.1) | |
| # prefetch schedule: { trigger_node_name : target_submod_index } | |
| # when execute the trigger_node_name, prefetch the target_submod_index | |
| self.prefetch_schedule = collections.defaultdict() | |
| self.submod_load_events: Dict[str, torch.cuda.Event] = collections.defaultdict(torch.cuda.Event) | |
| for i in range(len(self.submod_nodes)): | |
| self.submod_load_events[self.submod_nodes[i].name] = torch.cuda.Event() | |
| def schedule_kept_weights(self, profile_data: Dict[str, Any]): | |
| """ | |
| Core logic: build the prefetch schedule based on the profile data. | |
| profile_data format: | |
| { | |
| "submod_name": { | |
| "compute": 10.5, | |
| "h2d": 5.2 | |
| }, | |
| "h2d_bandwidth": 5.2 | |
| } | |
| """ | |
| self.kept_weights.clear() | |
| self.prefetch_schedule.clear() | |
| weight_sizes = self.weight_sizes | |
| len(self.submod_nodes) | |
| compute_times = [] | |
| h2d_times = [] | |
| # GiB/s | |
| EST_BANDWIDTH = profile_data.get("h2d_bandwidth") * self.compile_config.offload_config.bandwidth_safety_factor | |
| for i, node in enumerate(self.submod_nodes): | |
| data = profile_data.get(node.name, {}) | |
| c_time = data.get("compute", 0.0) | |
| w_size = weight_sizes.get(node.name, 0) | |
| # w_size is in bytes, EST_BANDWIDTH is in GB/s (1e9 bytes/s) | |
| # h_time should be in milliseconds | |
| h_time = (w_size / (EST_BANDWIDTH * 1e6)) if EST_BANDWIDTH > 0 else 0 | |
| compute_times.append(c_time) | |
| h2d_times.append(h_time) | |
| keep_submod_names = collections.defaultdict(list) | |
| schedule_nodes = collections.defaultdict(dict) | |
| submod_cpu_name = set() | |
| cpu_weight_size = 0 | |
| for i in range(len(self.submod_nodes) - 1, -1, -1): | |
| if weight_sizes.get(self.submod_nodes[i].name, 0.0) > 0.0: | |
| offload_idx = i | |
| break | |
| # traverse submod_nodes in reverse order to ensure the latest start of transmission | |
| i = offload_idx - 1 | |
| schedule_node_compute_time = 0 | |
| while i >= 0: | |
| end_idx = i | |
| schedule_node_compute_time = 0 | |
| while schedule_node_compute_time < self.safety_margin + h2d_times[offload_idx] and i >= 0: | |
| schedule_node_compute_time += compute_times[i] | |
| i -= 1 | |
| if i < 0 and schedule_node_compute_time < self.safety_margin + h2d_times[offload_idx]: | |
| break | |
| submod_cpu_name.add(self.submod_nodes[offload_idx].name) | |
| cpu_weight_size += weight_sizes.get(self.submod_nodes[offload_idx].name, 0) | |
| schedule_nodes[offload_idx].update( | |
| { | |
| "compute_start_idx": i + 1, | |
| "compute_end_idx": end_idx, | |
| "h2d_idx": offload_idx, | |
| "compute_time": schedule_node_compute_time, | |
| "h2d_time": h2d_times[offload_idx], | |
| "ratio": schedule_node_compute_time / (self.safety_margin + h2d_times[offload_idx]), | |
| } | |
| ) | |
| offload_idx = i + 1 | |
| while weight_sizes.get(self.submod_nodes[offload_idx].name, 0.0) == 0.0 and offload_idx <= end_idx: | |
| offload_idx += 1 | |
| gpu_resident_weight_ratio = self.compile_config.offload_config.gpu_resident_weight_ratio | |
| total_weight_size = sum(weight_sizes.values()) | |
| limit = total_weight_size * gpu_resident_weight_ratio | |
| keep_gpu_size = total_weight_size - cpu_weight_size | |
| while limit < keep_gpu_size: | |
| schedule_nodes_list = [(i, s_node["ratio"]) for i, s_node in schedule_nodes.items()] | |
| schedule_nodes_list.sort(key=lambda x: x[1], reverse=True) | |
| changed = False | |
| for s_node_idx, _ in schedule_nodes_list: | |
| s_node = schedule_nodes[s_node_idx] | |
| compute_start_idx = s_node["compute_start_idx"] | |
| for i in range(s_node["compute_start_idx"] + 1, s_node["compute_end_idx"] + 1): | |
| if self.submod_nodes[i].name in submod_cpu_name or weight_sizes.get(self.submod_nodes[i].name, 0.0) == 0.0: | |
| continue | |
| offload_idx = i | |
| schedule_node_compute_time = sum(compute_times[compute_start_idx:i]) | |
| schedule_nodes[offload_idx].update( | |
| { | |
| "compute_start_idx": compute_start_idx, | |
| "compute_end_idx": max(i - 1, compute_start_idx), | |
| "h2d_idx": offload_idx, | |
| "compute_time": schedule_node_compute_time, | |
| "h2d_time": h2d_times[offload_idx], | |
| "ratio": schedule_node_compute_time / (self.safety_margin + h2d_times[offload_idx]), | |
| } | |
| ) | |
| submod_cpu_name.add(self.submod_nodes[i].name) | |
| keep_gpu_size -= weight_sizes.get(self.submod_nodes[i].name, 0) | |
| schedule_nodes[s_node_idx].update( | |
| { | |
| "compute_start_idx": i, | |
| "compute_end_idx": s_node["compute_end_idx"], | |
| "h2d_idx": s_node_idx, | |
| "compute_time": s_node["compute_time"] - schedule_node_compute_time, | |
| "h2d_time": h2d_times[s_node_idx], | |
| "ratio": (s_node["compute_time"] - schedule_node_compute_time) | |
| / (self.safety_margin + h2d_times[s_node_idx]), | |
| } | |
| ) | |
| changed = True | |
| break | |
| if changed: | |
| break | |
| if not changed: | |
| break | |
| self.submod_cpu_names = submod_cpu_name | |
| for i, s_node in schedule_nodes.items(): | |
| self.prefetch_schedule[self.submod_nodes[s_node["compute_start_idx"]].name] = s_node["h2d_idx"] | |
| # add weights of submod that are not on CPU | |
| keep_submod_names = set() | |
| for i in range(len(self.submod_nodes)): | |
| if self.submod_nodes[i].name not in submod_cpu_name: | |
| for w_name in self.submod_weights_map.get(self.submod_nodes[i].name, []): | |
| self.kept_weights.add(w_name) | |
| keep_submod_names.add(self.submod_nodes[i].name) | |
| def prefetch(self, current_node_name: str, ctx: OffloadRuntimeContext): | |
| """ | |
| 1. check if the current node is on CPU, if yes, wait for the submod load event | |
| 2. if not on CPU, skip | |
| 3. prefetch the weights of the submod that the current node is pointing to | |
| """ | |
| if ctx.need_profile: | |
| return super().prefetch(current_node_name, ctx) | |
| env = ctx.env | |
| h2d_stream = ctx.h2d_stream | |
| buffers = ctx.buffers | |
| submod_0_next_iter_weights = ctx.submod_0_handoff | |
| compute_stream = ctx.compute_stream | |
| if current_node_name in self.submod_cpu_names: | |
| compute_stream.wait_event(self.submod_load_events[current_node_name]) | |
| h2d_idx = self.prefetch_schedule.get(current_node_name, None) | |
| if h2d_idx is None: | |
| return | |
| compute_complete_event = torch.cuda.Event() | |
| compute_complete_event.record(compute_stream) | |
| h2d_stream.wait_event(compute_complete_event) | |
| h2d_node = self.submod_nodes[h2d_idx] | |
| is_last_step = h2d_idx == len(self.submod_nodes) - 1 | |
| weight_names = self.submod_weights_map.get(h2d_node.name, []) | |
| with torch.cuda.stream(h2d_stream): | |
| for w_name in weight_names: | |
| if w_name in self.kept_weights: | |
| continue | |
| w_node = self.name_node_map.get(w_name) | |
| if w_node in env and hasattr(env[w_node], 'device') and env[w_node].device.type == "cpu": | |
| env[w_node] = env[w_node].to("cuda", non_blocking=True) | |
| elif is_last_step and w_name in buffers: | |
| submod_0_next_iter_weights[w_node] = buffers[w_name].to("cuda", non_blocking=True) | |
| self.submod_load_events[h2d_node.name].record(h2d_stream) | |
Xet Storage Details
- Size:
- 17.7 kB
- Xet hash:
- 0bf328e052b35f5b19d7063c7a0e3d7b61cb8ad5dcad504c10c96f0c78dcc9c4
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.