Steelskull commited on
Commit
893fdc7
·
verified ·
1 Parent(s): 3aed6d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -29
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import io
2
  import torch
3
  import pandas as pd
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
  import seaborn as sns
 
7
  import plotly.graph_objects as go
8
  import gradio as gr
9
  import PIL.Image
@@ -15,17 +14,19 @@ sns.set_theme(style="whitegrid")
15
 
16
  def calculate_weight_diff(base_weight, chat_weight):
17
  """Calculates the mean absolute difference between two tensors."""
18
- return torch.abs(base_weight - chat_weight).mean().item()
 
 
 
19
 
20
  def calculate_layer_diffs(base_model, chat_model):
21
  """Iterates through layers and calculates differences for specific projections."""
22
  layer_diffs = []
23
 
24
- # We zip the layers to iterate through them simultaneously
25
  layers = zip(base_model.model.layers, chat_model.model.layers)
26
  total_layers = len(base_model.model.layers)
27
 
28
- # List of components we want to track
29
  components_to_track = [
30
  ('input_layernorm', lambda l: l.input_layernorm.weight),
31
  ('self_attn_q_proj', lambda l: l.self_attn.q_proj.weight),
@@ -46,7 +47,6 @@ def calculate_layer_diffs(base_model, chat_model):
46
  val = calculate_weight_diff(getter(base_layer), getter(chat_layer))
47
  layer_data[name] = val
48
  except AttributeError:
49
- # Handle cases where architecture might differ slightly (e.g., bias terms)
50
  layer_data[name] = 0.0
51
 
52
  layer_diffs.append(layer_data)
@@ -62,11 +62,9 @@ def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
62
  components = list(layer_diffs[0].keys())
63
  num_components = len(components)
64
 
65
- # Dynamically adjust figure size
66
  height = max(8, num_layers / 6)
67
  width = max(20, num_components * 2.5)
68
 
69
- # Logic for subplot arrangement
70
  if num_components > 6:
71
  nrows = 2
72
  ncols = (num_components + 1) // 2
@@ -79,11 +77,9 @@ def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
79
 
80
  fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98)
81
 
82
- # Font sizing logic
83
  tick_font_size = max(6, min(10, 300 / num_layers))
84
 
85
  for i, component in enumerate(components):
86
- # Extract data for this specific component across all layers
87
  data = [[row[component]] for row in layer_diffs]
88
 
89
  sns.heatmap(data,
@@ -97,12 +93,9 @@ def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
97
  axs[i].set_title(component, fontsize=12, fontweight='bold')
98
  axs[i].set_yticks(range(num_layers))
99
  axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
100
- axs[i].set_xticks([]) # Hide x-axis ticks for the single column heatmap
101
- axs[i].invert_yaxis() # Layer 0 at bottom or top? Usually 0 is bottom in diagrams, but top in matrices.
102
- # Let's keep 0 at top (standard matrix view) or remove invert for 0 at bottom.
103
- # Standard heatmap has index 0 at top.
104
 
105
- # Remove empty subplots
106
  for j in range(i + 1, len(axs)):
107
  fig.delaxes(axs[j])
108
 
@@ -114,19 +107,14 @@ def visualize_2d_heatmap(layer_diffs, base_model_name, chat_model_name):
114
  plt.close(fig)
115
  return PIL.Image.open(buf)
116
 
117
- def visualize_3d_surface(layer_diffs):
118
- """Generates an interactive 3D Surface plot using Plotly."""
119
  if not layer_diffs:
120
- return None
121
 
122
- # Convert list of dicts to DataFrame for easier handling
123
  df = pd.DataFrame(layer_diffs)
124
-
125
- # X axis: Components
126
  x_labels = df.columns.tolist()
127
- # Y axis: Layers
128
  y_labels = df.index.tolist()
129
- # Z axis: Values (Transposed because Plotly expects Z[y][x])
130
  z_data = df.values
131
 
132
  fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')])
@@ -140,20 +128,25 @@ def visualize_3d_surface(layer_diffs):
140
  xaxis=dict(tickangle=45),
141
  ),
142
  autosize=True,
143
- height=800,
144
  margin=dict(l=65, r=50, b=65, t=90)
145
  )
146
 
147
- return fig
 
148
 
149
  def process_models(base_name, chat_name, hf_token):
 
 
 
 
150
  try:
151
  print(f"Loading {base_name}...")
152
  base_model = AutoModelForCausalLM.from_pretrained(
153
  base_name,
154
  torch_dtype=torch.bfloat16,
155
  token=hf_token,
156
- device_map="cpu", # Force CPU to avoid GPU OOM during comparison if models are large
157
  trust_remote_code=True
158
  )
159
 
@@ -174,11 +167,13 @@ def process_models(base_name, chat_name, hf_token):
174
  torch.cuda.empty_cache()
175
 
176
  img_2d = visualize_2d_heatmap(diffs, base_name, chat_name)
177
- plot_3d = visualize_3d_surface(diffs)
178
 
179
- return img_2d, plot_3d
180
 
181
  except Exception as e:
 
 
182
  raise gr.Error(f"Error processing models: {str(e)}")
183
 
184
  # --- Gradio UI Layout ---
@@ -201,7 +196,8 @@ with gr.Blocks(title="Model Diff Visualizer") as demo:
201
  with gr.Row():
202
  with gr.Column():
203
  gr.Markdown("### 3D Interactive Landscape")
204
- output_3d = gr.Plot(label="3D Visualization")
 
205
 
206
  submit_btn.click(
207
  fn=process_models,
 
1
  import io
2
  import torch
3
  import pandas as pd
 
 
4
  import seaborn as sns
5
+ import matplotlib.pyplot as plt
6
  import plotly.graph_objects as go
7
  import gradio as gr
8
  import PIL.Image
 
14
 
15
  def calculate_weight_diff(base_weight, chat_weight):
16
  """Calculates the mean absolute difference between two tensors."""
17
+ # Move to CPU for calculation to save GPU memory and ensure numpy compatibility
18
+ b_w = base_weight.detach().cpu()
19
+ c_w = chat_weight.detach().cpu()
20
+ return torch.abs(b_w - c_w).mean().item()
21
 
22
  def calculate_layer_diffs(base_model, chat_model):
23
  """Iterates through layers and calculates differences for specific projections."""
24
  layer_diffs = []
25
 
 
26
  layers = zip(base_model.model.layers, chat_model.model.layers)
27
  total_layers = len(base_model.model.layers)
28
 
29
+ # Components to track
30
  components_to_track = [
31
  ('input_layernorm', lambda l: l.input_layernorm.weight),
32
  ('self_attn_q_proj', lambda l: l.self_attn.q_proj.weight),
 
47
  val = calculate_weight_diff(getter(base_layer), getter(chat_layer))
48
  layer_data[name] = val
49
  except AttributeError:
 
50
  layer_data[name] = 0.0
51
 
52
  layer_diffs.append(layer_data)
 
62
  components = list(layer_diffs[0].keys())
63
  num_components = len(components)
64
 
 
65
  height = max(8, num_layers / 6)
66
  width = max(20, num_components * 2.5)
67
 
 
68
  if num_components > 6:
69
  nrows = 2
70
  ncols = (num_components + 1) // 2
 
77
 
78
  fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98)
79
 
 
80
  tick_font_size = max(6, min(10, 300 / num_layers))
81
 
82
  for i, component in enumerate(components):
 
83
  data = [[row[component]] for row in layer_diffs]
84
 
85
  sns.heatmap(data,
 
93
  axs[i].set_title(component, fontsize=12, fontweight='bold')
94
  axs[i].set_yticks(range(num_layers))
95
  axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size)
96
+ axs[i].set_xticks([])
97
+ axs[i].invert_yaxis()
 
 
98
 
 
99
  for j in range(i + 1, len(axs)):
100
  fig.delaxes(axs[j])
101
 
 
107
  plt.close(fig)
108
  return PIL.Image.open(buf)
109
 
110
+ def generate_3d_html(layer_diffs):
111
+ """Generates an interactive 3D Surface plot as an HTML string."""
112
  if not layer_diffs:
113
+ return "<p>No data to display</p>"
114
 
 
115
  df = pd.DataFrame(layer_diffs)
 
 
116
  x_labels = df.columns.tolist()
 
117
  y_labels = df.index.tolist()
 
118
  z_data = df.values
119
 
120
  fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')])
 
128
  xaxis=dict(tickangle=45),
129
  ),
130
  autosize=True,
131
+ height=700,
132
  margin=dict(l=65, r=50, b=65, t=90)
133
  )
134
 
135
+ # Return HTML string instead of Figure object to avoid Gradio schema bugs
136
+ return fig.to_html(include_plotlyjs='cdn', full_html=False)
137
 
138
  def process_models(base_name, chat_name, hf_token):
139
+ # Set default values if empty to prevent crash
140
+ if not base_name or not chat_name:
141
+ raise gr.Error("Please provide both model names.")
142
+
143
  try:
144
  print(f"Loading {base_name}...")
145
  base_model = AutoModelForCausalLM.from_pretrained(
146
  base_name,
147
  torch_dtype=torch.bfloat16,
148
  token=hf_token,
149
+ device_map="cpu",
150
  trust_remote_code=True
151
  )
152
 
 
167
  torch.cuda.empty_cache()
168
 
169
  img_2d = visualize_2d_heatmap(diffs, base_name, chat_name)
170
+ html_3d = generate_3d_html(diffs)
171
 
172
+ return img_2d, html_3d
173
 
174
  except Exception as e:
175
+ import traceback
176
+ traceback.print_exc()
177
  raise gr.Error(f"Error processing models: {str(e)}")
178
 
179
  # --- Gradio UI Layout ---
 
196
  with gr.Row():
197
  with gr.Column():
198
  gr.Markdown("### 3D Interactive Landscape")
199
+ # Using HTML component avoids Pydantic/Gradio schema validation bugs
200
+ output_3d = gr.HTML(label="3D Visualization")
201
 
202
  submit_btn.click(
203
  fn=process_models,