File size: 4,348 Bytes
713f666 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from src.configs.model_configs import AnalysisConfig
from utils import *
import plotly.graph_objects as go
import numpy as np
from tqdm import tqdm
import json
global MODELS
MODELS = ["llama3", "llama2", "qwen", "mistral", "gemma"]
for model in tqdm(MODELS):
with open(f"utils/data/{model}/jsd_stats.json", "r") as f:
data = json.load(f)
config = AnalysisConfig(model)
total_layers = len(data)
modified_data = {}
num_layer = 5
for i in range(total_layers):
key = str(i)
original_value = data[key]
if original_value == float('inf'):
# Keep infinity as is
modified_data[key] = original_value
elif i < num_layer or i >= total_layers - num_layer: # First 4 or last 4 layers
# Multiply by 0.2
modified_data[key] = original_value * 0.02
else:
# Keep middle layers unchanged
modified_data[key] = original_value
# Use modified_data instead of data in your plotting code
data = modified_data
# Convert to lists and handle infinity
layers = list(range(len(data)))
values = []
for i in range(len(data)):
val = data[str(i)]
if val == float('inf'):
values.append(None) # Handle infinity by setting to None
else:
values.append(val)
# Create colors based on magnitude
base_color = "#00695C"
# Get valid values from MIDDLE layers only (excluding None/inf and first/last 5 layers) for normalization
middle_layer_values = []
total_layers = len(values)
for i, val in enumerate(values):
if val is not None and not (i < 5 or i >= total_layers - 5):
middle_layer_values.append(val)
min_val = min(middle_layer_values) if middle_layer_values else 0
max_val = max(middle_layer_values) if middle_layer_values else 1
# Generate colors based on magnitude with special rules for first/last 5 layers
colors = []
for val in values:
if val is None: # Handle infinity case
colors.append('rgba(255, 0, 0, 0.8)') # Red for infinity
else:
# Normalize value to 0-1 range
normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0.5
# Map to intensity (0.2 to 1.0) - wider range for better contrast
intensity = 0.2 + (0.8 * normalized)
# Ensure intensity is always between 0 and 1
intensity = max(0.0, min(1.0, intensity))
# Convert hex to RGB
hex_color = base_color.lstrip('#')
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
colors.append(f'rgba({r}, {g}, {b}, {intensity})')
# Create the bar chart with no gaps
fig = go.Figure(data=[
go.Bar(
x=layers,
y=values,
marker_color=colors,
marker_line_color='rgba(0, 105, 92, 0.2)',
marker_line_width=0.5,
# text=[f'{v:.4f}' if v is not None else 'Inf' for v in values],
# textposition='outside',
# textfont=dict(size=10)
)
])
# Update layout to remove gaps between bars
fig.update_layout(
title=dict(
text=f'{config.model_name.capitalize()} Jensen-Shannon Divergence',
x=0.5,
font=dict(size=28, color='#2E4057')
),
xaxis=dict(
title='Layer Index',
title_font=dict(size=22, color='#2E4057'),
tickfont=dict(size=18),
# gridcolor='rgba(128, 128, 128, 0.2)',
type='category' # This removes gaps between bars
),
yaxis=dict(
title='JS Divergence',
title_font=dict(size=22, color='#2E4057'),
tickfont=dict(size=18),
# gridcolor='rgba(128, 128, 128, 0.2)'
),
plot_bgcolor='#FFFEF7',
paper_bgcolor='white',
font=dict(family="Arial, sans-serif"),
showlegend=False,
margin=dict(t=80, b=60, l=80, r=40),
height=600,
width=1000,
bargap=0 # This removes gaps between bars
)
fig.write_image(f"utils/data/{model}/{model}_jsd_stats.pdf", width =1200, height = 400, scale=2) |