Steelskull commited on
Commit
3aed6d3
·
verified ·
1 Parent(s): 7eb72e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -85
app.py CHANGED
@@ -1,129 +1,213 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import torch
 
 
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
- from tqdm import tqdm
6
  import gradio as gr
7
- import io
8
  import PIL.Image
 
 
 
 
 
9
 
10
  def calculate_weight_diff(base_weight, chat_weight):
 
11
  return torch.abs(base_weight - chat_weight).mean().item()
12
 
13
- def calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=False):
 
14
  layer_diffs = []
 
 
15
  layers = zip(base_model.model.layers, chat_model.model.layers)
 
16
 
17
- if load_one_at_a_time:
18
- for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
19
- layer_diff = {
20
- 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
21
- 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
22
- 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
23
- 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
24
- 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
25
- 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
26
- 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
27
- 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
28
- 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
29
- }
30
- layer_diffs.append(layer_diff)
31
-
32
- base_layer, chat_layer = None, None
33
- del base_layer, chat_layer
34
- else:
35
- for base_layer, chat_layer in tqdm(layers, total=len(base_model.model.layers)):
36
- layer_diff = {
37
- 'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
38
- 'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
39
- 'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
40
- 'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
41
- 'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
42
- 'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
43
- 'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
44
- 'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
45
- 'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
46
- }
47
- layer_diffs.append(layer_diff)
48
 
49
  return layer_diffs
50
 
51
- def visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name):
 
 
 
 
52
  num_layers = len(layer_diffs)
53
- num_components = len(layer_diffs[0])
 
54
 
55
- # Dynamically adjust figure size based on number of layers
56
- height = max(8, num_layers / 8) # Minimum height of 8, scales up for more layers
57
- width = max(24, num_components * 3) # Minimum width of 24, scales with components
58
 
59
- # Create figure with subplots arranged in 2 rows if there are many components
60
  if num_components > 6:
61
  nrows = 2
62
  ncols = (num_components + 1) // 2
63
- fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * 1.5))
64
- axs = axs.flatten()
65
  else:
66
  nrows = 1
67
  ncols = num_components
68
- fig, axs = plt.subplots(1, num_components, figsize=(width, height))
69
 
70
- fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
 
71
 
72
- # Adjust font sizes based on number of layers
73
- tick_font_size = max(6, min(10, 300 / num_layers))
74
- annot_font_size = max(6, min(10, 200 / num_layers))
75
 
76
- for i, component in tqdm(enumerate(layer_diffs[0].keys()), total=len(layer_diffs[0].keys())):
77
- component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
78
- sns.heatmap(component_diffs,
 
 
 
 
 
79
  annot=True,
80
- fmt=".9f",
81
- cmap="YlGnBu",
82
  ax=axs[i],
83
  cbar=False,
84
- annot_kws={'size': annot_font_size})
85
 
86
- axs[i].set_title(component, fontsize=max(10, tick_font_size * 1.2))
87
- axs[i].set_xlabel("Difference", fontsize=tick_font_size)
88
- axs[i].set_ylabel("Layer", fontsize=tick_font_size)
89
- axs[i].set_xticks([])
90
  axs[i].set_yticks(range(num_layers))
91
  axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
92
- axs[i].invert_yaxis()
 
 
 
93
 
94
- # Remove any empty subplots if using 2 rows
95
- if num_components > 6:
96
- for j in range(i + 1, len(axs)):
97
- fig.delaxes(axs[j])
98
 
99
- plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent overlap
100
 
101
- # Convert plot to image
102
  buf = io.BytesIO()
103
- fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
104
  buf.seek(0)
105
- plt.close(fig) # Close the figure to free memory
106
  return PIL.Image.open(buf)
107
 
108
- def gradio_interface(base_model_name, chat_model_name, hf_token, load_one_at_a_time=False):
109
- base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16, token=hf_token)
110
- chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype=torch.bfloat16, token=hf_token)
 
111
 
112
- layer_diffs = calculate_layer_diffs(base_model, chat_model, load_one_at_a_time=load_one_at_a_time)
113
- return visualize_layer_diffs(layer_diffs, base_model_name, chat_model_name)
 
 
 
 
 
 
 
 
 
114
 
115
- if __name__ == "__main__":
116
- iface = gr.Interface(
117
- fn=gradio_interface,
118
- inputs=[
119
- gr.Textbox(label="Base Model Name", lines=2),
120
- gr.Textbox(label="Chat Model Name", lines=2),
121
- gr.Textbox(label="Hugging Face Token", type="password", lines=2),
122
- gr.Checkbox(label="Load one layer at a time")
123
- ],
124
- outputs=gr.Image(type="pil", label="Weight Differences Visualization"),
125
- title="Model Weight Difference Visualizer",
126
- cache_examples=False
127
  )
128
 
129
- iface.launch(share=False, server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
  import torch
3
+ import pandas as pd
4
+ import numpy as np
5
  import matplotlib.pyplot as plt
6
  import seaborn as sns
7
+ import plotly.graph_objects as go
8
  import gradio as gr
 
9
  import PIL.Image
10
+ from transformers import AutoModelForCausalLM
11
+ from tqdm import tqdm
12
+
13
+ # Set style for matplotlib
14
+ sns.set_theme(style="whitegrid")
15
 
16
  def calculate_weight_diff(base_weight, chat_weight):
17
+ """Calculates the mean absolute difference between two tensors."""
18
  return torch.abs(base_weight - chat_weight).mean().item()
19
 
20
+ def calculate_layer_diffs(base_model, chat_model):
21
+ """Iterates through layers and calculates differences for specific projections."""
22
  layer_diffs = []
23
+
24
+ # We zip the layers to iterate through them simultaneously
25
  layers = zip(base_model.model.layers, chat_model.model.layers)
26
+ total_layers = len(base_model.model.layers)
27
 
28
+ # List of components we want to track
29
+ components_to_track = [
30
+ ('input_layernorm', lambda l: l.input_layernorm.weight),
31
+ ('self_attn_q_proj', lambda l: l.self_attn.q_proj.weight),
32
+ ('self_attn_k_proj', lambda l: l.self_attn.k_proj.weight),
33
+ ('self_attn_v_proj', lambda l: l.self_attn.v_proj.weight),
34
+ ('self_attn_o_proj', lambda l: l.self_attn.o_proj.weight),
35
+ ('post_attention_layernorm', lambda l: l.post_attention_layernorm.weight),
36
+ ('mlp_gate_proj', lambda l: l.mlp.gate_proj.weight),
37
+ ('mlp_up_proj', lambda l: l.mlp.up_proj.weight),
38
+ ('mlp_down_proj', lambda l: l.mlp.down_proj.weight),
39
+ ]
40
+
41
+ print("Calculating differences...")
42
+ for base_layer, chat_layer in tqdm(layers, total=total_layers):
43
+ layer_data = {}
44
+ for name, getter in components_to_track:
45
+ try:
46
+ val = calculate_weight_diff(getter(base_layer), getter(chat_layer))
47
+ layer_data[name] = val
48
+ except AttributeError:
49
+ # Handle cases where architecture might differ slightly (e.g., bias terms)
50
+ layer_data[name] = 0.0
51
+
52
+ layer_diffs.append(layer_data)
 
 
 
 
 
 
53
 
54
  return layer_diffs
55
 
56
+ def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
57
+ """Generates the static 2D Heatmap image."""
58
+ if not layer_diffs:
59
+ return None
60
+
61
  num_layers = len(layer_diffs)
62
+ components = list(layer_diffs[0].keys())
63
+ num_components = len(components)
64
 
65
+ # Dynamically adjust figure size
66
+ height = max(8, num_layers / 6)
67
+ width = max(20, num_components * 2.5)
68
 
69
+ # Logic for subplot arrangement
70
  if num_components > 6:
71
  nrows = 2
72
  ncols = (num_components + 1) // 2
 
 
73
  else:
74
  nrows = 1
75
  ncols = num_components
 
76
 
77
+ fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * (1.2 if nrows > 1 else 1)))
78
+ axs = axs.flatten() if num_components > 1 else [axs]
79
 
80
+ fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98)
 
 
81
 
82
+ # Font sizing logic
83
+ tick_font_size = max(6, min(10, 300 / num_layers))
84
+
85
+ for i, component in enumerate(components):
86
+ # Extract data for this specific component across all layers
87
+ data = [[row[component]] for row in layer_diffs]
88
+
89
+ sns.heatmap(data,
90
  annot=True,
91
+ fmt=".6f",
92
+ cmap="viridis",
93
  ax=axs[i],
94
  cbar=False,
95
+ annot_kws={'size': tick_font_size * 0.8})
96
 
97
+ axs[i].set_title(component, fontsize=12, fontweight='bold')
 
 
 
98
  axs[i].set_yticks(range(num_layers))
99
  axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
100
+ axs[i].set_xticks([]) # Hide x-axis ticks for the single column heatmap
101
+ axs[i].invert_yaxis() # Layer 0 at bottom or top? Usually 0 is bottom in diagrams, but top in matrices.
102
+ # Let's keep 0 at top (standard matrix view) or remove invert for 0 at bottom.
103
+ # Standard heatmap has index 0 at top.
104
 
105
+ # Remove empty subplots
106
+ for j in range(i + 1, len(axs)):
107
+ fig.delaxes(axs[j])
 
108
 
109
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
110
 
 
111
  buf = io.BytesIO()
112
+ fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
113
  buf.seek(0)
114
+ plt.close(fig)
115
  return PIL.Image.open(buf)
116
 
117
+ def visualize_3d_surface(layer_diffs):
118
+ """Generates an interactive 3D Surface plot using Plotly."""
119
+ if not layer_diffs:
120
+ return None
121
 
122
+ # Convert list of dicts to DataFrame for easier handling
123
+ df = pd.DataFrame(layer_diffs)
124
+
125
+ # X axis: Components
126
+ x_labels = df.columns.tolist()
127
+ # Y axis: Layers
128
+ y_labels = df.index.tolist()
129
+ # Z axis: Values (Transposed because Plotly expects Z[y][x])
130
+ z_data = df.values
131
+
132
+ fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')])
133
 
134
+ fig.update_layout(
135
+ title='3D Landscape of Weight Differences',
136
+ scene=dict(
137
+ xaxis_title='Model Components',
138
+ yaxis_title='Layer Index',
139
+ zaxis_title='Mean Weight Diff',
140
+ xaxis=dict(tickangle=45),
141
+ ),
142
+ autosize=True,
143
+ height=800,
144
+ margin=dict(l=65, r=50, b=65, t=90)
 
145
  )
146
 
147
+ return fig
148
+
149
+ def process_models(base_name, chat_name, hf_token):
150
+ try:
151
+ print(f"Loading {base_name}...")
152
+ base_model = AutoModelForCausalLM.from_pretrained(
153
+ base_name,
154
+ torch_dtype=torch.bfloat16,
155
+ token=hf_token,
156
+ device_map="cpu", # Force CPU to avoid GPU OOM during comparison if models are large
157
+ trust_remote_code=True
158
+ )
159
+
160
+ print(f"Loading {chat_name}...")
161
+ chat_model = AutoModelForCausalLM.from_pretrained(
162
+ chat_name,
163
+ torch_dtype=torch.bfloat16,
164
+ token=hf_token,
165
+ device_map="cpu",
166
+ trust_remote_code=True
167
+ )
168
+
169
+ diffs = calculate_layer_diffs(base_model, chat_model)
170
+
171
+ # Clean up memory
172
+ del base_model
173
+ del chat_model
174
+ torch.cuda.empty_cache()
175
+
176
+ img_2d = visualize_2d_heatmap(diffs, base_name, chat_name)
177
+ plot_3d = visualize_3d_surface(diffs)
178
+
179
+ return img_2d, plot_3d
180
+
181
+ except Exception as e:
182
+ raise gr.Error(f"Error processing models: {str(e)}")
183
+
184
+ # --- Gradio UI Layout ---
185
+ with gr.Blocks(title="Model Diff Visualizer") as demo:
186
+ gr.Markdown("# 🧠 LLM Weight Difference Visualizer")
187
+ gr.Markdown("Compare the weights of a Base model vs. its Instruct/Chat tuned version layer by layer.")
188
+
189
+ with gr.Row():
190
+ with gr.Column(scale=1):
191
+ base_input = gr.Textbox(label="Base Model Name", placeholder="e.g., meta-llama/Llama-2-7b-hf")
192
+ chat_input = gr.Textbox(label="Chat/Tuned Model Name", placeholder="e.g., meta-llama/Llama-2-7b-chat-hf")
193
+ token_input = gr.Textbox(label="Hugging Face Token (Optional)", type="password")
194
+ submit_btn = gr.Button("🚀 Analyze Differences", variant="primary")
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ gr.Markdown("### 2D Layer-wise Heatmap")
199
+ output_2d = gr.Image(label="2D Visualization", type="pil")
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ gr.Markdown("### 3D Interactive Landscape")
204
+ output_3d = gr.Plot(label="3D Visualization")
205
+
206
+ submit_btn.click(
207
+ fn=process_models,
208
+ inputs=[base_input, chat_input, token_input],
209
+ outputs=[output_2d, output_3d]
210
+ )
211
+
212
+ if __name__ == "__main__":
213
+ demo.launch(share=False, server_port=7860)