IndicTunedLens / app.py
AnonymousAccountACL's picture
Update app.py
92fcb27 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn.lenses import TunedLens, LogitLens
from tuned_lens.nn.lenses import TunedLensConfig
from tuned_lens.nn.unembed import Unembed
from tuned_lens.plotting import PredictionTrajectory
import json
import os
import gradio as gr
import plotly.io as pio
import plotly.graph_objects as go
import datetime
def ensure_kaleido():
try:
import kaleido
except ImportError:
print("Installing kaleido...")
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"])
import kaleido
print("Kaleido installed successfully!")
# Call this at the start
ensure_kaleido()
# Set font for multilingual (Devanagari etc.)
pio.templates.default = "plotly"
pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12)
# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
# Load tuned lens from HF
lens_folder_path = "tuned-lens"
os.system("apt-get -y install git-lfs")
os.system("git lfs install")
os.system(f"git clone https://huggingface.co/AnonymousAccountACL/{lens_folder_path}")
new_lens_folder_path = lens_folder_path + "/sarvam_lens/sarvamai/sarvam-1"
config_path = os.path.join(new_lens_folder_path, "config.json")
params_path = os.path.join(new_lens_folder_path, "params.pt")
with open(config_path, "r") as f:
config_dict = json.load(f)
fixed_config = {
"base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"),
"d_model": config_dict.get("d_model", model.config.hidden_size),
"num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers),
"bias": config_dict.get("bias", True),
"base_model_revision": config_dict.get("base_model_revision", None),
"unembed_hash": config_dict.get("unembed_hash", None),
"lens_type": config_dict.get("lens_type", "linear_tuned_lens"),
}
config = TunedLensConfig.from_dict(fixed_config)
unembed = Unembed(model)
tuned_lens = TunedLens(unembed, config)
# Load and align parameters
state_dict = torch.load(params_path, map_location=device)
layer_translator_state_dict = {}
for key, value in state_dict.items():
if key.startswith("layer_translators."):
new_key = key.replace("layer_translators.", "")
layer_translator_state_dict[new_key] = value
layer_indices = set()
for key in layer_translator_state_dict.keys():
if "." in key:
layer_idx = int(key.split(".")[0])
layer_indices.add(layer_idx)
actual_num_translators = max(layer_indices) + 1 if layer_indices else 0
if actual_num_translators != len(tuned_lens.layer_translators):
fixed_config["num_hidden_layers"] = actual_num_translators
config = TunedLensConfig.from_dict(fixed_config)
tuned_lens = TunedLens(unembed, config)
tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict)
# Patch TunedLens forward
original_forward = tuned_lens.forward
def patched_forward(self, hidden_states, layer_idx):
if layer_idx < len(self.layer_translators):
return original_forward(hidden_states, layer_idx)
else:
return self.unembed(hidden_states)
tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens)
# Load LogitLens
logit_lens = LogitLens.from_model(model)
def fix_colorbar_properties(fig):
"""
Recursively fix invalid colorbar properties in a Plotly figure
"""
def clean_colorbar(colorbar_obj):
if colorbar_obj is None:
return
# List of invalid properties to remove
invalid_props = ['titleside']
if isinstance(colorbar_obj, dict):
# Remove invalid properties from dict
for prop in invalid_props:
colorbar_obj.pop(prop, None)
else:
# Remove invalid properties from object
for prop in invalid_props:
if hasattr(colorbar_obj, prop):
try:
delattr(colorbar_obj, prop)
except:
pass
# Fix colorbar properties in all traces
for trace in fig.data:
if hasattr(trace, 'colorbar') and trace.colorbar is not None:
clean_colorbar(trace.colorbar)
# Handle different trace types that might have colorbars
if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'):
clean_colorbar(trace.marker.colorbar)
if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'):
clean_colorbar(trace.line.colorbar)
return fig
# Plot generation with comprehensive colorbar fix
def generate_plot(text, lens_type, statistic, stride, start, end):
input_ids = tokenizer.encode(text, add_special_tokens=True)
targets = input_ids[1:] + [tokenizer.eos_token_id]
lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens
try:
pred_traj = PredictionTrajectory.from_lens_and_model(
lens=lens,
model=model,
input_ids=input_ids,
tokenizer=tokenizer,
targets=targets,
).slice_sequence(slice(start, end))
fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}")
# Apply comprehensive colorbar fix
fig = fix_colorbar_properties(fig)
# Update layout with proper font settings
fig.update_layout(
font=dict(family="Mangal, Arial, sans-serif", size=12),
title_font=dict(family="Mangal, Arial, sans-serif", size=14),
)
filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
fig.write_image(filename, width=1200, height=800, scale=2, format="png")
return filename, fig
except Exception as e:
return None, f"Error: {e}"
# Gradio UI
def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end):
filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end)
if isinstance(result, str):
return result, None
return f"✅ Plot saved as {filename}", result
with gr.Blocks() as demo:
gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1")
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.")
lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens")
statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic")
with gr.Column():
token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start")
token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End")
stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride")
btn = gr.Button("🔍 Generate Plot")
output_text = gr.Textbox(label="Status")
plot_output = gr.Plot(label="Visualization")
btn.click(
fn=lens_ui,
inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end],
outputs=[output_text, plot_output]
)
if __name__ == "__main__":
demo.launch()