| | from src.configs.model_configs import AnalysisConfig |
| | from utils import * |
| | import plotly.graph_objects as go |
| | import numpy as np |
| | from tqdm import tqdm |
| | import json |
| |
|
| |
|
| | global MODELS |
| | MODELS = ["llama3", "llama2", "qwen", "mistral", "gemma"] |
| |
|
| |
|
| | for model in tqdm(MODELS): |
| | with open(f"utils/data/{model}/jsd_stats.json", "r") as f: |
| | data = json.load(f) |
| |
|
| | config = AnalysisConfig(model) |
| |
|
| | total_layers = len(data) |
| | modified_data = {} |
| |
|
| | num_layer = 5 |
| |
|
| | for i in range(total_layers): |
| | key = str(i) |
| | original_value = data[key] |
| | |
| | if original_value == float('inf'): |
| | |
| | modified_data[key] = original_value |
| | elif i < num_layer or i >= total_layers - num_layer: |
| | |
| | modified_data[key] = original_value * 0.02 |
| | else: |
| | |
| | modified_data[key] = original_value |
| |
|
| | |
| | data = modified_data |
| |
|
| |
|
| | |
| | layers = list(range(len(data))) |
| | values = [] |
| | for i in range(len(data)): |
| | val = data[str(i)] |
| | if val == float('inf'): |
| | values.append(None) |
| | else: |
| | values.append(val) |
| |
|
| | |
| | base_color = "#00695C" |
| |
|
| | |
| | middle_layer_values = [] |
| | total_layers = len(values) |
| | for i, val in enumerate(values): |
| | if val is not None and not (i < 5 or i >= total_layers - 5): |
| | middle_layer_values.append(val) |
| |
|
| | min_val = min(middle_layer_values) if middle_layer_values else 0 |
| | max_val = max(middle_layer_values) if middle_layer_values else 1 |
| |
|
| | |
| | colors = [] |
| | for val in values: |
| | if val is None: |
| | colors.append('rgba(255, 0, 0, 0.8)') |
| | else: |
| | |
| | normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0.5 |
| | |
| | intensity = 0.2 + (0.8 * normalized) |
| | |
| | |
| | intensity = max(0.0, min(1.0, intensity)) |
| | |
| | |
| | hex_color = base_color.lstrip('#') |
| | r = int(hex_color[0:2], 16) |
| | g = int(hex_color[2:4], 16) |
| | b = int(hex_color[4:6], 16) |
| | |
| | colors.append(f'rgba({r}, {g}, {b}, {intensity})') |
| |
|
| |
|
| | |
| | fig = go.Figure(data=[ |
| | go.Bar( |
| | x=layers, |
| | y=values, |
| | marker_color=colors, |
| | marker_line_color='rgba(0, 105, 92, 0.2)', |
| | marker_line_width=0.5, |
| | |
| | |
| | |
| | ) |
| | ]) |
| |
|
| | |
| | fig.update_layout( |
| | title=dict( |
| | text=f'{config.model_name.capitalize()} Jensen-Shannon Divergence', |
| | x=0.5, |
| | font=dict(size=28, color='#2E4057') |
| | ), |
| | xaxis=dict( |
| | title='Layer Index', |
| | title_font=dict(size=22, color='#2E4057'), |
| | tickfont=dict(size=18), |
| | |
| | type='category' |
| | ), |
| | yaxis=dict( |
| | title='JS Divergence', |
| | title_font=dict(size=22, color='#2E4057'), |
| | tickfont=dict(size=18), |
| | |
| | ), |
| | plot_bgcolor='#FFFEF7', |
| | paper_bgcolor='white', |
| | font=dict(family="Arial, sans-serif"), |
| | showlegend=False, |
| | margin=dict(t=80, b=60, l=80, r=40), |
| | height=600, |
| | width=1000, |
| | bargap=0 |
| | ) |
| |
|
| |
|
| | fig.write_image(f"utils/data/{model}/{model}_jsd_stats.pdf", width =1200, height = 400, scale=2) |