{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "7e7899f4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.7923, -0.1882, 0.8791, 0.8785],\n", " [-0.3649, 0.3171, 0.2766, 0.4714],\n", " [ 0.5661, 0.4688, -1.5763, -0.7690],\n", " [ 1.2863, -0.4760, -2.0309, -2.4342],\n", " [ 0.1591, 1.2439, -1.0475, 0.6328],\n", " [ 0.3351, 0.5378, -1.2086, 0.9963]], grad_fn=) tensor([[-0.4674, 0.0799, 0.8670, 0.0765],\n", " [ 0.2153, -0.1855, 0.0422, -0.1279],\n", " [-0.3339, -0.3323, -0.2219, -0.1967],\n", " [-0.7588, 0.2398, -0.3984, -0.1867],\n", " [-0.0939, -0.8113, 0.1191, -0.3375],\n", " [-0.1977, -0.3647, -0.1560, 0.8890]], grad_fn=) tensor([[-1.6951, 0.1378, 2.0534, 1.5385],\n", " [ 0.0000, -1.5491, 1.3466, -1.2221],\n", " [ 0.0000, 0.0000, 1.9967, 1.8420],\n", " [ 0.0000, 0.0000, 0.0000, 1.2846]], grad_fn=)\n", "Output Shape: torch.Size([4, 6])\n", "\n", "Gram Matrix (M @ M.T):\n", "tensor([[ 1.0000e+00, -2.5258e-08, -1.9981e-07, 6.6143e-09],\n", " [-2.5258e-08, 1.0000e+00, 5.1411e-08, 3.9070e-08],\n", " [-1.9981e-07, 5.1411e-08, 1.0000e+00, 1.6955e-08],\n", " [ 6.6143e-09, 3.9070e-08, 1.6955e-08, 1.0000e+00]],\n", " grad_fn=)\n", "\n", "Orthogonality Error: 0.000000\n" ] } ], "source": [ "import torch\n", "\n", "def create_orthogonal_rows_matrix(rows: int, cols: int):\n", " \"\"\"\n", " Creates a rectangular matrix (rows x cols) with orthonormal rows using QR decomposition.\n", " Condition: cols >= rows.\n", " \"\"\"\n", " # Create a random input matrix (requires_grad=True to test differentiability later)\n", " # Shape: (rows, cols)\n", " X = torch.randn(rows, cols, requires_grad=True)\n", " \n", " # 1. Transpose to get a \"tall\" matrix (cols x rows)\n", " # We do this because standard QR produces orthogonal columns.\n", " X_T = X.T\n", " \n", " # 2. Apply QR Decomposition\n", " # Q will have shape (cols, rows) with orthogonal columns\n", " # R will be upper triangular\n", " Q_T, R = torch.linalg.qr(X_T, mode='reduced')\n", " print(X_T, Q_T, R)\n", " \n", " # 3. Transpose Q back to get the desired shape (rows, cols)\n", " # Now, M has orthogonal rows.\n", " M = Q_T.T\n", " \n", " return X, M\n", "\n", "# --- Usage Example ---\n", "\n", "# Configuration: 3 rows, 5 columns (Rectangular \"fat\" matrix)\n", "m, n = 4, 6\n", "\n", "# Create the matrix\n", "input_tensor, ortho_matrix = create_orthogonal_rows_matrix(m, n)\n", "\n", "print(f\"Output Shape: {ortho_matrix.shape}\") # Should be torch.Size([3, 5])\n", "\n", "# --- Verification ---\n", "\n", "# Check orthogonality: M @ M.T should be Identity matrix (3x3)\n", "gram_matrix = torch.matmul(ortho_matrix, ortho_matrix.T)\n", "identity = torch.eye(m)\n", "\n", "print(\"\\nGram Matrix (M @ M.T):\")\n", "print(gram_matrix)\n", "\n", "# Check error\n", "error = torch.dist(gram_matrix, identity)\n", "print(f\"\\nOrthogonality Error: {error.item():.6f}\")\n", "\n", "# Note: The result is very close to 0, confirming orthogonality." ] }, { "cell_type": "code", "execution_count": 1, "id": "0b4c7963", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NumPy: 2.2.6, SciPy:\n" ] } ], "source": [ "import numpy; import scipy\n", "print(f'NumPy: {numpy.__version__}, SciPy:')" ] }, { "cell_type": "code", "execution_count": 2, "id": "6241e0ab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.22s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Số tham số của Llama2-7B: 6,738,415,616\n", "n = model.embed_tokens.weight, shape torch.Size([32000, 4096])\n", "n = model.layers.0.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.0.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.0.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.0.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.0.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.0.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.0.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.0.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.0.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.1.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.1.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.1.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.1.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.1.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.1.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.1.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.1.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.1.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.2.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.2.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.2.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.2.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.2.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.2.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.2.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.2.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.2.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.3.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.3.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.3.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.3.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.3.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.3.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.3.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.3.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.3.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.4.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.4.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.4.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.4.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.4.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.4.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.4.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.4.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.4.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.5.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.5.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.5.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.5.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.5.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.5.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.5.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.5.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.5.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.6.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.6.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.6.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.6.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.6.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.6.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.6.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.6.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.6.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.7.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.7.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.7.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.7.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.7.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.7.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.7.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.7.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.7.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.8.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.8.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.8.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.8.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.8.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.8.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.8.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.8.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.8.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.9.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.9.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.9.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.9.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.9.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.9.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.9.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.9.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.9.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.10.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.10.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.10.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.10.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.10.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.10.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.10.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.10.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.10.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.11.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.11.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.11.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.11.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.11.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.11.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.11.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.11.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.11.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.12.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.12.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.12.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.12.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.12.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.12.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.12.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.12.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.12.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.13.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.13.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.13.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.13.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.13.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.13.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.13.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.13.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.13.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.14.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.14.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.14.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.14.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.14.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.14.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.14.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.14.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.14.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.15.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.15.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.15.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.15.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.15.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.15.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.15.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.15.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.15.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.16.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.16.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.16.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.16.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.16.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.16.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.16.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.16.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.16.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.17.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.17.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.17.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.17.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.17.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.17.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.17.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.17.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.17.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.18.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.18.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.18.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.18.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.18.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.18.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.18.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.18.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.18.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.19.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.19.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.19.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.19.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.19.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.19.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.19.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.19.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.19.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.20.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.20.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.20.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.20.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.20.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.20.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.20.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.20.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.20.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.21.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.21.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.21.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.21.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.21.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.21.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.21.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.21.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.21.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.22.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.22.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.22.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.22.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.22.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.22.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.22.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.22.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.22.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.23.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.23.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.23.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.23.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.23.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.23.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.23.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.23.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.23.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.24.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.24.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.24.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.24.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.24.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.24.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.24.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.24.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.24.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.25.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.25.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.25.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.25.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.25.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.25.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.25.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.25.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.25.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.26.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.26.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.26.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.26.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.26.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.26.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.26.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.26.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.26.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.27.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.27.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.27.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.27.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.27.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.27.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.27.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.27.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.27.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.28.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.28.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.28.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.28.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.28.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.28.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.28.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.28.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.28.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.29.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.29.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.29.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.29.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.29.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.29.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.29.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.29.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.29.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.30.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.30.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.30.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.30.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.30.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.30.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.30.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.30.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.30.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.31.self_attn.q_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.31.self_attn.k_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.31.self_attn.v_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.31.self_attn.o_proj.weight, shape torch.Size([4096, 4096])\n", "n = model.layers.31.mlp.gate_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.31.mlp.up_proj.weight, shape torch.Size([11008, 4096])\n", "n = model.layers.31.mlp.down_proj.weight, shape torch.Size([4096, 11008])\n", "n = model.layers.31.input_layernorm.weight, shape torch.Size([4096])\n", "n = model.layers.31.post_attention_layernorm.weight, shape torch.Size([4096])\n", "n = model.norm.weight, shape torch.Size([4096])\n", "n = lm_head.weight, shape torch.Size([32000, 4096])\n" ] } ], "source": [ "from transformers import AutoModelForCausalLM\n", "\n", "# Tải mô hình Llama2-7B từ Hugging Face\n", "model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n", "\n", "# Đếm số lượng tham số\n", "num_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Số tham số của Llama2-7B: {num_params:,}\")\n", "#print(model)\n", "\n", "for n, p in model.named_parameters():\n", " print(f'n = {n}, shape {p.shape}')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9538f476", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def debug_shared_weights(model):\n", " \"\"\"\n", " Debug utility to verify if hypernetxs is truly shared across layers.\n", " It checks the memory addresses of the parameters.\n", " \"\"\"\n", " print(\"--- Debugging Shared Weights ---\")\n", " \n", " # Access the shared module in different layers\n", " # Adjust the path based on your actual structure (e.g. model.model.layers...)\n", " layer_0_param = model.model.layers[0].hypernetxs.latent_proj.weight\n", " layer_1_param = model.model.layers[1].hypernetxs.latent_proj.weight\n", " main_param = model.model.hypernetxs.latent_proj.weight\n", " \n", " # 1. Check Memory Address (The most reliable check)\n", " addr_0 = layer_0_param.data_ptr()\n", " addr_1 = layer_1_param.data_ptr()\n", " addr_main = main_param.data_ptr()\n", " \n", " print(f\"Layer 0 Param Address: {addr_0}\")\n", " print(f\"Layer 1 Param Address: {addr_1}\")\n", " print(f\"Main Model Param Address: {addr_main}\")\n", " \n", " if addr_0 == addr_1 == addr_main:\n", " print(\">> SUCCESS: Parameters are sharing the same memory.\")\n", " else:\n", " print(\">> WARNING: Parameters are NOT shared. They are copies!\")\n", "\n", " # 2. Functional Check (Modify one, check others)\n", " with torch.no_grad():\n", " # Add a small value to layer 0\n", " original_val = layer_1_param[0,0].item()\n", " layer_0_param[0,0] += 1.0\n", " new_val = layer_1_param[0,0].item()\n", " \n", " if new_val == original_val + 1.0:\n", " print(\">> SUCCESS: Modification in Layer 0 reflected in Layer 1.\")\n", " else:\n", " print(\">> FAILURE: Modification did not propagate.\")\n", " \n", " # Revert change\n", " layer_0_param[0,0] -= 1.0\n", "\n", "# Usage inside your main flow\n", "debug_shared_weights(my_xs_model)" ] }, { "cell_type": "code", "execution_count": 11, "id": "cee557d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/work/an_nguyen/Instance-based-FT/src\n" ] } ], "source": [ "!pwd" ] }, { "cell_type": "code", "execution_count": 12, "id": "f60f82a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ">>> Loading checkpoint: ../SVD_llama2/pytorch_model.bin\n", ">>> Please wait, mapping to CPU...\n", "\n", "============================================================\n", "KEY NAME | SHAPE | DTYPE | SIZE (MB) \n", "============================================================\n", "model.hypernetxs_cross_attn_tokens | [4, 128] | torch.float32 | 0.0020\n", "model.embed_tokens.weight | [32000, 128] | torch.float32 | 15.6250\n", "model.layers.0.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n", "model.layers.0.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.0.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.0.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.0.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.0.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.0.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.0.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.0.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.0.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.0.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.0.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n", "model.layers.0.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n", "model.layers.0.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.0.input_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.0.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.0.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n", "model.layers.0.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n", "model.layers.0.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n", "model.layers.0.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.0.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n", "model.layers.0.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.0.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n", "model.layers.0.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n", "model.layers.0.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.0.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n", "model.layers.0.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n", "model.layers.1.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n", "model.layers.1.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.1.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.1.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.1.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.1.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.1.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.1.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.1.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.1.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.1.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.1.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n", "model.layers.1.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n", "model.layers.1.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.1.input_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.1.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.1.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n", "model.layers.1.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n", "model.layers.1.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n", "model.layers.1.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.1.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n", "model.layers.1.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.1.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n", "model.layers.1.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n", "model.layers.1.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.1.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n", "model.layers.1.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n", "model.layers.2.layer_idx_hyperxs | [] | torch.int64 | 0.0000\n", "model.layers.2.self_attn.q_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.2.self_attn.q_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.q_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.k_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.2.self_attn.k_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.k_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.v_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.2.self_attn.v_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.v_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.o_proj.weight | [128, 128] | torch.float32 | 0.0625\n", "model.layers.2.self_attn.o_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.self_attn.o_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.2.mlp.gate_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.2.mlp.gate_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.mlp.gate_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.2.mlp.up_proj.weight | [290, 128] | torch.float32 | 0.1416\n", "model.layers.2.mlp.up_proj.lora_A | [128, 32] | torch.float32 | 0.0156\n", "model.layers.2.mlp.up_proj.lora_B | [32, 290] | torch.float32 | 0.0354\n", "model.layers.2.mlp.down_proj.weight | [128, 290] | torch.float32 | 0.1416\n", "model.layers.2.mlp.down_proj.lora_A | [290, 32] | torch.float32 | 0.0354\n", "model.layers.2.mlp.down_proj.lora_B | [32, 128] | torch.float32 | 0.0156\n", "model.layers.2.input_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.2.post_attention_layernorm.weight | [128] | torch.float32 | 0.0005\n", "model.layers.2.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n", "model.layers.2.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n", "model.layers.2.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n", "model.layers.2.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.2.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n", "model.layers.2.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.2.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n", "model.layers.2.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n", "model.layers.2.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n", "model.layers.2.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n", "model.layers.2.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n", "model.norm.weight | [128] | torch.float32 | 0.0005\n", "model.hypernetxs.latent_proj.weight | [256, 128] | torch.float32 | 0.1250\n", "model.hypernetxs.latent_proj.bias | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.mixture.weight | [256, 1088] | torch.float32 | 1.0625\n", "model.hypernetxs.mixture.bias | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.c_fc.weight | [1024, 256] | torch.float32 | 1.0000\n", "model.hypernetxs.c_fc.bias | [1024] | torch.float32 | 0.0039\n", "model.hypernetxs.c_proj.weight | [1024, 1024] | torch.float32 | 4.0000\n", "model.hypernetxs.c_proj.bias | [1024] | torch.float32 | 0.0039\n", "model.hypernetxs.ln_latent.weight | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.ln_latent.bias | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.ln_1.weight | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.ln_1.bias | [256] | torch.float32 | 0.0010\n", "model.hypernetxs.ln_2.weight | [1024] | torch.float32 | 0.0039\n", "model.hypernetxs.ln_2.bias | [1024] | torch.float32 | 0.0039\n", "model.hypernetxs.layer_embedding.weight | [3, 48] | torch.float32 | 0.0005\n", "model.hypernetxs.module_embedding.weight | [7, 16] | torch.float32 | 0.0004\n", "lm_head.weight | [32000, 128] | torch.float32 | 15.6250\n", "============================================================\n", "\n", ">>> SUMMARY STATISTICS:\n", "Total Keys found: 140\n", "Total Parameters: 15,454,403\n", "Total Size (calculated): 58.95 MB\n", "\n", ">>> GROUP ANALYSIS (Where are the weights?):\n", " - Prefix 'model': 139 items found.\n", " - Prefix 'lm_head': 1 items found.\n", "\n", "[!!!] CRITICAL INSIGHT: Layers exist but are extremely small.\n", "Check if you saved 'Float8' or empty tensors, or if Rank is effectively 0.\n" ] } ], "source": [ "import torch\n", "import os\n", "import sys\n", "\n", "def inspect_checkpoint(file_path):\n", " \"\"\"\n", " Loads a pytorch_model.bin file and analyzes its content:\n", " keys, shapes, dtypes, and memory footprint.\n", " \"\"\"\n", " \n", " if not os.path.exists(file_path):\n", " print(f\"Error: File not found at {file_path}\")\n", " return\n", "\n", " print(f\">>> Loading checkpoint: {file_path}\")\n", " print(\">>> Please wait, mapping to CPU...\")\n", " \n", " try:\n", " # Load state_dict to CPU to avoid OOM\n", " state_dict = torch.load(file_path, map_location=\"cpu\", weights_only=True)\n", " except Exception as e:\n", " print(f\"Error loading file: {e}\")\n", " return\n", "\n", " print(\"\\n\" + \"=\"*60)\n", " print(f\"{'KEY NAME':<50} | {'SHAPE':<20} | {'DTYPE':<10} | {'SIZE (MB)':<10}\")\n", " print(\"=\"*60)\n", "\n", " total_size_bytes = 0\n", " total_params = 0\n", " grouped_keys = {}\n", "\n", " for key, tensor in state_dict.items():\n", " # Calculate size in MB\n", " numel = tensor.numel()\n", " element_size = tensor.element_size()\n", " size_mb = (numel * element_size) / (1024 * 1024)\n", " \n", " total_size_bytes += numel * element_size\n", " total_params += numel\n", "\n", " # Print details for every key (or uncomment logic below to summarize)\n", " # To avoid flooding console, we categorize by prefix\n", " prefix = key.split('.')[0]\n", " if prefix not in grouped_keys:\n", " grouped_keys[prefix] = []\n", " grouped_keys[prefix].append(key)\n", "\n", " # Print only if it's a \"suspiciously\" large or small tensor, or just print all\n", " # For debugging your 40MB issue, let's print everything if < 100 keys, \n", " # otherwise just print the first few of each group.\n", " print(f\"{key:<50} | {str(list(tensor.shape)):<20} | {str(tensor.dtype):<10} | {size_mb:.4f}\")\n", "\n", " print(\"=\"*60)\n", " print(\"\\n>>> SUMMARY STATISTICS:\")\n", " print(f\"Total Keys found: {len(state_dict)}\")\n", " print(f\"Total Parameters: {total_params:,}\")\n", " print(f\"Total Size (calculated): {total_size_bytes / (1024*1024):.2f} MB\")\n", " \n", " print(\"\\n>>> GROUP ANALYSIS (Where are the weights?):\")\n", " for prefix, keys in grouped_keys.items():\n", " print(f\" - Prefix '{prefix}': {len(keys)} items found.\")\n", " # Check if 'model' prefix exists (standard for Llama)\n", " \n", " # Heuristics based on your 40MB issue\n", " has_layers = any(\"layers\" in k for k in state_dict.keys())\n", " has_backbone = any(\"model.layers\" in k for k in state_dict.keys())\n", " \n", " if not has_backbone:\n", " print(\"\\n[!!!] CRITICAL INSIGHT: The 'model.layers' keys are MISSING.\")\n", " print(\"This means the main backbone weights were NOT saved.\")\n", " print(\"Only the HyperNet or Head weights seem to be present.\")\n", " elif total_size_bytes / (1024*1024) < 100:\n", " print(\"\\n[!!!] CRITICAL INSIGHT: Layers exist but are extremely small.\")\n", " print(\"Check if you saved 'Float8' or empty tensors, or if Rank is effectively 0.\")\n", "\n", "if __name__ == \"__main__\":\n", " # Replace with the actual path to your bin file\n", " # Example: \"xs_model_output/pytorch_model.bin\"\n", " chk_path = \"../SVD_llama2/pytorch_model.bin\" \n", " \n", " \n", " inspect_checkpoint(chk_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "1a69ffd1", "metadata": {}, "outputs": [], "source": [ " # print('model', self.model)\n", " # for n, p in self.model.named_parameters():\n", " # print('n,p', n, p.shape)\n", " # exit()" ] } ], "metadata": { "kernelspec": { "display_name": "allm", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 5 }