import torch import torch.nn as nn from typing import Optional, Set from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge import einops from layer import SharedMonarch, SingleSharedBlockDiag def test_single(): # input B * c, output B * r B = 20 block_size_r = 5 block_size_c = 4 r = 2 uni_w = B // r model = SingleSharedBlockDiag(num_unique_weights=uni_w, share_factor=r, block_size_c=block_size_c, block_size_r=block_size_r) x = torch.rand(2, 3, B*block_size_c) y = model(x) print(f"Weights Shape: {model.weights.shape}, input {x.shape}") print(f"Output Shape: {y.shape}") W_group_0 = model.weights[0] x_blk_0 = x[0, 0, 0:block_size_c] y_manual = x_blk_0 @ W_group_0.T y_model = y[0, 0, 0:block_size_r] print("\n--- Correctness Check ---") print(f"Manual: {y_manual.detach().numpy()}") print(f"Model : {y_model.detach().numpy()}") # Assert close assert torch.allclose(y_manual, y_model, atol=1e-6) print(">> Logic Verified: Dense Matrix Multiplication per block is correct.") def test_monarch(): rL, cL = 4, 6 rR, cR = 4, 2 share_factor = 2 model = SharedMonarch(share_factor, rR, cR, rL, cL) x = torch.arange(cL * cR, dtype=torch.float32).unsqueeze(0) y = model(x) print(f"Input shape: {x.shape}") print(f"Output shape: {y.shape}") def manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity'): """ Helper function to force weights to known values for testing. """ with torch.no_grad(): # --- Override Layer R --- if method == 'identity': # Create identity matrices for each block # Weight shape: (num_unique, r, c). We want weights s.t. x @ W.T = x # So W must be Identity. for i in range(model.sama_R.num_unique_weights): nn.init.eye_(model.sama_R.weights[i]) model.sama_R.weights[i] *= val_r # Optional scaling elif method == 'seq': # Fill with sequential integers to track position nn.init.constant_(model.sama_R.weights, 0) count = 1 for i in range(model.sama_R.num_unique_weights): for r in range(model.sama_R.block_size_r): for c in range(model.sama_R.block_size_c): model.sama_R.weights[i, r, c] = count count += 1 # --- Override Layer L --- if method == 'identity': for i in range(model.sama_L.num_unique_weights): nn.init.eye_(model.sama_L.weights[i]) model.sama_L.weights[i] *= val_l def test_monarch_correctness(): print("\n=== STARTING MONARCH LAYER DEBUG/TEST ===") # 1. Setup minimal model to easy calculation # N = 4 (2x2 grid). # This is small enough to print everything. rL, cL = 2, 2 rR, cR = 2, 2 share_factor = 1 model = SharedMonarch(share_factor, rR, cR, rL, cL) # ========================================== # TEST CASE 1: The "Passthrough" (Identity) Test # ========================================== print("\n--- Test 1: Identity/Reconstruction ---") print("Goal: If R=Identity and L=Identity, output must equal input.") print("Logic: x -> I -> P -> I -> P_T -> x") # Force weights to Identity manual_override_weights(model, method='identity') # Create random input x = torch.randn(1, cL * cR) y = model(x) print(f"Input: {x.detach().numpy().round(2)}") print(f"Output: {y.detach().numpy().round(2)}") if torch.allclose(x, y, atol=1e-6): print(">> [PASS] Identity Test: P and P_T are correctly inverse.") else: print(">> [FAIL] Identity Test: Output does not match Input!") return # Stop if basic logic fails # ========================================== # TEST CASE 2: The "Specific Path" (One-Hot) Test # ========================================== print("\n--- Test 2: Manual Calculation Trace ---") print("Goal: Trace a specific value through the layers.") # Setup: # R will double the values (Scale = 2 * Identity) # L will triple the values (Scale = 3 * Identity) # Expected: Output = Input * 2 * 3 = Input * 6 manual_override_weights(model, val_r=2.0, val_l=3.0, method='identity') # Input: One-hot vector at index 1 (value 1.0) # [0, 1, 0, 0] x_one_hot = torch.tensor([[0.0, 1.0, 0.0, 0.0]]) y_test = model(x_one_hot) print(f"Input One-Hot: {x_one_hot.numpy()}") print(f"Output: {y_test.detach().numpy()}") expected_value = 1.0 * 2.0 * 3.0 # = 6.0 if torch.allclose(y_test[0, 1], torch.tensor(expected_value)): print(f">> [PASS] Value Scaling: Input 1.0 became {y_test[0,1].item()} (Expected 6.0)") else: print(f">> [FAIL] Value Scaling: Expected 6.0 but got {y_test[0,1].item()}") # ========================================== # TEST CASE 3: The "Permutation Logic" Check # ========================================== print("\n--- Test 3: Permutation Logic (Complex) ---") print("Goal: Ensure R acts on local blocks (neighbors) and L acts on distant blocks (strided).") # Reset model model = SharedMonarch(share_factor, rR, cR, rL, cL) # Configuration: # We want R to mix [0,1] and [2,3]. # We want L to mix [0,2] and [1,3] (after permute). # Let's set R to be Identity (Pass through). manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity') # Now CHANGE Layer L (The second one) manually. # L has 2 blocks. Block size is 2. # Block 0 of L acts on indices that WERE [0, 2] in original space (due to transpose). # Let's make Block 0 of L sum its inputs: Matrix [[1, 1], [1, 1]] with torch.no_grad(): # L weights shape: (num_unique, block_r, block_c) -> (2, 2, 2) # Set Block 0 to all ones model.sama_L.weights[0] = torch.ones(2, 2) # Set Block 1 to Identity (keep simple) nn.init.eye_(model.sama_L.weights[1]) # Input: Activate index 0 and 2. # x = [1, 0, 1, 0] x_input = torch.tensor([[1.0, 0.0, 1.0, 0.0]]) # Manual Trace Logic: # 1. x -> R (Identity) -> [1, 0, 1, 0] # 2. P (Permute): # Original indices: 0, 1, 2, 3 # View (2,2): [[0,1], [2,3]] # Transpose: [[0,2], [1,3]] # New Vector sequence: [val_at_0, val_at_2, val_at_1, val_at_3] # Values: [1, 1, 0, 0] # 3. L (Block Diag): # Vector is [1, 1] (Block 0 input) and [0, 0] (Block 1 input). # Block 0 of L is "All Ones" matrix. # Calculation: [1, 1] @ [[1,1],[1,1]].T = [2, 2] # Block 1 of L is Identity. # Calculation: [0, 0] @ I = [0, 0] # Result after L: [2, 2, 0, 0] (This is in Permuted space!) # 4. P_T (Un-permute): # Current View (2,2): [[2, 2], [0, 0]] # Transpose back: [[2, 0], [2, 0]] # Flatten: [2, 0, 2, 0] y_complex = model(x_input) print(f"Input: {x_input.numpy()}") print(f"Expected: [2., 0., 2., 0.]") print(f"Actual: {y_complex.detach().numpy()}") if torch.allclose(y_complex, torch.tensor([[2., 0., 2., 0.]])): print(">> [PASS] Permutation Mixing Logic is correct!") else: print(">> [FAIL] Permutation Mixing Logic is flawed.") import matplotlib.pyplot as plt import torch import numpy as np def visualize_monarch_structure(): """ Visualizes the effective weight matrices of the Monarch layers to prove the structural properties (Sparsity & Butterfly pattern). """ # 1. Initialize Model rL, cL = 4, 4 rR, cR = 4, 4 share_factor = 1 print(f"Visualizing Monarch Structure for N={cL*cR} (Grid {cL}x{cR})...") model = SharedMonarch(share_factor, rR, cR, rL, cL) # 2. Force weights to be random but significant (away from 0) for clear visualization with torch.no_grad(): nn.init.uniform_(model.sama_R.weights, 0.5, 1.0) nn.init.uniform_(model.sama_L.weights, 0.5, 1.0) # 3. Create Identity Matrix to extract Effective Weights input_dim = cL * cR identity_matrix = torch.eye(input_dim) # 4. Extract Matrices with torch.no_grad(): # A. Extract Matrix R (Right Layer Only) # We only run the first layer: x -> R w_r_transposed = model.sama_R(identity_matrix) w_r = w_r_transposed.T # Transpose back to standard (Out, In) format # B. Extract Total Effective Matrix (R -> P -> L -> P_T) w_total_transposed = model(identity_matrix) w_total = w_total_transposed.T # 5. Plotting fig, axes = plt.subplots(1, 2, figsize=(12, 5)) # Plot A: The Local Block Diagonal Matrix (Layer R) im1 = axes[0].imshow(w_r.abs().numpy(), cmap='Blues', interpolation='nearest') axes[0].set_title(f"Layer R (Block Diagonal)\nBlocks of size {cR}x{cR}") axes[0].set_xlabel("Input Index") axes[0].set_ylabel("Intermediate Index") # Plot B: The Full Monarch Matrix (Butterfly Structure) im2 = axes[1].imshow(w_total.abs().numpy(), cmap='Reds', interpolation='nearest') axes[1].set_title(f"Full Monarch (Butterfly)\nGlobal Mixing") axes[1].set_xlabel("Input Index") axes[1].set_ylabel("Output Index") plt.tight_layout() plt.savefig("monarch_viz1.png") print("Graph saved to 'monarch_viz.png'. Please check your file explorer.") plt.close() def test_merge(): model = SharedMonarch(share_factor=1, block_size_cR=3, block_size_rR=4, block_size_cL=4, block_size_rL=5) # Method 1 (Manual) with torch.no_grad(): W_manual = model.get_delta_weight() # Method 2: Identity Trick (Forward) # Note: forward(I) returns là W^T. I = torch.eye(12) with torch.no_grad(): W_forward = model(I).T print(f"Manual Shape: {W_manual.shape}") # (20, 12) print(f"Forward Shape: {W_forward.shape}") # (20, 12) diff = (W_manual - W_forward).abs().max() print(f"Maximum Difference: {diff.item()}") if diff < 1e-6: print(">> Accurate!") if __name__ == "__main__": # test_monarch_correctness() # visualize_monarch_structure() test_merge()