File size: 6,425 Bytes
84f0b80
97e312a
 
84f0b80
97e312a
 
 
 
 
 
84f0b80
b79954f
 
84f0b80
 
b79954f
84f0b80
 
97e312a
 
 
 
 
 
 
84f0b80
b79954f
84f0b80
 
97e312a
 
 
 
 
 
 
 
 
 
 
 
 
 
b79954f
97e312a
 
 
 
84f0b80
b79954f
84f0b80
 
97e312a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f0b80
97e312a
 
 
 
 
 
 
 
 
 
 
 
b79954f
97e312a
 
 
 
 
 
84f0b80
97e312a
 
 
 
 
 
 
 
 
 
84f0b80
97e312a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f0b80
97e312a
84f0b80
97e312a
 
84f0b80
97e312a
b79954f
97e312a
 
 
 
 
b79954f
84f0b80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import pandas as pd
from functools import partial
from defaults import DEFAULTS
from state import Model, Parallelism, Training
from calculator import MemoryCalculation
from dtypes import DType

# Create a Number component for natural numbers (positive integers)
NaturalNumber = partial(gr.Number, minimum=1, step=1, precision=0, interactive=True)


def greet(name, intensity) -> str:
    return "Hello, " + name + "!" * int(intensity)


def create_parallelism_block():
    with gr.Column():
        gr.Markdown("# Parallelism")
        with gr.Group():
            tp = NaturalNumber(label="Tensor Parallelism", value=1)
            pp = NaturalNumber(label="Pipeline Parallelism", value=1)
            cp = NaturalNumber(label="Context Parallelism", value=1)
            ep = NaturalNumber(label="Expert Parallelism", value=1)
            return tp, pp, cp, ep


def create_model_block():
    with gr.Column():
        gr.Markdown("# Model Architecture")
        layers = NaturalNumber(label="Number of Layers", value=32)
        vocab = NaturalNumber(label="Vocab Size", value=32000)
        hidden = NaturalNumber(label="Hidden Dim", value=4096)
        intermediate = NaturalNumber(label="Intermediate Dim", value=11008)
        is_moe = gr.Checkbox(label="Mixture of Experts (MoE)", value=False)
        active_experts = NaturalNumber(label="Active Experts", value=2, visible=False)
        total_experts = NaturalNumber(label="Total Experts", value=8, visible=False)

        # Toggle expert fields visibility based on MoE checkbox
        is_moe.change(
            fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
            inputs=is_moe,
            outputs=[active_experts, total_experts]
        )

        # not ready yet
        # presets = gr.Dropdown(list(DEFAULTS.keys()), label="Presets", interactive=True)
        return layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets


def create_training_block():
    with gr.Column():
        gr.Markdown("# Training Config")
        seq_len = NaturalNumber(label="Sequence Length", value=8192)
        batch_size = NaturalNumber(label="Batch Size", info="If you are using gradient accumulation, enter microbatch size", value=8)
        with gr.Row():
            gradient_checkpointing = gr.Checkbox(label="Gradient Checkpointing", value=False)
            grad_accumulation = gr.Checkbox(label="Gradient Accumulation", value=False)
        precision = gr.Dropdown(DType.values(), label="Precision", value=DType.FP32.value, interactive=True)
        mixed_precision = gr.Checkbox(label="Mixed Precision", value=False)
        param_dtype = gr.Dropdown(DType.values(), label="Parameter Dtype", value=DType.FP32.value, interactive=True, visible=False)
        reduce_dtype = gr.Dropdown(DType.values(), label="Reduce Dtype", value=DType.FP32.value, interactive=True, visible=False)

        # Toggle dtype fields visibility based on mixed precision checkbox
        mixed_precision.change(
            fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
            inputs=mixed_precision,
            outputs=[param_dtype, reduce_dtype]
        )

        return seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype


def calculate(tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype):
    # Create state objects
    model_config = Model(
        vocab_size=int(vocab),
        num_layers=int(layers),
        hidden_dim=int(hidden),
        intermediate_size=int(intermediate),
        weight_tied_embeddings=True,  # Default assumption
        active_experts=int(active_experts),
        total_experts=int(total_experts),
        is_moe=is_moe
    )

    parallelism_config = Parallelism(
        tensor_parallelism=int(tp),
        pipeline_parallelism=int(pp),
        context_parallelism=int(cp),
        expert_parallelism=int(ep)
    )

    training_config = Training(
        sequence_length=int(seq_len),
        batch_size=int(batch_size),
        gradient_checkpointing=gradient_checkpointing,
        grad_accumulation=grad_accumulation,
        precision=DType(precision),
        mixed_precision=mixed_precision,
        param_dtype=DType(param_dtype),
        reduce_dtype=DType(reduce_dtype)
    )

    # Calculate different memory components
    calc = MemoryCalculation(model_config, parallelism_config, training_config)

    # Get all memory calculations
    param_memory = calc.calculate_parameter_memory()
    activation_memory = calc.calculate_activation_memory()
    gradient_memory = calc.calculate_gradient_memory()
    optimizer_memory = calc.calculate_optimizer_memory()

    # Create DataFrame for bar plot
    memory_data = pd.DataFrame({
        'Component': [
            'Parameter Memory',
            'Activation Memory',
            'Gradient Memory',
            'Optimizer Memory'
        ],
        'Memory (GB)': [
            param_memory / 1e9,
            activation_memory / 1e9,
            gradient_memory / 1e9,
            optimizer_memory / 1e9
        ]
    })

    return gr.BarPlot(
        value=memory_data,
        x="Component",
        y="Memory (GB)",
        title="LLM Memory Usage Breakdown",
        container=False,
        y_lim=[0, None]
    )


with gr.Blocks(theme='gstaff/xkcd') as demo:
    with gr.Sidebar():
        gr.Textbox("## LLM Memory Visualizer")
    with gr.Column():
        with gr.Row(equal_height=True):
            tp, pp, cp, ep = create_parallelism_block()
            layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, presets = create_model_block()
            seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype = create_training_block()
        calculate_button = gr.Button("Calculate")
        output = gr.BarPlot(label="Memory Usage Breakdown")

        calculate_button.click(
            fn=calculate,
            inputs=[tp, pp, cp, ep, layers, vocab, hidden, intermediate, active_experts, total_experts, is_moe, seq_len, batch_size, gradient_checkpointing, grad_accumulation, precision, mixed_precision, param_dtype, reduce_dtype],
            outputs=output
        )


demo.launch()