| import torch |
| import torch.nn as nn |
| from typing import Optional, Set |
|
|
| from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
| import einops |
|
|
|
|
| from layer import SharedMonarch, SingleSharedBlockDiag |
|
|
|
|
| def test_single(): |
| |
| B = 20 |
| block_size_r = 5 |
| block_size_c = 4 |
| r = 2 |
| uni_w = B // r |
| model = SingleSharedBlockDiag(num_unique_weights=uni_w, share_factor=r, block_size_c=block_size_c, block_size_r=block_size_r) |
| x = torch.rand(2, 3, B*block_size_c) |
| y = model(x) |
| |
| print(f"Weights Shape: {model.weights.shape}, input {x.shape}") |
| print(f"Output Shape: {y.shape}") |
| W_group_0 = model.weights[0] |
|
|
| x_blk_0 = x[0, 0, 0:block_size_c] |
| y_manual = x_blk_0 @ W_group_0.T |
| y_model = y[0, 0, 0:block_size_r] |
|
|
| print("\n--- Correctness Check ---") |
| print(f"Manual: {y_manual.detach().numpy()}") |
| print(f"Model : {y_model.detach().numpy()}") |
| |
| |
| assert torch.allclose(y_manual, y_model, atol=1e-6) |
| print(">> Logic Verified: Dense Matrix Multiplication per block is correct.") |
|
|
| def test_monarch(): |
| rL, cL = 4, 6 |
| rR, cR = 4, 2 |
| share_factor = 2 |
| model = SharedMonarch(share_factor, rR, cR, rL, cL) |
|
|
|
|
| x = torch.arange(cL * cR, dtype=torch.float32).unsqueeze(0) |
| y = model(x) |
|
|
| print(f"Input shape: {x.shape}") |
| print(f"Output shape: {y.shape}") |
|
|
|
|
| def manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity'): |
| """ |
| Helper function to force weights to known values for testing. |
| """ |
| with torch.no_grad(): |
| |
| if method == 'identity': |
| |
| |
| |
| for i in range(model.sama_R.num_unique_weights): |
| nn.init.eye_(model.sama_R.weights[i]) |
| model.sama_R.weights[i] *= val_r |
| |
| elif method == 'seq': |
| |
| nn.init.constant_(model.sama_R.weights, 0) |
| count = 1 |
| for i in range(model.sama_R.num_unique_weights): |
| for r in range(model.sama_R.block_size_r): |
| for c in range(model.sama_R.block_size_c): |
| model.sama_R.weights[i, r, c] = count |
| count += 1 |
|
|
| |
| if method == 'identity': |
| for i in range(model.sama_L.num_unique_weights): |
| nn.init.eye_(model.sama_L.weights[i]) |
| model.sama_L.weights[i] *= val_l |
|
|
| def test_monarch_correctness(): |
| print("\n=== STARTING MONARCH LAYER DEBUG/TEST ===") |
| |
| |
| |
| |
| rL, cL = 2, 2 |
| rR, cR = 2, 2 |
| share_factor = 1 |
| model = SharedMonarch(share_factor, rR, cR, rL, cL) |
| |
| |
| |
| |
| print("\n--- Test 1: Identity/Reconstruction ---") |
| print("Goal: If R=Identity and L=Identity, output must equal input.") |
| print("Logic: x -> I -> P -> I -> P_T -> x") |
| |
| |
| manual_override_weights(model, method='identity') |
| |
| |
| x = torch.randn(1, cL * cR) |
| y = model(x) |
| |
| print(f"Input: {x.detach().numpy().round(2)}") |
| print(f"Output: {y.detach().numpy().round(2)}") |
| |
| if torch.allclose(x, y, atol=1e-6): |
| print(">> [PASS] Identity Test: P and P_T are correctly inverse.") |
| else: |
| print(">> [FAIL] Identity Test: Output does not match Input!") |
| return |
|
|
| |
| |
| |
| print("\n--- Test 2: Manual Calculation Trace ---") |
| print("Goal: Trace a specific value through the layers.") |
| |
| |
| |
| |
| |
| manual_override_weights(model, val_r=2.0, val_l=3.0, method='identity') |
| |
| |
| |
| x_one_hot = torch.tensor([[0.0, 1.0, 0.0, 0.0]]) |
| |
| y_test = model(x_one_hot) |
| |
| print(f"Input One-Hot: {x_one_hot.numpy()}") |
| print(f"Output: {y_test.detach().numpy()}") |
| |
| expected_value = 1.0 * 2.0 * 3.0 |
| if torch.allclose(y_test[0, 1], torch.tensor(expected_value)): |
| print(f">> [PASS] Value Scaling: Input 1.0 became {y_test[0,1].item()} (Expected 6.0)") |
| else: |
| print(f">> [FAIL] Value Scaling: Expected 6.0 but got {y_test[0,1].item()}") |
|
|
| |
| |
| |
| print("\n--- Test 3: Permutation Logic (Complex) ---") |
| print("Goal: Ensure R acts on local blocks (neighbors) and L acts on distant blocks (strided).") |
| |
| |
| model = SharedMonarch(share_factor, rR, cR, rL, cL) |
| |
| |
| |
| |
| |
| |
| manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity') |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| |
| |
| model.sama_L.weights[0] = torch.ones(2, 2) |
| |
| nn.init.eye_(model.sama_L.weights[1]) |
| |
| |
| |
| x_input = torch.tensor([[1.0, 0.0, 1.0, 0.0]]) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| y_complex = model(x_input) |
| |
| print(f"Input: {x_input.numpy()}") |
| print(f"Expected: [2., 0., 2., 0.]") |
| print(f"Actual: {y_complex.detach().numpy()}") |
| |
| if torch.allclose(y_complex, torch.tensor([[2., 0., 2., 0.]])): |
| print(">> [PASS] Permutation Mixing Logic is correct!") |
| else: |
| print(">> [FAIL] Permutation Mixing Logic is flawed.") |
|
|
| import matplotlib.pyplot as plt |
| import torch |
| import numpy as np |
|
|
| def visualize_monarch_structure(): |
| """ |
| Visualizes the effective weight matrices of the Monarch layers to prove |
| the structural properties (Sparsity & Butterfly pattern). |
| """ |
| |
| |
| |
| rL, cL = 4, 4 |
| rR, cR = 4, 4 |
| share_factor = 1 |
| print(f"Visualizing Monarch Structure for N={cL*cR} (Grid {cL}x{cR})...") |
| model = SharedMonarch(share_factor, rR, cR, rL, cL) |
| |
| |
| with torch.no_grad(): |
| nn.init.uniform_(model.sama_R.weights, 0.5, 1.0) |
| nn.init.uniform_(model.sama_L.weights, 0.5, 1.0) |
|
|
| |
| input_dim = cL * cR |
| identity_matrix = torch.eye(input_dim) |
| |
| |
| with torch.no_grad(): |
| |
| |
| w_r_transposed = model.sama_R(identity_matrix) |
| w_r = w_r_transposed.T |
| |
| |
| w_total_transposed = model(identity_matrix) |
| w_total = w_total_transposed.T |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| |
| |
| im1 = axes[0].imshow(w_r.abs().numpy(), cmap='Blues', interpolation='nearest') |
| axes[0].set_title(f"Layer R (Block Diagonal)\nBlocks of size {cR}x{cR}") |
| axes[0].set_xlabel("Input Index") |
| axes[0].set_ylabel("Intermediate Index") |
| |
| |
| im2 = axes[1].imshow(w_total.abs().numpy(), cmap='Reds', interpolation='nearest') |
| axes[1].set_title(f"Full Monarch (Butterfly)\nGlobal Mixing") |
| axes[1].set_xlabel("Input Index") |
| axes[1].set_ylabel("Output Index") |
|
|
| plt.tight_layout() |
| plt.savefig("monarch_viz1.png") |
| print("Graph saved to 'monarch_viz.png'. Please check your file explorer.") |
| plt.close() |
|
|
|
|
|
|
| def test_merge(): |
| model = SharedMonarch(share_factor=1, |
| block_size_cR=3, block_size_rR=4, |
| block_size_cL=4, block_size_rL=5) |
| |
| |
| with torch.no_grad(): |
| W_manual = model.get_delta_weight() |
|
|
| |
| |
| I = torch.eye(12) |
| with torch.no_grad(): |
| W_forward = model(I).T |
| |
| print(f"Manual Shape: {W_manual.shape}") |
| print(f"Forward Shape: {W_forward.shape}") |
| |
| diff = (W_manual - W_forward).abs().max() |
| print(f"Maximum Difference: {diff.item()}") |
| |
| if diff < 1e-6: |
| print(">> Accurate!") |
|
|
| if __name__ == "__main__": |
| |
| |
| test_merge() |