import torch from transformers import AutoModelForCausalLM, AutoTokenizer from tuned_lens.nn.lenses import TunedLens, LogitLens from tuned_lens.nn.lenses import TunedLensConfig from tuned_lens.nn.unembed import Unembed from tuned_lens.plotting import PredictionTrajectory import json import os import gradio as gr import plotly.io as pio import plotly.graph_objects as go import datetime def ensure_kaleido(): try: import kaleido except ImportError: print("Installing kaleido...") import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"]) import kaleido print("Kaleido installed successfully!") # Call this at the start ensure_kaleido() # Set font for multilingual (Devanagari etc.) pio.templates.default = "plotly" pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12) # Load model and tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) model = model.to(device) tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1") # Load tuned lens from HF lens_folder_path = "tuned-lens" os.system("apt-get -y install git-lfs") os.system("git lfs install") os.system(f"git clone https://huggingface.co/AnonymousAccountACL/{lens_folder_path}") new_lens_folder_path = lens_folder_path + "/sarvam_lens/sarvamai/sarvam-1" config_path = os.path.join(new_lens_folder_path, "config.json") params_path = os.path.join(new_lens_folder_path, "params.pt") with open(config_path, "r") as f: config_dict = json.load(f) fixed_config = { "base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"), "d_model": config_dict.get("d_model", model.config.hidden_size), "num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers), "bias": config_dict.get("bias", True), "base_model_revision": config_dict.get("base_model_revision", None), "unembed_hash": config_dict.get("unembed_hash", None), "lens_type": config_dict.get("lens_type", "linear_tuned_lens"), } config = TunedLensConfig.from_dict(fixed_config) unembed = Unembed(model) tuned_lens = TunedLens(unembed, config) # Load and align parameters state_dict = torch.load(params_path, map_location=device) layer_translator_state_dict = {} for key, value in state_dict.items(): if key.startswith("layer_translators."): new_key = key.replace("layer_translators.", "") layer_translator_state_dict[new_key] = value layer_indices = set() for key in layer_translator_state_dict.keys(): if "." in key: layer_idx = int(key.split(".")[0]) layer_indices.add(layer_idx) actual_num_translators = max(layer_indices) + 1 if layer_indices else 0 if actual_num_translators != len(tuned_lens.layer_translators): fixed_config["num_hidden_layers"] = actual_num_translators config = TunedLensConfig.from_dict(fixed_config) tuned_lens = TunedLens(unembed, config) tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict) # Patch TunedLens forward original_forward = tuned_lens.forward def patched_forward(self, hidden_states, layer_idx): if layer_idx < len(self.layer_translators): return original_forward(hidden_states, layer_idx) else: return self.unembed(hidden_states) tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens) # Load LogitLens logit_lens = LogitLens.from_model(model) def fix_colorbar_properties(fig): """ Recursively fix invalid colorbar properties in a Plotly figure """ def clean_colorbar(colorbar_obj): if colorbar_obj is None: return # List of invalid properties to remove invalid_props = ['titleside'] if isinstance(colorbar_obj, dict): # Remove invalid properties from dict for prop in invalid_props: colorbar_obj.pop(prop, None) else: # Remove invalid properties from object for prop in invalid_props: if hasattr(colorbar_obj, prop): try: delattr(colorbar_obj, prop) except: pass # Fix colorbar properties in all traces for trace in fig.data: if hasattr(trace, 'colorbar') and trace.colorbar is not None: clean_colorbar(trace.colorbar) # Handle different trace types that might have colorbars if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'): clean_colorbar(trace.marker.colorbar) if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'): clean_colorbar(trace.line.colorbar) return fig # Plot generation with comprehensive colorbar fix def generate_plot(text, lens_type, statistic, stride, start, end): input_ids = tokenizer.encode(text, add_special_tokens=True) targets = input_ids[1:] + [tokenizer.eos_token_id] lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens try: pred_traj = PredictionTrajectory.from_lens_and_model( lens=lens, model=model, input_ids=input_ids, tokenizer=tokenizer, targets=targets, ).slice_sequence(slice(start, end)) fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}") # Apply comprehensive colorbar fix fig = fix_colorbar_properties(fig) # Update layout with proper font settings fig.update_layout( font=dict(family="Mangal, Arial, sans-serif", size=12), title_font=dict(family="Mangal, Arial, sans-serif", size=14), ) filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png" fig.write_image(filename, width=1200, height=800, scale=2, format="png") return filename, fig except Exception as e: return None, f"Error: {e}" # Gradio UI def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end): filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end) if isinstance(result, str): return result, None return f"✅ Plot saved as {filename}", result with gr.Blocks() as demo: gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1") with gr.Row(): with gr.Column(): text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.") lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens") statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic") with gr.Column(): token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start") token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End") stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride") btn = gr.Button("🔍 Generate Plot") output_text = gr.Textbox(label="Status") plot_output = gr.Plot(label="Visualization") btn.click( fn=lens_ui, inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end], outputs=[output_text, plot_output] ) if __name__ == "__main__": demo.launch()