File size: 7,970 Bytes
a452a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92fcb27
a452a60
 
92fcb27
a452a60
92fcb27
a452a60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f55c43e
a452a60
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.nn.lenses import TunedLensConfig
from tuned_lens.nn.unembed import Unembed
from tuned_lens.plotting import PredictionTrajectory
import json
import os
import gradio as gr
import plotly.io as pio
import plotly.graph_objects as go
import datetime

def ensure_kaleido():
    try:
        import kaleido
    except ImportError:
        print("Installing kaleido...")
        import subprocess
        import sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"])
        import kaleido
        print("Kaleido installed successfully!")

# Call this at the start
ensure_kaleido()

# Set font for multilingual (Devanagari etc.)
pio.templates.default = "plotly"
pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12)

# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")

# Load tuned lens from HF
lens_folder_path = "tuned-lens"
os.system("apt-get -y install git-lfs")
os.system("git lfs install")
os.system(f"git clone https://huggingface.co/AnonymousAccountACL/{lens_folder_path}")

new_lens_folder_path = lens_folder_path + "/sarvam_lens/sarvamai/sarvam-1"
config_path = os.path.join(new_lens_folder_path, "config.json")
params_path = os.path.join(new_lens_folder_path, "params.pt")

with open(config_path, "r") as f:
    config_dict = json.load(f)

fixed_config = {
    "base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"),
    "d_model": config_dict.get("d_model", model.config.hidden_size),
    "num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers),
    "bias": config_dict.get("bias", True),
    "base_model_revision": config_dict.get("base_model_revision", None),
    "unembed_hash": config_dict.get("unembed_hash", None),
    "lens_type": config_dict.get("lens_type", "linear_tuned_lens"),
}

config = TunedLensConfig.from_dict(fixed_config)
unembed = Unembed(model)
tuned_lens = TunedLens(unembed, config)

# Load and align parameters
state_dict = torch.load(params_path, map_location=device)
layer_translator_state_dict = {}
for key, value in state_dict.items():
    if key.startswith("layer_translators."):
        new_key = key.replace("layer_translators.", "")
        layer_translator_state_dict[new_key] = value

layer_indices = set()
for key in layer_translator_state_dict.keys():
    if "." in key:
        layer_idx = int(key.split(".")[0])
        layer_indices.add(layer_idx)

actual_num_translators = max(layer_indices) + 1 if layer_indices else 0
if actual_num_translators != len(tuned_lens.layer_translators):
    fixed_config["num_hidden_layers"] = actual_num_translators
    config = TunedLensConfig.from_dict(fixed_config)
    tuned_lens = TunedLens(unembed, config)

tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict)

# Patch TunedLens forward
original_forward = tuned_lens.forward
def patched_forward(self, hidden_states, layer_idx):
    if layer_idx < len(self.layer_translators):
        return original_forward(hidden_states, layer_idx)
    else:
        return self.unembed(hidden_states)
tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens)

# Load LogitLens
logit_lens = LogitLens.from_model(model)

def fix_colorbar_properties(fig):
    """
    Recursively fix invalid colorbar properties in a Plotly figure
    """
    def clean_colorbar(colorbar_obj):
        if colorbar_obj is None:
            return
        
        # List of invalid properties to remove
        invalid_props = ['titleside']
        
        if isinstance(colorbar_obj, dict):
            # Remove invalid properties from dict
            for prop in invalid_props:
                colorbar_obj.pop(prop, None)
        else:
            # Remove invalid properties from object
            for prop in invalid_props:
                if hasattr(colorbar_obj, prop):
                    try:
                        delattr(colorbar_obj, prop)
                    except:
                        pass
    
    # Fix colorbar properties in all traces
    for trace in fig.data:
        if hasattr(trace, 'colorbar') and trace.colorbar is not None:
            clean_colorbar(trace.colorbar)
        
        # Handle different trace types that might have colorbars
        if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'):
            clean_colorbar(trace.marker.colorbar)
        
        if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'):
            clean_colorbar(trace.line.colorbar)
    
    return fig

# Plot generation with comprehensive colorbar fix
def generate_plot(text, lens_type, statistic, stride, start, end):
    input_ids = tokenizer.encode(text, add_special_tokens=True)
    targets = input_ids[1:] + [tokenizer.eos_token_id]
    lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens

    try:
        pred_traj = PredictionTrajectory.from_lens_and_model(
            lens=lens,
            model=model,
            input_ids=input_ids,
            tokenizer=tokenizer,
            targets=targets,
        ).slice_sequence(slice(start, end))

        fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}")

        # Apply comprehensive colorbar fix
        fig = fix_colorbar_properties(fig)

        # Update layout with proper font settings
        fig.update_layout(
            font=dict(family="Mangal, Arial, sans-serif", size=12),
            title_font=dict(family="Mangal, Arial, sans-serif", size=14),
        )

        filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
        fig.write_image(filename, width=1200, height=800, scale=2, format="png")
        return filename, fig

    except Exception as e:
        return None, f"Error: {e}"

# Gradio UI
def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end):
    filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end)
    if isinstance(result, str):
        return result, None
    return f"✅ Plot saved as {filename}", result

with gr.Blocks() as demo:
    gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1")

    with gr.Row():
        with gr.Column():
            text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.")
            lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens")
            statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic")

        with gr.Column():
            token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start")
            token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End")
            stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride")

    btn = gr.Button("🔍 Generate Plot")
    output_text = gr.Textbox(label="Status")
    plot_output = gr.Plot(label="Visualization")

    btn.click(
        fn=lens_ui,
        inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end],
        outputs=[output_text, plot_output]
    )

if __name__ == "__main__":
    demo.launch()