nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
import torch
import torch.nn as nn
from typing import Optional, Set
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
import einops
from layer import SharedMonarch, SingleSharedBlockDiag
def test_single():
# input B * c, output B * r
B = 20
block_size_r = 5
block_size_c = 4
r = 2
uni_w = B // r
model = SingleSharedBlockDiag(num_unique_weights=uni_w, share_factor=r, block_size_c=block_size_c, block_size_r=block_size_r)
x = torch.rand(2, 3, B*block_size_c)
y = model(x)
print(f"Weights Shape: {model.weights.shape}, input {x.shape}")
print(f"Output Shape: {y.shape}")
W_group_0 = model.weights[0]
x_blk_0 = x[0, 0, 0:block_size_c]
y_manual = x_blk_0 @ W_group_0.T
y_model = y[0, 0, 0:block_size_r]
print("\n--- Correctness Check ---")
print(f"Manual: {y_manual.detach().numpy()}")
print(f"Model : {y_model.detach().numpy()}")
# Assert close
assert torch.allclose(y_manual, y_model, atol=1e-6)
print(">> Logic Verified: Dense Matrix Multiplication per block is correct.")
def test_monarch():
rL, cL = 4, 6
rR, cR = 4, 2
share_factor = 2
model = SharedMonarch(share_factor, rR, cR, rL, cL)
x = torch.arange(cL * cR, dtype=torch.float32).unsqueeze(0)
y = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
def manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity'):
"""
Helper function to force weights to known values for testing.
"""
with torch.no_grad():
# --- Override Layer R ---
if method == 'identity':
# Create identity matrices for each block
# Weight shape: (num_unique, r, c). We want weights s.t. x @ W.T = x
# So W must be Identity.
for i in range(model.sama_R.num_unique_weights):
nn.init.eye_(model.sama_R.weights[i])
model.sama_R.weights[i] *= val_r # Optional scaling
elif method == 'seq':
# Fill with sequential integers to track position
nn.init.constant_(model.sama_R.weights, 0)
count = 1
for i in range(model.sama_R.num_unique_weights):
for r in range(model.sama_R.block_size_r):
for c in range(model.sama_R.block_size_c):
model.sama_R.weights[i, r, c] = count
count += 1
# --- Override Layer L ---
if method == 'identity':
for i in range(model.sama_L.num_unique_weights):
nn.init.eye_(model.sama_L.weights[i])
model.sama_L.weights[i] *= val_l
def test_monarch_correctness():
print("\n=== STARTING MONARCH LAYER DEBUG/TEST ===")
# 1. Setup minimal model to easy calculation
# N = 4 (2x2 grid).
# This is small enough to print everything.
rL, cL = 2, 2
rR, cR = 2, 2
share_factor = 1
model = SharedMonarch(share_factor, rR, cR, rL, cL)
# ==========================================
# TEST CASE 1: The "Passthrough" (Identity) Test
# ==========================================
print("\n--- Test 1: Identity/Reconstruction ---")
print("Goal: If R=Identity and L=Identity, output must equal input.")
print("Logic: x -> I -> P -> I -> P_T -> x")
# Force weights to Identity
manual_override_weights(model, method='identity')
# Create random input
x = torch.randn(1, cL * cR)
y = model(x)
print(f"Input: {x.detach().numpy().round(2)}")
print(f"Output: {y.detach().numpy().round(2)}")
if torch.allclose(x, y, atol=1e-6):
print(">> [PASS] Identity Test: P and P_T are correctly inverse.")
else:
print(">> [FAIL] Identity Test: Output does not match Input!")
return # Stop if basic logic fails
# ==========================================
# TEST CASE 2: The "Specific Path" (One-Hot) Test
# ==========================================
print("\n--- Test 2: Manual Calculation Trace ---")
print("Goal: Trace a specific value through the layers.")
# Setup:
# R will double the values (Scale = 2 * Identity)
# L will triple the values (Scale = 3 * Identity)
# Expected: Output = Input * 2 * 3 = Input * 6
manual_override_weights(model, val_r=2.0, val_l=3.0, method='identity')
# Input: One-hot vector at index 1 (value 1.0)
# [0, 1, 0, 0]
x_one_hot = torch.tensor([[0.0, 1.0, 0.0, 0.0]])
y_test = model(x_one_hot)
print(f"Input One-Hot: {x_one_hot.numpy()}")
print(f"Output: {y_test.detach().numpy()}")
expected_value = 1.0 * 2.0 * 3.0 # = 6.0
if torch.allclose(y_test[0, 1], torch.tensor(expected_value)):
print(f">> [PASS] Value Scaling: Input 1.0 became {y_test[0,1].item()} (Expected 6.0)")
else:
print(f">> [FAIL] Value Scaling: Expected 6.0 but got {y_test[0,1].item()}")
# ==========================================
# TEST CASE 3: The "Permutation Logic" Check
# ==========================================
print("\n--- Test 3: Permutation Logic (Complex) ---")
print("Goal: Ensure R acts on local blocks (neighbors) and L acts on distant blocks (strided).")
# Reset model
model = SharedMonarch(share_factor, rR, cR, rL, cL)
# Configuration:
# We want R to mix [0,1] and [2,3].
# We want L to mix [0,2] and [1,3] (after permute).
# Let's set R to be Identity (Pass through).
manual_override_weights(model, val_r=1.0, val_l=1.0, method='identity')
# Now CHANGE Layer L (The second one) manually.
# L has 2 blocks. Block size is 2.
# Block 0 of L acts on indices that WERE [0, 2] in original space (due to transpose).
# Let's make Block 0 of L sum its inputs: Matrix [[1, 1], [1, 1]]
with torch.no_grad():
# L weights shape: (num_unique, block_r, block_c) -> (2, 2, 2)
# Set Block 0 to all ones
model.sama_L.weights[0] = torch.ones(2, 2)
# Set Block 1 to Identity (keep simple)
nn.init.eye_(model.sama_L.weights[1])
# Input: Activate index 0 and 2.
# x = [1, 0, 1, 0]
x_input = torch.tensor([[1.0, 0.0, 1.0, 0.0]])
# Manual Trace Logic:
# 1. x -> R (Identity) -> [1, 0, 1, 0]
# 2. P (Permute):
# Original indices: 0, 1, 2, 3
# View (2,2): [[0,1], [2,3]]
# Transpose: [[0,2], [1,3]]
# New Vector sequence: [val_at_0, val_at_2, val_at_1, val_at_3]
# Values: [1, 1, 0, 0]
# 3. L (Block Diag):
# Vector is [1, 1] (Block 0 input) and [0, 0] (Block 1 input).
# Block 0 of L is "All Ones" matrix.
# Calculation: [1, 1] @ [[1,1],[1,1]].T = [2, 2]
# Block 1 of L is Identity.
# Calculation: [0, 0] @ I = [0, 0]
# Result after L: [2, 2, 0, 0] (This is in Permuted space!)
# 4. P_T (Un-permute):
# Current View (2,2): [[2, 2], [0, 0]]
# Transpose back: [[2, 0], [2, 0]]
# Flatten: [2, 0, 2, 0]
y_complex = model(x_input)
print(f"Input: {x_input.numpy()}")
print(f"Expected: [2., 0., 2., 0.]")
print(f"Actual: {y_complex.detach().numpy()}")
if torch.allclose(y_complex, torch.tensor([[2., 0., 2., 0.]])):
print(">> [PASS] Permutation Mixing Logic is correct!")
else:
print(">> [FAIL] Permutation Mixing Logic is flawed.")
import matplotlib.pyplot as plt
import torch
import numpy as np
def visualize_monarch_structure():
"""
Visualizes the effective weight matrices of the Monarch layers to prove
the structural properties (Sparsity & Butterfly pattern).
"""
# 1. Initialize Model
rL, cL = 4, 4
rR, cR = 4, 4
share_factor = 1
print(f"Visualizing Monarch Structure for N={cL*cR} (Grid {cL}x{cR})...")
model = SharedMonarch(share_factor, rR, cR, rL, cL)
# 2. Force weights to be random but significant (away from 0) for clear visualization
with torch.no_grad():
nn.init.uniform_(model.sama_R.weights, 0.5, 1.0)
nn.init.uniform_(model.sama_L.weights, 0.5, 1.0)
# 3. Create Identity Matrix to extract Effective Weights
input_dim = cL * cR
identity_matrix = torch.eye(input_dim)
# 4. Extract Matrices
with torch.no_grad():
# A. Extract Matrix R (Right Layer Only)
# We only run the first layer: x -> R
w_r_transposed = model.sama_R(identity_matrix)
w_r = w_r_transposed.T # Transpose back to standard (Out, In) format
# B. Extract Total Effective Matrix (R -> P -> L -> P_T)
w_total_transposed = model(identity_matrix)
w_total = w_total_transposed.T
# 5. Plotting
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Plot A: The Local Block Diagonal Matrix (Layer R)
im1 = axes[0].imshow(w_r.abs().numpy(), cmap='Blues', interpolation='nearest')
axes[0].set_title(f"Layer R (Block Diagonal)\nBlocks of size {cR}x{cR}")
axes[0].set_xlabel("Input Index")
axes[0].set_ylabel("Intermediate Index")
# Plot B: The Full Monarch Matrix (Butterfly Structure)
im2 = axes[1].imshow(w_total.abs().numpy(), cmap='Reds', interpolation='nearest')
axes[1].set_title(f"Full Monarch (Butterfly)\nGlobal Mixing")
axes[1].set_xlabel("Input Index")
axes[1].set_ylabel("Output Index")
plt.tight_layout()
plt.savefig("monarch_viz1.png")
print("Graph saved to 'monarch_viz.png'. Please check your file explorer.")
plt.close()
def test_merge():
model = SharedMonarch(share_factor=1,
block_size_cR=3, block_size_rR=4,
block_size_cL=4, block_size_rL=5)
# Method 1 (Manual)
with torch.no_grad():
W_manual = model.get_delta_weight()
# Method 2: Identity Trick (Forward)
# Note: forward(I) returns là W^T.
I = torch.eye(12)
with torch.no_grad():
W_forward = model(I).T
print(f"Manual Shape: {W_manual.shape}") # (20, 12)
print(f"Forward Shape: {W_forward.shape}") # (20, 12)
diff = (W_manual - W_forward).abs().max()
print(f"Maximum Difference: {diff.item()}")
if diff < 1e-6:
print(">> Accurate!")
if __name__ == "__main__":
# test_monarch_correctness()
# visualize_monarch_structure()
test_merge()