|
|
import os |
|
|
import time |
|
|
from collections import OrderedDict |
|
|
from functools import partial, reduce |
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
|
import pyecharts |
|
|
import torch |
|
|
|
|
|
from internlm.core.naive_amp import NaiveAMPModel |
|
|
|
|
|
mb = 1024 * 1024 |
|
|
|
|
|
|
|
|
class SimpleMemState: |
|
|
""" |
|
|
A class to represent the memory state of a model layer. |
|
|
|
|
|
Args: |
|
|
layer_name (str): The name of the layer. |
|
|
layer_mem (int): The memory usage of the layer in bytes. |
|
|
""" |
|
|
|
|
|
def __init__(self, layer_name: str, layer_mem: int = 0) -> None: |
|
|
self.layer_name = layer_name |
|
|
|
|
|
|
|
|
self._layer_mem: int = layer_mem |
|
|
|
|
|
self._total_mem: int = self._layer_mem |
|
|
|
|
|
self.sub_model_stats = OrderedDict() |
|
|
|
|
|
@property |
|
|
def layer_mem(self) -> int: |
|
|
""" |
|
|
Get the memory usage of the layer. |
|
|
|
|
|
Returns: |
|
|
int: The memory usage of the layer in bytes. |
|
|
""" |
|
|
return self._layer_mem |
|
|
|
|
|
@layer_mem.setter |
|
|
def layer_mem(self, new_layer_mem: int) -> None: |
|
|
""" |
|
|
Set the memory usage of the layer. |
|
|
|
|
|
Args: |
|
|
new_layer_mem (int): The new memory usage of the layer in bytes. |
|
|
""" |
|
|
diff = new_layer_mem - self._layer_mem |
|
|
self._layer_mem = new_layer_mem |
|
|
self._total_mem += diff |
|
|
|
|
|
@property |
|
|
def total_mem(self) -> int: |
|
|
""" |
|
|
Get the total memory usage of the model and sub-models. |
|
|
|
|
|
Returns: |
|
|
int: The total memory usage in bytes. |
|
|
""" |
|
|
return self._total_mem |
|
|
|
|
|
def add(self, layer_name: str, layer_mem: int = 0, flush: bool = True) -> None: |
|
|
""" |
|
|
Add a layer to the memory state. |
|
|
|
|
|
Args: |
|
|
layer_name (str): The name of the layer. |
|
|
layer_mem (int, optional): The memory usage of the layer in bytes. Defaults to 0. |
|
|
flush (bool, optional): Whether to update the total memory usage. Defaults to True. |
|
|
""" |
|
|
path = layer_name.split(".") |
|
|
|
|
|
target = self.find_layer_state(path, create=True) |
|
|
target.layer_mem = layer_mem |
|
|
|
|
|
if flush: |
|
|
self.update_total_memory() |
|
|
|
|
|
def delete(self, layer_name: str, flush: bool = True) -> None: |
|
|
""" |
|
|
Delete a layer from the memory state. |
|
|
|
|
|
Args: |
|
|
layer_name (str): The name of the layer. |
|
|
flush (bool, optional): Whether to update the total memory usage. Defaults to True. |
|
|
""" |
|
|
path = layer_name.split(".") |
|
|
assert len(path) >= 2, f"Only support deleting non-root layers, layer_name: {layer_name}" |
|
|
|
|
|
parent_path = path[0:-1] |
|
|
layer = path[-1] |
|
|
parent = self.find_layer_state(parent_path) |
|
|
|
|
|
if parent is not None and layer in parent.sub_model_stats: |
|
|
del parent.sub_model_stats[layer] |
|
|
|
|
|
if flush: |
|
|
self.update_total_memory() |
|
|
|
|
|
def update_total_memory(self) -> None: |
|
|
""" |
|
|
Update the total memory usage of the model and sub-models. |
|
|
""" |
|
|
self._total_mem = self._layer_mem |
|
|
|
|
|
for stat in self.sub_model_stats.values(): |
|
|
|
|
|
stat.update_total_memory() |
|
|
|
|
|
self._total_mem += stat._total_mem |
|
|
|
|
|
def find_layer_state(self, path: Tuple[str], create: bool = False) -> "SimpleMemState": |
|
|
""" |
|
|
Find the memory state of a layer. |
|
|
|
|
|
Args: |
|
|
path (Tuple[str]): The path to the layer. |
|
|
create (bool, optional): Whether to create the layer if it doesn't exist. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
SimpleMemState: The memory state of the layer. |
|
|
""" |
|
|
current_node = self |
|
|
|
|
|
for _node in path: |
|
|
if _node not in current_node.sub_model_stats: |
|
|
if not create: |
|
|
return None |
|
|
|
|
|
current_node.sub_model_stats[_node] = SimpleMemState(_node) |
|
|
|
|
|
current_node = current_node.sub_model_stats[_node] |
|
|
|
|
|
return current_node |
|
|
|
|
|
def dump(self, prefix: str = "") -> str: |
|
|
""" |
|
|
Dump the memory state of the model and sub-models. |
|
|
|
|
|
Args: |
|
|
prefix (str, optional): The prefix to add to the layer names. Defaults to "". |
|
|
|
|
|
Returns: |
|
|
str: The memory state information. |
|
|
""" |
|
|
cur_prefix = prefix + "." + self.layer_name if prefix != "" else self.layer_name |
|
|
res = f"layer: {cur_prefix}, layer_mem: {self.layer_mem / mb:.2f} MB, total_mem: {self.total_mem / mb:.2f} MB\n" |
|
|
|
|
|
for sub_layer in self.sub_model_stats.values(): |
|
|
res += sub_layer.dump(cur_prefix) |
|
|
|
|
|
return res |
|
|
|
|
|
def to_json(self, base: int = 1024 * 1024) -> dict: |
|
|
""" |
|
|
Convert the memory state to a JSON structure. |
|
|
|
|
|
Returns: |
|
|
dict: The JSON structure of the memory state. |
|
|
""" |
|
|
children = [child.to_json() for child in self.sub_model_stats.values()] |
|
|
if len(children) == 0: |
|
|
return {"name": self.layer_name, "value": self.layer_mem // base} |
|
|
else: |
|
|
return {"name": self.layer_name, "children": children} |
|
|
|
|
|
|
|
|
class ActivationMemState: |
|
|
""" |
|
|
Activation Memory State |
|
|
""" |
|
|
|
|
|
def __init__(self, num_chunks: int) -> None: |
|
|
self._num_chunks = num_chunks |
|
|
|
|
|
self.inited: List[bool] = [False for _ in range(num_chunks)] |
|
|
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)] |
|
|
|
|
|
@property |
|
|
def total_mem(self) -> int: |
|
|
return sum(state.total_mem for state in self.states) |
|
|
|
|
|
def dump(self, prefix: str = "") -> str: |
|
|
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states]) |
|
|
|
|
|
def to_json(self, base: int = 1024 * 1024) -> List: |
|
|
return [state.to_json(base) for state in self.states] |
|
|
|
|
|
|
|
|
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]: |
|
|
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1 |
|
|
|
|
|
if num_chunks > 1: |
|
|
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model]) |
|
|
else: |
|
|
model = model.model if isinstance(model, NaiveAMPModel) else model |
|
|
|
|
|
return model, num_chunks |
|
|
|
|
|
|
|
|
class SimpleMemoryProfiler: |
|
|
""" |
|
|
A memory profiler for a llm model. |
|
|
|
|
|
Args: |
|
|
model (torch.nn.Module): The model to profile. |
|
|
optimizer (torch.optim.Optimizer): The optimizer used for training the model. |
|
|
log_file (str): The file to write the memory state information to. |
|
|
total_steps: number of steps to trace. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: torch.nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
log_folder: str, |
|
|
total_steps: int = 5, |
|
|
): |
|
|
self._model, self._num_model_chunks = _unpack_naive_wrapper(model) |
|
|
self._optimizer = optimizer |
|
|
self._log_folder = log_folder |
|
|
self._remaining_steps = total_steps |
|
|
|
|
|
self._stoped = False |
|
|
self._record_start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
self._activation_mem: int = 0 |
|
|
self._activation_mem_max: int = 0 |
|
|
self._activation_base_mems = ActivationMemState(self._num_model_chunks) |
|
|
|
|
|
|
|
|
os.makedirs(self._log_folder, exist_ok=True) |
|
|
|
|
|
|
|
|
if self._num_model_chunks > 1: |
|
|
for chunk_id in range(self._num_model_chunks): |
|
|
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id]) |
|
|
else: |
|
|
self._register_activation_trace_hooks(0, self._model) |
|
|
|
|
|
|
|
|
self._param_mem_state = SimpleMemState("param_mem") |
|
|
self._calc_tensor_memory(self._param_mem_state, self._model.named_parameters()) |
|
|
|
|
|
self._grad_mem_state = SimpleMemState("grad_mem") |
|
|
self._calc_tensor_memory(self._grad_mem_state, self._model.named_parameters(), True) |
|
|
|
|
|
self._os_params_mem_state = SimpleMemState("os_params_mem") |
|
|
self._os_state_mem_state = SimpleMemState("os_state_mem") |
|
|
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) |
|
|
|
|
|
|
|
|
self.point(with_options="params,grads,os_params", create=True) |
|
|
|
|
|
def point(self, with_options: str = "", create: bool = False) -> None: |
|
|
""" |
|
|
Record the memory state. |
|
|
|
|
|
Args: |
|
|
with_options (str, optional): The options to include in the memory state. Defaults to "". |
|
|
create (bool, optional): Whether to create a new memory record file. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
now = time.time() |
|
|
file = f"{self._log_folder}/memory.log" |
|
|
|
|
|
if with_options == "all": |
|
|
options = ["params", "grads", "os_params", "os_state", "activation_base"] |
|
|
else: |
|
|
options = with_options.split(",") |
|
|
|
|
|
total_mem = ( |
|
|
self._param_mem_state.total_mem |
|
|
+ self._grad_mem_state.total_mem |
|
|
+ self._os_params_mem_state.total_mem |
|
|
+ self._os_state_mem_state.total_mem |
|
|
+ self._activation_mem |
|
|
) / mb |
|
|
|
|
|
|
|
|
summary_info = ( |
|
|
f"total_memory: {total_mem:.2f} MB" |
|
|
+ "\n" |
|
|
+ f"params_memory: {self._param_mem_state.total_mem / mb:.2f} MB, " |
|
|
+ f"grads_memory: {self._grad_mem_state.total_mem / mb:.2f} MB, " |
|
|
+ f"os_params_memory: {self._os_params_mem_state.total_mem / mb:.2f} MB, " |
|
|
+ f"os_state_memory: {self._os_state_mem_state.total_mem / mb:.2f} MB, " |
|
|
+ f"activation_memory: {self._activation_mem / mb:.2f} MB" |
|
|
) |
|
|
|
|
|
|
|
|
layout_info = "" |
|
|
if "params" in options: |
|
|
layout_info += "params_layout:\n" + self._param_mem_state.dump() |
|
|
if "grads" in options: |
|
|
layout_info += "grads_layout:\n" + self._grad_mem_state.dump() |
|
|
if "os_params" in options: |
|
|
layout_info += "os_params_layout:\n" + self._os_params_mem_state.dump() |
|
|
if "os_state" in options: |
|
|
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump() |
|
|
if "activation_base" in options: |
|
|
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump() |
|
|
|
|
|
|
|
|
file_mode = "w" if create else "a" |
|
|
with open(file, file_mode, encoding="utf-8") as writer: |
|
|
writer.write( |
|
|
"Memory State:\n" + f"time: {now - self._record_start_time}\n" + "---summary---\n" + summary_info + "\n" |
|
|
) |
|
|
if layout_info != "": |
|
|
writer.write("---Layout---\n" + layout_info) |
|
|
writer.write("\n") |
|
|
|
|
|
def step(self) -> None: |
|
|
""" |
|
|
Update the memory state of the optimizer state. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
if self._stoped: |
|
|
return |
|
|
|
|
|
self._remaining_steps -= 1 |
|
|
if self._remaining_steps == 0: |
|
|
self._stoped = True |
|
|
|
|
|
|
|
|
self._os_state_mem_state = SimpleMemState("os_state_mem") |
|
|
self._calc_tensor_group_memory(self._os_state_mem_state, list(self._optimizer.state_dict()["state"].items())) |
|
|
|
|
|
if not self._stoped: |
|
|
|
|
|
self.point(with_options="os_state") |
|
|
else: |
|
|
|
|
|
self.point(with_options="all") |
|
|
|
|
|
self._render_sunburst_chart(self._param_mem_state.to_json()["children"], "params_memory_sunburst") |
|
|
self._render_sunburst_chart(self._grad_mem_state.to_json()["children"], "grads_memory_sunburst") |
|
|
self._render_sunburst_chart( |
|
|
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()], |
|
|
"os_memory_sunburst", |
|
|
) |
|
|
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst") |
|
|
|
|
|
summary_sunburst_data = [ |
|
|
{"name": "params", "value": self._param_mem_state.total_mem // mb}, |
|
|
{"name": "grads", "value": self._grad_mem_state.total_mem // mb}, |
|
|
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb}, |
|
|
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb}, |
|
|
{"name": "activation", "value": self._activation_mem_max // mb}, |
|
|
] |
|
|
|
|
|
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst") |
|
|
|
|
|
def _render_sunburst_chart(self, data: Any, name: str) -> None: |
|
|
pyecharts.charts.Sunburst(init_opts=pyecharts.options.InitOpts(width="1000px", height="1000px")).add( |
|
|
name, |
|
|
data_pair=data, |
|
|
highlight_policy="ancestor", |
|
|
radius=[0, "95%"], |
|
|
levels=[ |
|
|
{}, |
|
|
{ |
|
|
"r0": "10%", |
|
|
"r": "35%", |
|
|
"itemStyle": {"borderWidth": 3}, |
|
|
"label": {"align": "left"}, |
|
|
}, |
|
|
{"r0": "35%", "r": "55%", "label": {"align": "left"}}, |
|
|
{"r0": "55%", "r": "70%", "label": {"align": "left"}}, |
|
|
{"r0": "70%", "r": "80%", "label": {"align": "left"}}, |
|
|
{"r0": "80%", "r": "90%", "label": {"align": "left"}}, |
|
|
{ |
|
|
"r0": "90%", |
|
|
"r": "92%", |
|
|
"label": {"position": "outside", "padding": 3, "silent": False}, |
|
|
"itemStyle": {"borderWidth": 3}, |
|
|
}, |
|
|
], |
|
|
).set_global_opts(title_opts=pyecharts.options.TitleOpts(title="CUDA Memory")).set_series_opts( |
|
|
label_opts=pyecharts.options.LabelOpts(formatter="{b}") |
|
|
).render( |
|
|
f"{self._log_folder}/{name}.html" |
|
|
) |
|
|
|
|
|
def _inner_activation_trace_hook( |
|
|
self, |
|
|
chunk_id: int, |
|
|
layer_name: str, |
|
|
model: Any, |
|
|
inputs: Any, |
|
|
output: torch.Tensor, |
|
|
) -> None: |
|
|
""" |
|
|
Hook function to trace the activation memory usage for a inner layer. |
|
|
|
|
|
Args: |
|
|
layer_name (str): The name of the layer. |
|
|
model (Any): The model. |
|
|
inputs (Any): The inputs to the layer. |
|
|
output (torch.Tensor): The output tensor. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
del model, inputs |
|
|
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}" |
|
|
|
|
|
if self._stoped or self._activation_base_mems.inited[chunk_id]: |
|
|
return |
|
|
|
|
|
|
|
|
self._activation_base_mems.states[chunk_id].add( |
|
|
layer_name, output.element_size() * output.nelement(), flush=False |
|
|
) |
|
|
|
|
|
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None: |
|
|
""" |
|
|
Hook function to trace the activation memory usage for a forward pass. |
|
|
|
|
|
Args: |
|
|
model (Any): The model. |
|
|
inputs (Any): The inputs to the model. |
|
|
output (torch.Tensor): The output tensor. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
del model, inputs |
|
|
assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}" |
|
|
|
|
|
if self._stoped: |
|
|
return |
|
|
|
|
|
|
|
|
if self._activation_base_mems.inited[chunk_id] is False: |
|
|
self._activation_base_mems.inited[chunk_id] = True |
|
|
|
|
|
self._activation_base_mems.states[chunk_id].update_total_memory() |
|
|
|
|
|
with_options = "activation_base" |
|
|
else: |
|
|
with_options = "" |
|
|
|
|
|
|
|
|
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem |
|
|
if self._activation_mem > self._activation_mem_max: |
|
|
self._activation_mem_max = self._activation_mem |
|
|
|
|
|
|
|
|
self.point(with_options) |
|
|
|
|
|
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None: |
|
|
""" |
|
|
Hook function to trace the activation memory usage for a backward pass. |
|
|
|
|
|
Args: |
|
|
model (Any): The model. |
|
|
inputs (Any): The inputs to the model. |
|
|
grad_outputs (Any): The gradients of the outputs. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
del model, inputs, grad_outputs |
|
|
|
|
|
if self._stoped: |
|
|
return |
|
|
|
|
|
|
|
|
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem |
|
|
|
|
|
|
|
|
self.point() |
|
|
|
|
|
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None: |
|
|
""" |
|
|
Register activation trace hooks for the model and each submodule in the model. |
|
|
""" |
|
|
|
|
|
|
|
|
for layer_name, sub_model in model_chunk.named_modules(): |
|
|
|
|
|
if len(sub_model._modules) != 0: |
|
|
continue |
|
|
|
|
|
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name)) |
|
|
|
|
|
|
|
|
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id)) |
|
|
|
|
|
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id)) |
|
|
|
|
|
def _calc_tensor_memory( |
|
|
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False |
|
|
) -> None: |
|
|
""" |
|
|
Calculate the memory usage of tensors and update the memory state. |
|
|
|
|
|
Args: |
|
|
root_stat (SimpleMemState): The root memory state. |
|
|
named_tensors (Dict[str, torch.Tensor]): A dictionary containing the named tensors. |
|
|
require_grad (bool, optional): Whether to consider tensors with gradients. Defaults to False. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
for name, tensor in named_tensors: |
|
|
if require_grad and not tensor.requires_grad: |
|
|
continue |
|
|
|
|
|
layer_splits = name.split(sep=".") |
|
|
layer_stat = root_stat.find_layer_state(layer_splits, create=True) |
|
|
layer_stat.layer_mem = tensor.element_size() * tensor.nelement() |
|
|
|
|
|
root_stat.update_total_memory() |
|
|
|
|
|
def _calc_tensor_group_memory(self, root_stat: SimpleMemState, tensor_groups: List[Tuple[int, torch.Tensor]]): |
|
|
""" |
|
|
Calculate the memory usage of a group of tensors. |
|
|
|
|
|
Args: |
|
|
root_stat (SimpleMemState): The root memory state. |
|
|
tensor_groups (List[Tuple[int, torch.Tensor]]): A list of tuples containing the tensor groups. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
|
|
|
def _normalize_helper(named_tensors: Dict[str, Any]) -> List[Tuple[str, Any]]: |
|
|
""" |
|
|
Normalize the named tensors. |
|
|
|
|
|
Args: |
|
|
named_tensors (Dict[str, Any]): The named tensors to normalize. |
|
|
|
|
|
Returns: |
|
|
List[Tuple[str, Any]]: The normalized named tensors. |
|
|
""" |
|
|
res = {} |
|
|
|
|
|
for name, tensors in named_tensors.items(): |
|
|
if isinstance(tensors, torch.Tensor): |
|
|
res[name] = tensors |
|
|
elif isinstance(tensors, (list, tuple)): |
|
|
for index, tensor in enumerate(tensors): |
|
|
res[f"{name}.{index}"] = tensor |
|
|
elif isinstance(tensors, dict): |
|
|
for subname, tensor in tensors.items(): |
|
|
res[f"{name}.{subname}"] = tensor |
|
|
else: |
|
|
raise TypeError(f"unsupported normalize value type: {type(tensors)}") |
|
|
|
|
|
return list(res.items()) |
|
|
|
|
|
def _value_check(tensor_or_tensors): |
|
|
""" |
|
|
Check if the input is a tensor or a collection of tensors. |
|
|
|
|
|
Args: |
|
|
tensor_or_tensors (Any): The input to check. |
|
|
|
|
|
Returns: |
|
|
bool: True if the input is a tensor or a collection of tensors, False otherwise. |
|
|
""" |
|
|
if torch.is_tensor(tensor_or_tensors): |
|
|
return True |
|
|
elif isinstance(tensor_or_tensors, (list, tuple)) and all(torch.is_tensor(x) for x in tensor_or_tensors): |
|
|
return True |
|
|
elif isinstance(tensor_or_tensors, dict) and all(torch.is_tensor(x) for x in tensor_or_tensors.values()): |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
|
|
|
for idx, tensors in tensor_groups: |
|
|
|
|
|
named_tensors = {f"{idx}.{k}": v for k, v in tensors.items() if _value_check(v)} |
|
|
named_tensors = _normalize_helper(named_tensors) |
|
|
|
|
|
self._calc_tensor_memory(root_stat, named_tensors) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
class SimpleModel(torch.nn.Module): |
|
|
""" |
|
|
A simple model with three linear layers. |
|
|
|
|
|
Args: |
|
|
skip_layer2 (bool, optional): Whether to skip layer2. Defaults to False. |
|
|
""" |
|
|
|
|
|
def __init__(self, skip_layer2: bool = False): |
|
|
super().__init__() |
|
|
self.layer1 = torch.nn.Linear(5120, 5120, True) |
|
|
self.layer3 = torch.nn.Linear(5120, 5120, False) |
|
|
|
|
|
if skip_layer2: |
|
|
self.layer2 = None |
|
|
else: |
|
|
self.layer2 = SimpleModel(skip_layer2=True) |
|
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass of the model. |
|
|
|
|
|
Args: |
|
|
inputs (torch.Tensor): The input tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The output tensor. |
|
|
""" |
|
|
output1 = self.layer1(inputs) |
|
|
if self.layer2 is not None: |
|
|
output2 = self.layer2(output1) |
|
|
else: |
|
|
output2 = output1 |
|
|
output = self.layer3(output2) |
|
|
|
|
|
return output |
|
|
|
|
|
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor: |
|
|
if _num_chunks > 1: |
|
|
_output = _input |
|
|
for _model_chunk in _model_chunks: |
|
|
_output = _model_chunk(_output) |
|
|
else: |
|
|
_output = _model_chunks(_input) |
|
|
|
|
|
return _output |
|
|
|
|
|
|
|
|
_num_chunks = 1 |
|
|
|
|
|
|
|
|
if _num_chunks > 1: |
|
|
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)] |
|
|
_model = torch.nn.ModuleList(_chunks).cuda() |
|
|
else: |
|
|
_model: torch.nn.Module = SimpleModel().cuda() |
|
|
_optimizer = torch.optim.Adam(_model.parameters()) |
|
|
|
|
|
|
|
|
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1) |
|
|
|
|
|
_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
x1 = torch.randn((128, 5120)).cuda() |
|
|
x2 = torch.randn((128, 5120)).cuda() |
|
|
|
|
|
out1 = _simple_schedule(_num_chunks, _model, x1) |
|
|
out2 = _simple_schedule(_num_chunks, _model, x2) |
|
|
|
|
|
out1.mean().backward() |
|
|
out2.mean().backward() |
|
|
|
|
|
_optimizer.step() |
|
|
|
|
|
|
|
|
profiler.step() |
|
|
|