henryholloway commited on
Commit
86b55e7
·
1 Parent(s): bb43133

Created calculator

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # Define bit sizes for different quantization options
4
+ quantization_bit_sizes = {
5
+ 'float32': 32,
6
+ 'float16': 16,
7
+ 'Q2_K': 2,
8
+ 'Q3_K_L': 3,
9
+ 'Q3_K_M': 3,
10
+ 'Q3_K_S': 3,
11
+ 'Q4_0': 4,
12
+ 'Q4_1': 4,
13
+ 'Q4_K_M': 4,
14
+ 'Q4_K_S': 4,
15
+ 'Q5_0': 5,
16
+ 'Q5_1': 5,
17
+ 'Q5_K_M': 5,
18
+ 'Q5_K_S': 5,
19
+ 'Q6_K': 6,
20
+ 'Q8_0': 8
21
+ }
22
+
23
+ # Define precision options
24
+ precision_options = {
25
+ 'full': 4,
26
+ 'mixed': 6, # for training mixed precision
27
+ 'half': 2
28
+ }
29
+
30
+ def calculate_memory_usage(parameter_count, context_length, data_type, is_training, batch_size, vocab_size, precision):
31
+ # Convert bit size to byte size
32
+ byte_size = quantization_bit_sizes[data_type] / 8
33
+
34
+ # Memory usage for model parameters
35
+ memory_params = parameter_count * byte_size
36
+
37
+ # Memory usage for context (activations)
38
+ activations = calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision, is_training)
39
+
40
+ # Outputs memory usage
41
+ outputs = 4 * batch_size * context_length * vocab_size * (2 if is_training else 1)
42
+
43
+ # Total memory usage
44
+ total_memory_usage = memory_params + activations + outputs
45
+
46
+ # Convert bytes to gigabytes
47
+ total_memory_usage_gb = total_memory_usage / (1024 ** 3)
48
+
49
+ return total_memory_usage_gb
50
+
51
+ def calculate_activations(parameter_count, context_length, batch_size, vocab_size, precision, is_training):
52
+ # Simplified activation calculation
53
+ hidden_size = parameter_count ** 0.5 # assuming a square root relationship for hidden size
54
+ num_attention_heads = 16 # a typical number of attention heads
55
+ intermediate_size = hidden_size * 4 # common in transformers
56
+
57
+ bytes_per_param = precision_options[precision] / 8
58
+
59
+ attention_input = bytes_per_param * batch_size * context_length * hidden_size
60
+ q = bytes_per_param * batch_size * context_length * (hidden_size / num_attention_heads) * num_attention_heads
61
+ k = bytes_per_param * batch_size * context_length * (hidden_size / num_attention_heads) * num_attention_heads
62
+ softmax_output = bytes_per_param * batch_size * num_attention_heads * (context_length ** 2)
63
+ v = bytes_per_param * batch_size * context_length * (hidden_size / num_attention_heads) * num_attention_heads
64
+ out_proj_input = bytes_per_param * batch_size * context_length * hidden_size
65
+ attention_block = attention_input + q + k + softmax_output + v + out_proj_input
66
+
67
+ mlp_input = bytes_per_param * batch_size * context_length * hidden_size
68
+ activation_input = bytes_per_param * batch_size * context_length * intermediate_size
69
+ down_proj_input = bytes_per_param * batch_size * context_length * intermediate_size
70
+ mlp_block = mlp_input + activation_input + down_proj_input
71
+
72
+ layer_norms = bytes_per_param * batch_size * context_length * hidden_size * 2
73
+
74
+ layer = attention_block + mlp_block + layer_norms
75
+
76
+ activations = layer * (12 if is_training else 1) # assuming 12 layers for simplicity
77
+
78
+ return activations
79
+
80
+ # Streamlit app
81
+ st.title("Memory Usage Calculator for Large Language Models")
82
+
83
+ # User inputs
84
+ parameter_count = st.number_input("Parameter Count (in billions)", value=1, step=1) * 1e9
85
+ context_length = st.number_input("Context Length (number of tokens)", value=512, step=1)
86
+ data_type = st.selectbox("Data Type", options=list(quantization_bit_sizes.keys()))
87
+ is_training = st.checkbox("Training Mode", value=False)
88
+ batch_size = st.number_input("Batch Size", value=1, step=1)
89
+ vocab_size = st.number_input("Vocabulary Size", value=30000, step=1000)
90
+ precision = st.selectbox("Precision", options=list(precision_options.keys()))
91
+
92
+ # Calculate memory usage
93
+ if st.button("Calculate Memory Usage"):
94
+ memory_usage = calculate_memory_usage(parameter_count, context_length, data_type, is_training, batch_size, vocab_size, precision)
95
+ st.write(f"Estimated Memory Usage for {'Training' if is_training else 'Inference'}: {memory_usage:.2f} GB")