Spaces:
Sleeping
Sleeping
File size: 7,970 Bytes
a452a60 92fcb27 a452a60 92fcb27 a452a60 92fcb27 a452a60 f55c43e a452a60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.nn.lenses import TunedLensConfig
from tuned_lens.nn.unembed import Unembed
from tuned_lens.plotting import PredictionTrajectory
import json
import os
import gradio as gr
import plotly.io as pio
import plotly.graph_objects as go
import datetime
def ensure_kaleido():
try:
import kaleido
except ImportError:
print("Installing kaleido...")
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"])
import kaleido
print("Kaleido installed successfully!")
# Call this at the start
ensure_kaleido()
# Set font for multilingual (Devanagari etc.)
pio.templates.default = "plotly"
pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12)
# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
# Load tuned lens from HF
lens_folder_path = "tuned-lens"
os.system("apt-get -y install git-lfs")
os.system("git lfs install")
os.system(f"git clone https://huggingface.co/AnonymousAccountACL/{lens_folder_path}")
new_lens_folder_path = lens_folder_path + "/sarvam_lens/sarvamai/sarvam-1"
config_path = os.path.join(new_lens_folder_path, "config.json")
params_path = os.path.join(new_lens_folder_path, "params.pt")
with open(config_path, "r") as f:
config_dict = json.load(f)
fixed_config = {
"base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"),
"d_model": config_dict.get("d_model", model.config.hidden_size),
"num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers),
"bias": config_dict.get("bias", True),
"base_model_revision": config_dict.get("base_model_revision", None),
"unembed_hash": config_dict.get("unembed_hash", None),
"lens_type": config_dict.get("lens_type", "linear_tuned_lens"),
}
config = TunedLensConfig.from_dict(fixed_config)
unembed = Unembed(model)
tuned_lens = TunedLens(unembed, config)
# Load and align parameters
state_dict = torch.load(params_path, map_location=device)
layer_translator_state_dict = {}
for key, value in state_dict.items():
if key.startswith("layer_translators."):
new_key = key.replace("layer_translators.", "")
layer_translator_state_dict[new_key] = value
layer_indices = set()
for key in layer_translator_state_dict.keys():
if "." in key:
layer_idx = int(key.split(".")[0])
layer_indices.add(layer_idx)
actual_num_translators = max(layer_indices) + 1 if layer_indices else 0
if actual_num_translators != len(tuned_lens.layer_translators):
fixed_config["num_hidden_layers"] = actual_num_translators
config = TunedLensConfig.from_dict(fixed_config)
tuned_lens = TunedLens(unembed, config)
tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict)
# Patch TunedLens forward
original_forward = tuned_lens.forward
def patched_forward(self, hidden_states, layer_idx):
if layer_idx < len(self.layer_translators):
return original_forward(hidden_states, layer_idx)
else:
return self.unembed(hidden_states)
tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens)
# Load LogitLens
logit_lens = LogitLens.from_model(model)
def fix_colorbar_properties(fig):
"""
Recursively fix invalid colorbar properties in a Plotly figure
"""
def clean_colorbar(colorbar_obj):
if colorbar_obj is None:
return
# List of invalid properties to remove
invalid_props = ['titleside']
if isinstance(colorbar_obj, dict):
# Remove invalid properties from dict
for prop in invalid_props:
colorbar_obj.pop(prop, None)
else:
# Remove invalid properties from object
for prop in invalid_props:
if hasattr(colorbar_obj, prop):
try:
delattr(colorbar_obj, prop)
except:
pass
# Fix colorbar properties in all traces
for trace in fig.data:
if hasattr(trace, 'colorbar') and trace.colorbar is not None:
clean_colorbar(trace.colorbar)
# Handle different trace types that might have colorbars
if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'):
clean_colorbar(trace.marker.colorbar)
if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'):
clean_colorbar(trace.line.colorbar)
return fig
# Plot generation with comprehensive colorbar fix
def generate_plot(text, lens_type, statistic, stride, start, end):
input_ids = tokenizer.encode(text, add_special_tokens=True)
targets = input_ids[1:] + [tokenizer.eos_token_id]
lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens
try:
pred_traj = PredictionTrajectory.from_lens_and_model(
lens=lens,
model=model,
input_ids=input_ids,
tokenizer=tokenizer,
targets=targets,
).slice_sequence(slice(start, end))
fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}")
# Apply comprehensive colorbar fix
fig = fix_colorbar_properties(fig)
# Update layout with proper font settings
fig.update_layout(
font=dict(family="Mangal, Arial, sans-serif", size=12),
title_font=dict(family="Mangal, Arial, sans-serif", size=14),
)
filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
fig.write_image(filename, width=1200, height=800, scale=2, format="png")
return filename, fig
except Exception as e:
return None, f"Error: {e}"
# Gradio UI
def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end):
filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end)
if isinstance(result, str):
return result, None
return f"✅ Plot saved as {filename}", result
with gr.Blocks() as demo:
gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1")
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.")
lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens")
statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic")
with gr.Column():
token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start")
token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End")
stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride")
btn = gr.Button("🔍 Generate Plot")
output_text = gr.Textbox(label="Status")
plot_output = gr.Plot(label="Visualization")
btn.click(
fn=lens_ui,
inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end],
outputs=[output_text, plot_output]
)
if __name__ == "__main__":
demo.launch() |