AnonymousAccountACL commited on
Commit
a452a60
·
verified ·
1 Parent(s): 040f538

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from tuned_lens.nn.lenses import TunedLens, LogitLens
4
+ from tuned_lens.nn.lenses import TunedLensConfig
5
+ from tuned_lens.nn.unembed import Unembed
6
+ from tuned_lens.plotting import PredictionTrajectory
7
+ import json
8
+ import os
9
+ import gradio as gr
10
+ import plotly.io as pio
11
+ import plotly.graph_objects as go
12
+ import datetime
13
+
14
+ def ensure_kaleido():
15
+ try:
16
+ import kaleido
17
+ except ImportError:
18
+ print("Installing kaleido...")
19
+ import subprocess
20
+ import sys
21
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"])
22
+ import kaleido
23
+ print("Kaleido installed successfully!")
24
+
25
+ # Call this at the start
26
+ ensure_kaleido()
27
+
28
+ # Set font for multilingual (Devanagari etc.)
29
+ pio.templates.default = "plotly"
30
+ pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12)
31
+
32
+ # Load model and tokenizer
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
35
+ model = model.to(device)
36
+ tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
37
+
38
+ # Load tuned lens from HF
39
+ lens_folder_path = "tunedlens-model"
40
+ os.system("apt-get -y install git-lfs")
41
+ os.system("git lfs install")
42
+ os.system(f"git clone https://huggingface.co/MihirRajeshPanchal/{lens_folder_path}")
43
+
44
+ new_lens_folder_path = lens_folder_path + "/spaces/sarvamai/sarvam-1"
45
+ config_path = os.path.join(new_lens_folder_path, "config.json")
46
+ params_path = os.path.join(new_lens_folder_path, "params.pt")
47
+
48
+ with open(config_path, "r") as f:
49
+ config_dict = json.load(f)
50
+
51
+ fixed_config = {
52
+ "base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"),
53
+ "d_model": config_dict.get("d_model", model.config.hidden_size),
54
+ "num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers),
55
+ "bias": config_dict.get("bias", True),
56
+ "base_model_revision": config_dict.get("base_model_revision", None),
57
+ "unembed_hash": config_dict.get("unembed_hash", None),
58
+ "lens_type": config_dict.get("lens_type", "linear_tuned_lens"),
59
+ }
60
+
61
+ config = TunedLensConfig.from_dict(fixed_config)
62
+ unembed = Unembed(model)
63
+ tuned_lens = TunedLens(unembed, config)
64
+
65
+ # Load and align parameters
66
+ state_dict = torch.load(params_path, map_location=device)
67
+ layer_translator_state_dict = {}
68
+ for key, value in state_dict.items():
69
+ if key.startswith("layer_translators."):
70
+ new_key = key.replace("layer_translators.", "")
71
+ layer_translator_state_dict[new_key] = value
72
+
73
+ layer_indices = set()
74
+ for key in layer_translator_state_dict.keys():
75
+ if "." in key:
76
+ layer_idx = int(key.split(".")[0])
77
+ layer_indices.add(layer_idx)
78
+
79
+ actual_num_translators = max(layer_indices) + 1 if layer_indices else 0
80
+ if actual_num_translators != len(tuned_lens.layer_translators):
81
+ fixed_config["num_hidden_layers"] = actual_num_translators
82
+ config = TunedLensConfig.from_dict(fixed_config)
83
+ tuned_lens = TunedLens(unembed, config)
84
+
85
+ tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict)
86
+
87
+ # Patch TunedLens forward
88
+ original_forward = tuned_lens.forward
89
+ def patched_forward(self, hidden_states, layer_idx):
90
+ if layer_idx < len(self.layer_translators):
91
+ return original_forward(hidden_states, layer_idx)
92
+ else:
93
+ return self.unembed(hidden_states)
94
+ tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens)
95
+
96
+ # Load LogitLens
97
+ logit_lens = LogitLens.from_model(model)
98
+
99
+ def fix_colorbar_properties(fig):
100
+ """
101
+ Recursively fix invalid colorbar properties in a Plotly figure
102
+ """
103
+ def clean_colorbar(colorbar_obj):
104
+ if colorbar_obj is None:
105
+ return
106
+
107
+ # List of invalid properties to remove
108
+ invalid_props = ['titleside']
109
+
110
+ if isinstance(colorbar_obj, dict):
111
+ # Remove invalid properties from dict
112
+ for prop in invalid_props:
113
+ colorbar_obj.pop(prop, None)
114
+ else:
115
+ # Remove invalid properties from object
116
+ for prop in invalid_props:
117
+ if hasattr(colorbar_obj, prop):
118
+ try:
119
+ delattr(colorbar_obj, prop)
120
+ except:
121
+ pass
122
+
123
+ # Fix colorbar properties in all traces
124
+ for trace in fig.data:
125
+ if hasattr(trace, 'colorbar') and trace.colorbar is not None:
126
+ clean_colorbar(trace.colorbar)
127
+
128
+ # Handle different trace types that might have colorbars
129
+ if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'):
130
+ clean_colorbar(trace.marker.colorbar)
131
+
132
+ if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'):
133
+ clean_colorbar(trace.line.colorbar)
134
+
135
+ return fig
136
+
137
+ # Plot generation with comprehensive colorbar fix
138
+ def generate_plot(text, lens_type, statistic, stride, start, end):
139
+ input_ids = tokenizer.encode(text, add_special_tokens=True)
140
+ targets = input_ids[1:] + [tokenizer.eos_token_id]
141
+ lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens
142
+
143
+ try:
144
+ pred_traj = PredictionTrajectory.from_lens_and_model(
145
+ lens=lens,
146
+ model=model,
147
+ input_ids=input_ids,
148
+ tokenizer=tokenizer,
149
+ targets=targets,
150
+ ).slice_sequence(slice(start, end))
151
+
152
+ fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}")
153
+
154
+ # Apply comprehensive colorbar fix
155
+ fig = fix_colorbar_properties(fig)
156
+
157
+ # Update layout with proper font settings
158
+ fig.update_layout(
159
+ font=dict(family="Mangal, Arial, sans-serif", size=12),
160
+ title_font=dict(family="Mangal, Arial, sans-serif", size=14),
161
+ )
162
+
163
+ filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
164
+ fig.write_image(filename, width=1200, height=800, scale=2, format="png")
165
+ return filename, fig
166
+
167
+ except Exception as e:
168
+ return None, f"Error: {e}"
169
+
170
+ # Gradio UI
171
+ def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end):
172
+ filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end)
173
+ if isinstance(result, str):
174
+ return result, None
175
+ return f"✅ Plot saved as {filename}", result
176
+
177
+ with gr.Blocks() as demo:
178
+ gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1")
179
+
180
+ with gr.Row():
181
+ with gr.Column():
182
+ text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.")
183
+ lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens")
184
+ statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic")
185
+
186
+ with gr.Column():
187
+ token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start")
188
+ token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End")
189
+ stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride")
190
+
191
+ btn = gr.Button("🔍 Generate Plot")
192
+ output_text = gr.Textbox(label="Status")
193
+ plot_output = gr.Plot(label="Visualization")
194
+
195
+ btn.click(
196
+ fn=lens_ui,
197
+ inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end],
198
+ outputs=[plot_output, output_text]
199
+ )
200
+
201
+ if __name__ == "__main__":
202
+ demo.launch()