nvan13 commited on
Commit
1f2bd2b
·
verified ·
1 Parent(s): a0d95b0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +1 -2
  2. src/bb.ipynb +806 -0
.gitignore CHANGED
@@ -210,7 +210,6 @@ __marimo__/
210
 
211
  trainer_output/
212
  outputs/
213
- src/note.ipynb
214
  wandb/
215
  runs/
216
- src/*.ipynb
 
210
 
211
  trainer_output/
212
  outputs/
 
213
  wandb/
214
  runs/
215
+ # src/*.ipynb
src/bb.ipynb ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "7e7899f4",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "tensor([[ 0.7923, -0.1882, 0.8791, 0.8785],\n",
14
+ " [-0.3649, 0.3171, 0.2766, 0.4714],\n",
15
+ " [ 0.5661, 0.4688, -1.5763, -0.7690],\n",
16
+ " [ 1.2863, -0.4760, -2.0309, -2.4342],\n",
17
+ " [ 0.1591, 1.2439, -1.0475, 0.6328],\n",
18
+ " [ 0.3351, 0.5378, -1.2086, 0.9963]], grad_fn=<PermuteBackward0>) tensor([[-0.4674, 0.0799, 0.8670, 0.0765],\n",
19
+ " [ 0.2153, -0.1855, 0.0422, -0.1279],\n",
20
+ " [-0.3339, -0.3323, -0.2219, -0.1967],\n",
21
+ " [-0.7588, 0.2398, -0.3984, -0.1867],\n",
22
+ " [-0.0939, -0.8113, 0.1191, -0.3375],\n",
23
+ " [-0.1977, -0.3647, -0.1560, 0.8890]], grad_fn=<LinalgQrBackward0>) tensor([[-1.6951, 0.1378, 2.0534, 1.5385],\n",
24
+ " [ 0.0000, -1.5491, 1.3466, -1.2221],\n",
25
+ " [ 0.0000, 0.0000, 1.9967, 1.8420],\n",
26
+ " [ 0.0000, 0.0000, 0.0000, 1.2846]], grad_fn=<LinalgQrBackward0>)\n",
27
+ "Output Shape: torch.Size([4, 6])\n",
28
+ "\n",
29
+ "Gram Matrix (M @ M.T):\n",
30
+ "tensor([[ 1.0000e+00, -2.5258e-08, -1.9981e-07, 6.6143e-09],\n",
31
+ " [-2.5258e-08, 1.0000e+00, 5.1411e-08, 3.9070e-08],\n",
32
+ " [-1.9981e-07, 5.1411e-08, 1.0000e+00, 1.6955e-08],\n",
33
+ " [ 6.6143e-09, 3.9070e-08, 1.6955e-08, 1.0000e+00]],\n",
34
+ " grad_fn=<MmBackward0>)\n",
35
+ "\n",
36
+ "Orthogonality Error: 0.000000\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "import torch\n",
42
+ "\n",
43
+ "def create_orthogonal_rows_matrix(rows: int, cols: int):\n",
44
+ " \"\"\"\n",
45
+ " Creates a rectangular matrix (rows x cols) with orthonormal rows using QR decomposition.\n",
46
+ " Condition: cols >= rows.\n",
47
+ " \"\"\"\n",
48
+ " # Create a random input matrix (requires_grad=True to test differentiability later)\n",
49
+ " # Shape: (rows, cols)\n",
50
+ " X = torch.randn(rows, cols, requires_grad=True)\n",
51
+ " \n",
52
+ " # 1. Transpose to get a \"tall\" matrix (cols x rows)\n",
53
+ " # We do this because standard QR produces orthogonal columns.\n",
54
+ " X_T = X.T\n",
55
+ " \n",
56
+ " # 2. Apply QR Decomposition\n",
57
+ " # Q will have shape (cols, rows) with orthogonal columns\n",
58
+ " # R will be upper triangular\n",
59
+ " Q_T, R = torch.linalg.qr(X_T, mode='reduced')\n",
60
+ " print(X_T, Q_T, R)\n",
61
+ " \n",
62
+ " # 3. Transpose Q back to get the desired shape (rows, cols)\n",
63
+ " # Now, M has orthogonal rows.\n",
64
+ " M = Q_T.T\n",
65
+ " \n",
66
+ " return X, M\n",
67
+ "\n",
68
+ "# --- Usage Example ---\n",
69
+ "\n",
70
+ "# Configuration: 3 rows, 5 columns (Rectangular \"fat\" matrix)\n",
71
+ "m, n = 4, 6\n",
72
+ "\n",
73
+ "# Create the matrix\n",
74
+ "input_tensor, ortho_matrix = create_orthogonal_rows_matrix(m, n)\n",
75
+ "\n",
76
+ "print(f\"Output Shape: {ortho_matrix.shape}\") # Should be torch.Size([3, 5])\n",
77
+ "\n",
78
+ "# --- Verification ---\n",
79
+ "\n",
80
+ "# Check orthogonality: M @ M.T should be Identity matrix (3x3)\n",
81
+ "gram_matrix = torch.matmul(ortho_matrix, ortho_matrix.T)\n",
82
+ "identity = torch.eye(m)\n",
83
+ "\n",
84
+ "print(\"\\nGram Matrix (M @ M.T):\")\n",
85
+ "print(gram_matrix)\n",
86
+ "\n",
87
+ "# Check error\n",
88
+ "error = torch.dist(gram_matrix, identity)\n",
89
+ "print(f\"\\nOrthogonality Error: {error.item():.6f}\")\n",
90
+ "\n",
91
+ "# Note: The result is very close to 0, confirming orthogonality."
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 1,
97
+ "id": "0b4c7963",
98
+ "metadata": {},
99
+ "outputs": [
100
+ {
101
+ "name": "stdout",
102
+ "output_type": "stream",
103
+ "text": [
104
+ "NumPy: 2.2.6, SciPy:\n"
105
+ ]
106
+ }
107
+ ],
108
+ "source": [
109
+ "import numpy; import scipy\n",
110
+ "print(f'NumPy: {numpy.__version__}, SciPy:')"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 2,
116
+ "id": "6241e0ab",
117
+ "metadata": {},
118
+ "outputs": [
119
+ {
120
+ "name": "stderr",
121
+ "output_type": "stream",
122
+ "text": [
123
+ "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.22s/it]\n"
124
+ ]
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "Số tham số của Llama2-7B: 6,738,415,616\n",
131
+ "n = model.embed_tokens.weight, shape torch.Size([32000, 4096])\n",
132
+ "n = model.layers.0.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
133
+ "n = model.layers.0.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
134
+ "n = model.layers.0.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
135
+ "n = model.layers.0.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
136
+ "n = model.layers.0.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
137
+ "n = model.layers.0.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
138
+ "n = model.layers.0.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
139
+ "n = model.layers.0.input_layernorm.weight, shape torch.Size([4096])\n",
140
+ "n = model.layers.0.post_attention_layernorm.weight, shape torch.Size([4096])\n",
141
+ "n = model.layers.1.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
142
+ "n = model.layers.1.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
143
+ "n = model.layers.1.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
144
+ "n = model.layers.1.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
145
+ "n = model.layers.1.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
146
+ "n = model.layers.1.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
147
+ "n = model.layers.1.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
148
+ "n = model.layers.1.input_layernorm.weight, shape torch.Size([4096])\n",
149
+ "n = model.layers.1.post_attention_layernorm.weight, shape torch.Size([4096])\n",
150
+ "n = model.layers.2.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
151
+ "n = model.layers.2.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
152
+ "n = model.layers.2.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
153
+ "n = model.layers.2.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
154
+ "n = model.layers.2.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
155
+ "n = model.layers.2.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
156
+ "n = model.layers.2.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
157
+ "n = model.layers.2.input_layernorm.weight, shape torch.Size([4096])\n",
158
+ "n = model.layers.2.post_attention_layernorm.weight, shape torch.Size([4096])\n",
159
+ "n = model.layers.3.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
160
+ "n = model.layers.3.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
161
+ "n = model.layers.3.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
162
+ "n = model.layers.3.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
163
+ "n = model.layers.3.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
164
+ "n = model.layers.3.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
165
+ "n = model.layers.3.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
166
+ "n = model.layers.3.input_layernorm.weight, shape torch.Size([4096])\n",
167
+ "n = model.layers.3.post_attention_layernorm.weight, shape torch.Size([4096])\n",
168
+ "n = model.layers.4.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
169
+ "n = model.layers.4.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
170
+ "n = model.layers.4.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
171
+ "n = model.layers.4.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
172
+ "n = model.layers.4.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
173
+ "n = model.layers.4.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
174
+ "n = model.layers.4.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
175
+ "n = model.layers.4.input_layernorm.weight, shape torch.Size([4096])\n",
176
+ "n = model.layers.4.post_attention_layernorm.weight, shape torch.Size([4096])\n",
177
+ "n = model.layers.5.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
178
+ "n = model.layers.5.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
179
+ "n = model.layers.5.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
180
+ "n = model.layers.5.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
181
+ "n = model.layers.5.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
182
+ "n = model.layers.5.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
183
+ "n = model.layers.5.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
184
+ "n = model.layers.5.input_layernorm.weight, shape torch.Size([4096])\n",
185
+ "n = model.layers.5.post_attention_layernorm.weight, shape torch.Size([4096])\n",
186
+ "n = model.layers.6.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
187
+ "n = model.layers.6.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
188
+ "n = model.layers.6.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
189
+ "n = model.layers.6.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
190
+ "n = model.layers.6.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
191
+ "n = model.layers.6.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
192
+ "n = model.layers.6.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
193
+ "n = model.layers.6.input_layernorm.weight, shape torch.Size([4096])\n",
194
+ "n = model.layers.6.post_attention_layernorm.weight, shape torch.Size([4096])\n",
195
+ "n = model.layers.7.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
196
+ "n = model.layers.7.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
197
+ "n = model.layers.7.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
198
+ "n = model.layers.7.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
199
+ "n = model.layers.7.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
200
+ "n = model.layers.7.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
201
+ "n = model.layers.7.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
202
+ "n = model.layers.7.input_layernorm.weight, shape torch.Size([4096])\n",
203
+ "n = model.layers.7.post_attention_layernorm.weight, shape torch.Size([4096])\n",
204
+ "n = model.layers.8.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
205
+ "n = model.layers.8.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
206
+ "n = model.layers.8.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
207
+ "n = model.layers.8.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
208
+ "n = model.layers.8.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
209
+ "n = model.layers.8.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
210
+ "n = model.layers.8.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
211
+ "n = model.layers.8.input_layernorm.weight, shape torch.Size([4096])\n",
212
+ "n = model.layers.8.post_attention_layernorm.weight, shape torch.Size([4096])\n",
213
+ "n = model.layers.9.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
214
+ "n = model.layers.9.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
215
+ "n = model.layers.9.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
216
+ "n = model.layers.9.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
217
+ "n = model.layers.9.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
218
+ "n = model.layers.9.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
219
+ "n = model.layers.9.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
220
+ "n = model.layers.9.input_layernorm.weight, shape torch.Size([4096])\n",
221
+ "n = model.layers.9.post_attention_layernorm.weight, shape torch.Size([4096])\n",
222
+ "n = model.layers.10.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
223
+ "n = model.layers.10.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
224
+ "n = model.layers.10.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
225
+ "n = model.layers.10.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
226
+ "n = model.layers.10.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
227
+ "n = model.layers.10.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
228
+ "n = model.layers.10.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
229
+ "n = model.layers.10.input_layernorm.weight, shape torch.Size([4096])\n",
230
+ "n = model.layers.10.post_attention_layernorm.weight, shape torch.Size([4096])\n",
231
+ "n = model.layers.11.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
232
+ "n = model.layers.11.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
233
+ "n = model.layers.11.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
234
+ "n = model.layers.11.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
235
+ "n = model.layers.11.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
236
+ "n = model.layers.11.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
237
+ "n = model.layers.11.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
238
+ "n = model.layers.11.input_layernorm.weight, shape torch.Size([4096])\n",
239
+ "n = model.layers.11.post_attention_layernorm.weight, shape torch.Size([4096])\n",
240
+ "n = model.layers.12.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
241
+ "n = model.layers.12.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
242
+ "n = model.layers.12.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
243
+ "n = model.layers.12.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
244
+ "n = model.layers.12.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
245
+ "n = model.layers.12.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
246
+ "n = model.layers.12.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
247
+ "n = model.layers.12.input_layernorm.weight, shape torch.Size([4096])\n",
248
+ "n = model.layers.12.post_attention_layernorm.weight, shape torch.Size([4096])\n",
249
+ "n = model.layers.13.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
250
+ "n = model.layers.13.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
251
+ "n = model.layers.13.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
252
+ "n = model.layers.13.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
253
+ "n = model.layers.13.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
254
+ "n = model.layers.13.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
255
+ "n = model.layers.13.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
256
+ "n = model.layers.13.input_layernorm.weight, shape torch.Size([4096])\n",
257
+ "n = model.layers.13.post_attention_layernorm.weight, shape torch.Size([4096])\n",
258
+ "n = model.layers.14.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
259
+ "n = model.layers.14.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
260
+ "n = model.layers.14.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
261
+ "n = model.layers.14.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
262
+ "n = model.layers.14.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
263
+ "n = model.layers.14.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
264
+ "n = model.layers.14.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
265
+ "n = model.layers.14.input_layernorm.weight, shape torch.Size([4096])\n",
266
+ "n = model.layers.14.post_attention_layernorm.weight, shape torch.Size([4096])\n",
267
+ "n = model.layers.15.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
268
+ "n = model.layers.15.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
269
+ "n = model.layers.15.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
270
+ "n = model.layers.15.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
271
+ "n = model.layers.15.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
272
+ "n = model.layers.15.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
273
+ "n = model.layers.15.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
274
+ "n = model.layers.15.input_layernorm.weight, shape torch.Size([4096])\n",
275
+ "n = model.layers.15.post_attention_layernorm.weight, shape torch.Size([4096])\n",
276
+ "n = model.layers.16.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
277
+ "n = model.layers.16.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
278
+ "n = model.layers.16.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
279
+ "n = model.layers.16.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
280
+ "n = model.layers.16.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
281
+ "n = model.layers.16.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
282
+ "n = model.layers.16.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
283
+ "n = model.layers.16.input_layernorm.weight, shape torch.Size([4096])\n",
284
+ "n = model.layers.16.post_attention_layernorm.weight, shape torch.Size([4096])\n",
285
+ "n = model.layers.17.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
286
+ "n = model.layers.17.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
287
+ "n = model.layers.17.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
288
+ "n = model.layers.17.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
289
+ "n = model.layers.17.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
290
+ "n = model.layers.17.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
291
+ "n = model.layers.17.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
292
+ "n = model.layers.17.input_layernorm.weight, shape torch.Size([4096])\n",
293
+ "n = model.layers.17.post_attention_layernorm.weight, shape torch.Size([4096])\n",
294
+ "n = model.layers.18.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
295
+ "n = model.layers.18.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
296
+ "n = model.layers.18.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
297
+ "n = model.layers.18.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
298
+ "n = model.layers.18.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
299
+ "n = model.layers.18.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
300
+ "n = model.layers.18.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
301
+ "n = model.layers.18.input_layernorm.weight, shape torch.Size([4096])\n",
302
+ "n = model.layers.18.post_attention_layernorm.weight, shape torch.Size([4096])\n",
303
+ "n = model.layers.19.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
304
+ "n = model.layers.19.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
305
+ "n = model.layers.19.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
306
+ "n = model.layers.19.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
307
+ "n = model.layers.19.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
308
+ "n = model.layers.19.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
309
+ "n = model.layers.19.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
310
+ "n = model.layers.19.input_layernorm.weight, shape torch.Size([4096])\n",
311
+ "n = model.layers.19.post_attention_layernorm.weight, shape torch.Size([4096])\n",
312
+ "n = model.layers.20.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
313
+ "n = model.layers.20.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
314
+ "n = model.layers.20.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
315
+ "n = model.layers.20.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
316
+ "n = model.layers.20.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
317
+ "n = model.layers.20.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
318
+ "n = model.layers.20.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
319
+ "n = model.layers.20.input_layernorm.weight, shape torch.Size([4096])\n",
320
+ "n = model.layers.20.post_attention_layernorm.weight, shape torch.Size([4096])\n",
321
+ "n = model.layers.21.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
322
+ "n = model.layers.21.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
323
+ "n = model.layers.21.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
324
+ "n = model.layers.21.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
325
+ "n = model.layers.21.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
326
+ "n = model.layers.21.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
327
+ "n = model.layers.21.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
328
+ "n = model.layers.21.input_layernorm.weight, shape torch.Size([4096])\n",
329
+ "n = model.layers.21.post_attention_layernorm.weight, shape torch.Size([4096])\n",
330
+ "n = model.layers.22.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
331
+ "n = model.layers.22.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
332
+ "n = model.layers.22.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
333
+ "n = model.layers.22.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
334
+ "n = model.layers.22.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
335
+ "n = model.layers.22.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
336
+ "n = model.layers.22.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
337
+ "n = model.layers.22.input_layernorm.weight, shape torch.Size([4096])\n",
338
+ "n = model.layers.22.post_attention_layernorm.weight, shape torch.Size([4096])\n",
339
+ "n = model.layers.23.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
340
+ "n = model.layers.23.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
341
+ "n = model.layers.23.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
342
+ "n = model.layers.23.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
343
+ "n = model.layers.23.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
344
+ "n = model.layers.23.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
345
+ "n = model.layers.23.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
346
+ "n = model.layers.23.input_layernorm.weight, shape torch.Size([4096])\n",
347
+ "n = model.layers.23.post_attention_layernorm.weight, shape torch.Size([4096])\n",
348
+ "n = model.layers.24.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
349
+ "n = model.layers.24.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
350
+ "n = model.layers.24.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
351
+ "n = model.layers.24.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
352
+ "n = model.layers.24.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
353
+ "n = model.layers.24.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
354
+ "n = model.layers.24.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
355
+ "n = model.layers.24.input_layernorm.weight, shape torch.Size([4096])\n",
356
+ "n = model.layers.24.post_attention_layernorm.weight, shape torch.Size([4096])\n",
357
+ "n = model.layers.25.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
358
+ "n = model.layers.25.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
359
+ "n = model.layers.25.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
360
+ "n = model.layers.25.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
361
+ "n = model.layers.25.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
362
+ "n = model.layers.25.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
363
+ "n = model.layers.25.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
364
+ "n = model.layers.25.input_layernorm.weight, shape torch.Size([4096])\n",
365
+ "n = model.layers.25.post_attention_layernorm.weight, shape torch.Size([4096])\n",
366
+ "n = model.layers.26.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
367
+ "n = model.layers.26.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
368
+ "n = model.layers.26.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
369
+ "n = model.layers.26.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
370
+ "n = model.layers.26.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
371
+ "n = model.layers.26.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
372
+ "n = model.layers.26.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
373
+ "n = model.layers.26.input_layernorm.weight, shape torch.Size([4096])\n",
374
+ "n = model.layers.26.post_attention_layernorm.weight, shape torch.Size([4096])\n",
375
+ "n = model.layers.27.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
376
+ "n = model.layers.27.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
377
+ "n = model.layers.27.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
378
+ "n = model.layers.27.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
379
+ "n = model.layers.27.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
380
+ "n = model.layers.27.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
381
+ "n = model.layers.27.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
382
+ "n = model.layers.27.input_layernorm.weight, shape torch.Size([4096])\n",
383
+ "n = model.layers.27.post_attention_layernorm.weight, shape torch.Size([4096])\n",
384
+ "n = model.layers.28.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
385
+ "n = model.layers.28.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
386
+ "n = model.layers.28.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
387
+ "n = model.layers.28.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
388
+ "n = model.layers.28.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
389
+ "n = model.layers.28.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
390
+ "n = model.layers.28.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
391
+ "n = model.layers.28.input_layernorm.weight, shape torch.Size([4096])\n",
392
+ "n = model.layers.28.post_attention_layernorm.weight, shape torch.Size([4096])\n",
393
+ "n = model.layers.29.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
394
+ "n = model.layers.29.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
395
+ "n = model.layers.29.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
396
+ "n = model.layers.29.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
397
+ "n = model.layers.29.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
398
+ "n = model.layers.29.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
399
+ "n = model.layers.29.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
400
+ "n = model.layers.29.input_layernorm.weight, shape torch.Size([4096])\n",
401
+ "n = model.layers.29.post_attention_layernorm.weight, shape torch.Size([4096])\n",
402
+ "n = model.layers.30.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
403
+ "n = model.layers.30.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
404
+ "n = model.layers.30.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
405
+ "n = model.layers.30.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
406
+ "n = model.layers.30.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
407
+ "n = model.layers.30.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
408
+ "n = model.layers.30.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
409
+ "n = model.layers.30.input_layernorm.weight, shape torch.Size([4096])\n",
410
+ "n = model.layers.30.post_attention_layernorm.weight, shape torch.Size([4096])\n",
411
+ "n = model.layers.31.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n",
412
+ "n = model.layers.31.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n",
413
+ "n = model.layers.31.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n",
414
+ "n = model.layers.31.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n",
415
+ "n = model.layers.31.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n",
416
+ "n = model.layers.31.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n",
417
+ "n = model.layers.31.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n",
418
+ "n = model.layers.31.input_layernorm.weight, shape torch.Size([4096])\n",
419
+ "n = model.layers.31.post_attention_layernorm.weight, shape torch.Size([4096])\n",
420
+ "n = model.norm.weight, shape torch.Size([4096])\n",
421
+ "n = lm_head.weight, shape torch.Size([32000, 4096])\n"
422
+ ]
423
+ }
424
+ ],
425
+ "source": [
426
+ "from transformers import AutoModelForCausalLM\n",
427
+ "\n",
428
+ "# Tải mô hình Llama2-7B từ Hugging Face\n",
429
+ "model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n",
430
+ "\n",
431
+ "# Đếm số lượng tham số\n",
432
+ "num_params = sum(p.numel() for p in model.parameters())\n",
433
+ "print(f\"Số tham số của Llama2-7B: {num_params:,}\")\n",
434
+ "#print(model)\n",
435
+ "\n",
436
+ "for n, p in model.named_parameters():\n",
437
+ " print(f'n = {n}, shape {p.shape}')\n"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "id": "9538f476",
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "import torch\n",
448
+ "\n",
449
+ "def debug_shared_weights(model):\n",
450
+ " \"\"\"\n",
451
+ " Debug utility to verify if hypernetxs is truly shared across layers.\n",
452
+ " It checks the memory addresses of the parameters.\n",
453
+ " \"\"\"\n",
454
+ " print(\"--- Debugging Shared Weights ---\")\n",
455
+ " \n",
456
+ " # Access the shared module in different layers\n",
457
+ " # Adjust the path based on your actual structure (e.g. model.model.layers...)\n",
458
+ " layer_0_param = model.model.layers[0].hypernetxs.latent_proj.weight\n",
459
+ " layer_1_param = model.model.layers[1].hypernetxs.latent_proj.weight\n",
460
+ " main_param = model.model.hypernetxs.latent_proj.weight\n",
461
+ " \n",
462
+ " # 1. Check Memory Address (The most reliable check)\n",
463
+ " addr_0 = layer_0_param.data_ptr()\n",
464
+ " addr_1 = layer_1_param.data_ptr()\n",
465
+ " addr_main = main_param.data_ptr()\n",
466
+ " \n",
467
+ " print(f\"Layer 0 Param Address: {addr_0}\")\n",
468
+ " print(f\"Layer 1 Param Address: {addr_1}\")\n",
469
+ " print(f\"Main Model Param Address: {addr_main}\")\n",
470
+ " \n",
471
+ " if addr_0 == addr_1 == addr_main:\n",
472
+ " print(\">> SUCCESS: Parameters are sharing the same memory.\")\n",
473
+ " else:\n",
474
+ " print(\">> WARNING: Parameters are NOT shared. They are copies!\")\n",
475
+ "\n",
476
+ " # 2. Functional Check (Modify one, check others)\n",
477
+ " with torch.no_grad():\n",
478
+ " # Add a small value to layer 0\n",
479
+ " original_val = layer_1_param[0,0].item()\n",
480
+ " layer_0_param[0,0] += 1.0\n",
481
+ " new_val = layer_1_param[0,0].item()\n",
482
+ " \n",
483
+ " if new_val == original_val + 1.0:\n",
484
+ " print(\">> SUCCESS: Modification in Layer 0 reflected in Layer 1.\")\n",
485
+ " else:\n",
486
+ " print(\">> FAILURE: Modification did not propagate.\")\n",
487
+ " \n",
488
+ " # Revert change\n",
489
+ " layer_0_param[0,0] -= 1.0\n",
490
+ "\n",
491
+ "# Usage inside your main flow\n",
492
+ "debug_shared_weights(my_xs_model)"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": 11,
498
+ "id": "cee557d0",
499
+ "metadata": {},
500
+ "outputs": [
501
+ {
502
+ "name": "stdout",
503
+ "output_type": "stream",
504
+ "text": [
505
+ "/home/work/an_nguyen/Instance-based-FT/src\n"
506
+ ]
507
+ }
508
+ ],
509
+ "source": [
510
+ "!pwd"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 12,
516
+ "id": "f60f82a4",
517
+ "metadata": {},
518
+ "outputs": [
519
+ {
520
+ "name": "stdout",
521
+ "output_type": "stream",
522
+ "text": [
523
+ ">>> Loading checkpoint: ../SVD_llama2/pytorch_model.bin\n",
524
+ ">>> Please wait, mapping to CPU...\n",
525
+ "\n",
526
+ "============================================================\n",
527
+ "KEY NAME | SHAPE | DTYPE | SIZE (MB) \n",
528
+ "============================================================\n",
529
+ "model.hypernetxs_cross_attn_tokens | [4, 128] | torch.float32 | 0.0020\n",
530
+ "model.embed_tokens.weight | [32000, 128] | torch.float32 | 15.6250\n",
531
+ "model.layers.0.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n",
532
+ "model.layers.0.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
533
+ "model.layers.0.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
534
+ "model.layers.0.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
535
+ "model.layers.0.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
536
+ "model.layers.0.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
537
+ "model.layers.0.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
538
+ "model.layers.0.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
539
+ "model.layers.0.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
540
+ "model.layers.0.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
541
+ "model.layers.0.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
542
+ "model.layers.0.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
543
+ "model.layers.0.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
544
+ "model.layers.0.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
545
+ "model.layers.0.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
546
+ "model.layers.0.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
547
+ "model.layers.0.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
548
+ "model.layers.0.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
549
+ "model.layers.0.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
550
+ "model.layers.0.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n",
551
+ "model.layers.0.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n",
552
+ "model.layers.0.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
553
+ "model.layers.0.input_layernorm.weight | [128] | torch.float32 | 0.0005\n",
554
+ "model.layers.0.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n",
555
+ "model.layers.0.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n",
556
+ "model.layers.0.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n",
557
+ "model.layers.0.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n",
558
+ "model.layers.0.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n",
559
+ "model.layers.0.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n",
560
+ "model.layers.0.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n",
561
+ "model.layers.0.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n",
562
+ "model.layers.0.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n",
563
+ "model.layers.0.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n",
564
+ "model.layers.0.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n",
565
+ "model.layers.0.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n",
566
+ "model.layers.0.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n",
567
+ "model.layers.0.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n",
568
+ "model.layers.0.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n",
569
+ "model.layers.0.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n",
570
+ "model.layers.0.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n",
571
+ "model.layers.1.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n",
572
+ "model.layers.1.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
573
+ "model.layers.1.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
574
+ "model.layers.1.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
575
+ "model.layers.1.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
576
+ "model.layers.1.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
577
+ "model.layers.1.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
578
+ "model.layers.1.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
579
+ "model.layers.1.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
580
+ "model.layers.1.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
581
+ "model.layers.1.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
582
+ "model.layers.1.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
583
+ "model.layers.1.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
584
+ "model.layers.1.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
585
+ "model.layers.1.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
586
+ "model.layers.1.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
587
+ "model.layers.1.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
588
+ "model.layers.1.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
589
+ "model.layers.1.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
590
+ "model.layers.1.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n",
591
+ "model.layers.1.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n",
592
+ "model.layers.1.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
593
+ "model.layers.1.input_layernorm.weight | [128] | torch.float32 | 0.0005\n",
594
+ "model.layers.1.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n",
595
+ "model.layers.1.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n",
596
+ "model.layers.1.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n",
597
+ "model.layers.1.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n",
598
+ "model.layers.1.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n",
599
+ "model.layers.1.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n",
600
+ "model.layers.1.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n",
601
+ "model.layers.1.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n",
602
+ "model.layers.1.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n",
603
+ "model.layers.1.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n",
604
+ "model.layers.1.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n",
605
+ "model.layers.1.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n",
606
+ "model.layers.1.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n",
607
+ "model.layers.1.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n",
608
+ "model.layers.1.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n",
609
+ "model.layers.1.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n",
610
+ "model.layers.1.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n",
611
+ "model.layers.2.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n",
612
+ "model.layers.2.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
613
+ "model.layers.2.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
614
+ "model.layers.2.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
615
+ "model.layers.2.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
616
+ "model.layers.2.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
617
+ "model.layers.2.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
618
+ "model.layers.2.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
619
+ "model.layers.2.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
620
+ "model.layers.2.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
621
+ "model.layers.2.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n",
622
+ "model.layers.2.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
623
+ "model.layers.2.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
624
+ "model.layers.2.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
625
+ "model.layers.2.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
626
+ "model.layers.2.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
627
+ "model.layers.2.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n",
628
+ "model.layers.2.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n",
629
+ "model.layers.2.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n",
630
+ "model.layers.2.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n",
631
+ "model.layers.2.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n",
632
+ "model.layers.2.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n",
633
+ "model.layers.2.input_layernorm.weight | [128] | torch.float32 | 0.0005\n",
634
+ "model.layers.2.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n",
635
+ "model.layers.2.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n",
636
+ "model.layers.2.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n",
637
+ "model.layers.2.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n",
638
+ "model.layers.2.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n",
639
+ "model.layers.2.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n",
640
+ "model.layers.2.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n",
641
+ "model.layers.2.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n",
642
+ "model.layers.2.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n",
643
+ "model.layers.2.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n",
644
+ "model.layers.2.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n",
645
+ "model.layers.2.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n",
646
+ "model.layers.2.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n",
647
+ "model.layers.2.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n",
648
+ "model.layers.2.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n",
649
+ "model.layers.2.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n",
650
+ "model.layers.2.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n",
651
+ "model.norm.weight | [128] | torch.float32 | 0.0005\n",
652
+ "model.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n",
653
+ "model.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n",
654
+ "model.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n",
655
+ "model.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n",
656
+ "model.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n",
657
+ "model.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n",
658
+ "model.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n",
659
+ "model.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n",
660
+ "model.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n",
661
+ "model.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n",
662
+ "model.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n",
663
+ "model.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n",
664
+ "model.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n",
665
+ "model.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n",
666
+ "model.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n",
667
+ "model.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n",
668
+ "lm_head.weight | [32000, 128] | torch.float32 | 15.6250\n",
669
+ "============================================================\n",
670
+ "\n",
671
+ ">>> SUMMARY STATISTICS:\n",
672
+ "Total Keys found: 140\n",
673
+ "Total Parameters: 15,454,403\n",
674
+ "Total Size (calculated): 58.95 MB\n",
675
+ "\n",
676
+ ">>> GROUP ANALYSIS (Where are the weights?):\n",
677
+ " - Prefix 'model': 139 items found.\n",
678
+ " - Prefix 'lm_head': 1 items found.\n",
679
+ "\n",
680
+ "[!!!] CRITICAL INSIGHT: Layers exist but are extremely small.\n",
681
+ "Check if you saved 'Float8' or empty tensors, or if Rank is effectively 0.\n"
682
+ ]
683
+ }
684
+ ],
685
+ "source": [
686
+ "import torch\n",
687
+ "import os\n",
688
+ "import sys\n",
689
+ "\n",
690
+ "def inspect_checkpoint(file_path):\n",
691
+ " \"\"\"\n",
692
+ " Loads a pytorch_model.bin file and analyzes its content:\n",
693
+ " keys, shapes, dtypes, and memory footprint.\n",
694
+ " \"\"\"\n",
695
+ " \n",
696
+ " if not os.path.exists(file_path):\n",
697
+ " print(f\"Error: File not found at {file_path}\")\n",
698
+ " return\n",
699
+ "\n",
700
+ " print(f\">>> Loading checkpoint: {file_path}\")\n",
701
+ " print(\">>> Please wait, mapping to CPU...\")\n",
702
+ " \n",
703
+ " try:\n",
704
+ " # Load state_dict to CPU to avoid OOM\n",
705
+ " state_dict = torch.load(file_path, map_location=\"cpu\", weights_only=True)\n",
706
+ " except Exception as e:\n",
707
+ " print(f\"Error loading file: {e}\")\n",
708
+ " return\n",
709
+ "\n",
710
+ " print(\"\\n\" + \"=\"*60)\n",
711
+ " print(f\"{'KEY NAME':<50} | {'SHAPE':<20} | {'DTYPE':<10} | {'SIZE (MB)':<10}\")\n",
712
+ " print(\"=\"*60)\n",
713
+ "\n",
714
+ " total_size_bytes = 0\n",
715
+ " total_params = 0\n",
716
+ " grouped_keys = {}\n",
717
+ "\n",
718
+ " for key, tensor in state_dict.items():\n",
719
+ " # Calculate size in MB\n",
720
+ " numel = tensor.numel()\n",
721
+ " element_size = tensor.element_size()\n",
722
+ " size_mb = (numel * element_size) / (1024 * 1024)\n",
723
+ " \n",
724
+ " total_size_bytes += numel * element_size\n",
725
+ " total_params += numel\n",
726
+ "\n",
727
+ " # Print details for every key (or uncomment logic below to summarize)\n",
728
+ " # To avoid flooding console, we categorize by prefix\n",
729
+ " prefix = key.split('.')[0]\n",
730
+ " if prefix not in grouped_keys:\n",
731
+ " grouped_keys[prefix] = []\n",
732
+ " grouped_keys[prefix].append(key)\n",
733
+ "\n",
734
+ " # Print only if it's a \"suspiciously\" large or small tensor, or just print all\n",
735
+ " # For debugging your 40MB issue, let's print everything if < 100 keys, \n",
736
+ " # otherwise just print the first few of each group.\n",
737
+ " print(f\"{key:<50} | {str(list(tensor.shape)):<20} | {str(tensor.dtype):<10} | {size_mb:.4f}\")\n",
738
+ "\n",
739
+ " print(\"=\"*60)\n",
740
+ " print(\"\\n>>> SUMMARY STATISTICS:\")\n",
741
+ " print(f\"Total Keys found: {len(state_dict)}\")\n",
742
+ " print(f\"Total Parameters: {total_params:,}\")\n",
743
+ " print(f\"Total Size (calculated): {total_size_bytes / (1024*1024):.2f} MB\")\n",
744
+ " \n",
745
+ " print(\"\\n>>> GROUP ANALYSIS (Where are the weights?):\")\n",
746
+ " for prefix, keys in grouped_keys.items():\n",
747
+ " print(f\" - Prefix '{prefix}': {len(keys)} items found.\")\n",
748
+ " # Check if 'model' prefix exists (standard for Llama)\n",
749
+ " \n",
750
+ " # Heuristics based on your 40MB issue\n",
751
+ " has_layers = any(\"layers\" in k for k in state_dict.keys())\n",
752
+ " has_backbone = any(\"model.layers\" in k for k in state_dict.keys())\n",
753
+ " \n",
754
+ " if not has_backbone:\n",
755
+ " print(\"\\n[!!!] CRITICAL INSIGHT: The 'model.layers' keys are MISSING.\")\n",
756
+ " print(\"This means the main backbone weights were NOT saved.\")\n",
757
+ " print(\"Only the HyperNet or Head weights seem to be present.\")\n",
758
+ " elif total_size_bytes / (1024*1024) < 100:\n",
759
+ " print(\"\\n[!!!] CRITICAL INSIGHT: Layers exist but are extremely small.\")\n",
760
+ " print(\"Check if you saved 'Float8' or empty tensors, or if Rank is effectively 0.\")\n",
761
+ "\n",
762
+ "if __name__ == \"__main__\":\n",
763
+ " # Replace with the actual path to your bin file\n",
764
+ " # Example: \"xs_model_output/pytorch_model.bin\"\n",
765
+ " chk_path = \"../SVD_llama2/pytorch_model.bin\" \n",
766
+ " \n",
767
+ " \n",
768
+ " inspect_checkpoint(chk_path)"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": null,
774
+ "id": "1a69ffd1",
775
+ "metadata": {},
776
+ "outputs": [],
777
+ "source": [
778
+ " # print('model', self.model)\n",
779
+ " # for n, p in self.model.named_parameters():\n",
780
+ " # print('n,p', n, p.shape)\n",
781
+ " # exit()"
782
+ ]
783
+ }
784
+ ],
785
+ "metadata": {
786
+ "kernelspec": {
787
+ "display_name": "allm",
788
+ "language": "python",
789
+ "name": "python3"
790
+ },
791
+ "language_info": {
792
+ "codemirror_mode": {
793
+ "name": "ipython",
794
+ "version": 3
795
+ },
796
+ "file_extension": ".py",
797
+ "mimetype": "text/x-python",
798
+ "name": "python",
799
+ "nbconvert_exporter": "python",
800
+ "pygments_lexer": "ipython3",
801
+ "version": "3.11.3"
802
+ }
803
+ },
804
+ "nbformat": 4,
805
+ "nbformat_minor": 5
806
+ }