Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from tuned_lens.nn.lenses import TunedLens, LogitLens | |
| from tuned_lens.nn.lenses import TunedLensConfig | |
| from tuned_lens.nn.unembed import Unembed | |
| from tuned_lens.plotting import PredictionTrajectory | |
| import json | |
| import os | |
| import gradio as gr | |
| import plotly.io as pio | |
| import plotly.graph_objects as go | |
| import datetime | |
| def ensure_kaleido(): | |
| try: | |
| import kaleido | |
| except ImportError: | |
| print("Installing kaleido...") | |
| import subprocess | |
| import sys | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "kaleido"]) | |
| import kaleido | |
| print("Kaleido installed successfully!") | |
| # Call this at the start | |
| ensure_kaleido() | |
| # Set font for multilingual (Devanagari etc.) | |
| pio.templates.default = "plotly" | |
| pio.templates["plotly"].layout.font = dict(family="Mangal, Arial, sans-serif", size=12) | |
| # Load model and tokenizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained("sarvamai/sarvam-1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
| model = model.to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1") | |
| # Load tuned lens from HF | |
| lens_folder_path = "tuned-lens" | |
| os.system("apt-get -y install git-lfs") | |
| os.system("git lfs install") | |
| os.system(f"git clone https://huggingface.co/AnonymousAccountACL/{lens_folder_path}") | |
| new_lens_folder_path = lens_folder_path + "/sarvam_lens/sarvamai/sarvam-1" | |
| config_path = os.path.join(new_lens_folder_path, "config.json") | |
| params_path = os.path.join(new_lens_folder_path, "params.pt") | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| fixed_config = { | |
| "base_model_name_or_path": config_dict.get("base_model_name_or_path", "sarvamai/sarvam-1"), | |
| "d_model": config_dict.get("d_model", model.config.hidden_size), | |
| "num_hidden_layers": config_dict.get("num_layers", model.config.num_hidden_layers), | |
| "bias": config_dict.get("bias", True), | |
| "base_model_revision": config_dict.get("base_model_revision", None), | |
| "unembed_hash": config_dict.get("unembed_hash", None), | |
| "lens_type": config_dict.get("lens_type", "linear_tuned_lens"), | |
| } | |
| config = TunedLensConfig.from_dict(fixed_config) | |
| unembed = Unembed(model) | |
| tuned_lens = TunedLens(unembed, config) | |
| # Load and align parameters | |
| state_dict = torch.load(params_path, map_location=device) | |
| layer_translator_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("layer_translators."): | |
| new_key = key.replace("layer_translators.", "") | |
| layer_translator_state_dict[new_key] = value | |
| layer_indices = set() | |
| for key in layer_translator_state_dict.keys(): | |
| if "." in key: | |
| layer_idx = int(key.split(".")[0]) | |
| layer_indices.add(layer_idx) | |
| actual_num_translators = max(layer_indices) + 1 if layer_indices else 0 | |
| if actual_num_translators != len(tuned_lens.layer_translators): | |
| fixed_config["num_hidden_layers"] = actual_num_translators | |
| config = TunedLensConfig.from_dict(fixed_config) | |
| tuned_lens = TunedLens(unembed, config) | |
| tuned_lens.layer_translators.load_state_dict(layer_translator_state_dict) | |
| # Patch TunedLens forward | |
| original_forward = tuned_lens.forward | |
| def patched_forward(self, hidden_states, layer_idx): | |
| if layer_idx < len(self.layer_translators): | |
| return original_forward(hidden_states, layer_idx) | |
| else: | |
| return self.unembed(hidden_states) | |
| tuned_lens.forward = patched_forward.__get__(tuned_lens, TunedLens) | |
| # Load LogitLens | |
| logit_lens = LogitLens.from_model(model) | |
| def fix_colorbar_properties(fig): | |
| """ | |
| Recursively fix invalid colorbar properties in a Plotly figure | |
| """ | |
| def clean_colorbar(colorbar_obj): | |
| if colorbar_obj is None: | |
| return | |
| # List of invalid properties to remove | |
| invalid_props = ['titleside'] | |
| if isinstance(colorbar_obj, dict): | |
| # Remove invalid properties from dict | |
| for prop in invalid_props: | |
| colorbar_obj.pop(prop, None) | |
| else: | |
| # Remove invalid properties from object | |
| for prop in invalid_props: | |
| if hasattr(colorbar_obj, prop): | |
| try: | |
| delattr(colorbar_obj, prop) | |
| except: | |
| pass | |
| # Fix colorbar properties in all traces | |
| for trace in fig.data: | |
| if hasattr(trace, 'colorbar') and trace.colorbar is not None: | |
| clean_colorbar(trace.colorbar) | |
| # Handle different trace types that might have colorbars | |
| if hasattr(trace, 'marker') and hasattr(trace.marker, 'colorbar'): | |
| clean_colorbar(trace.marker.colorbar) | |
| if hasattr(trace, 'line') and hasattr(trace.line, 'colorbar'): | |
| clean_colorbar(trace.line.colorbar) | |
| return fig | |
| # Plot generation with comprehensive colorbar fix | |
| def generate_plot(text, lens_type, statistic, stride, start, end): | |
| input_ids = tokenizer.encode(text, add_special_tokens=True) | |
| targets = input_ids[1:] + [tokenizer.eos_token_id] | |
| lens = tuned_lens if lens_type == "Tuned Lens" else logit_lens | |
| try: | |
| pred_traj = PredictionTrajectory.from_lens_and_model( | |
| lens=lens, | |
| model=model, | |
| input_ids=input_ids, | |
| tokenizer=tokenizer, | |
| targets=targets, | |
| ).slice_sequence(slice(start, end)) | |
| fig = getattr(pred_traj, statistic)().stride(stride).figure(title=f"{lens_type} ({model.name_or_path}) {statistic}") | |
| # Apply comprehensive colorbar fix | |
| fig = fix_colorbar_properties(fig) | |
| # Update layout with proper font settings | |
| fig.update_layout( | |
| font=dict(family="Mangal, Arial, sans-serif", size=12), | |
| title_font=dict(family="Mangal, Arial, sans-serif", size=14), | |
| ) | |
| filename = f"lens_plot_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png" | |
| fig.write_image(filename, width=1200, height=800, scale=2, format="png") | |
| return filename, fig | |
| except Exception as e: | |
| return None, f"Error: {e}" | |
| # Gradio UI | |
| def lens_ui(text, lens_type, statistic, stride, token_range_start, token_range_end): | |
| filename, result = generate_plot(text, lens_type, statistic, stride, token_range_start, token_range_end) | |
| if isinstance(result, str): | |
| return result, None | |
| return f"✅ Plot saved as {filename}", result | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🔬 Tuned Lens + Logit Lens Visualizer for Sarvam-1") | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Textbox(label="Input Text", lines=8, value="Question: मूरस के अनुसार 'आदर्श उपयोगितावाद' के अनुसार, सही क्रिया वह है जो सबसे अधिक लाभ लाती है:\nOptions:\n1. सुख।\n2. खुशी।\n3. अच्छाई।\n4. गुणवत्ता।\n\nPlease select the correct option.") | |
| lens_type = gr.Radio(["Tuned Lens", "Logit Lens"], value="Tuned Lens", label="Select Lens") | |
| statistic = gr.Radio(["entropy", "cross_entropy", "forward_kl"], value="entropy", label="Statistic") | |
| with gr.Column(): | |
| token_range_start = gr.Slider(0, 63, value=0, step=1, label="Token Range Start") | |
| token_range_end = gr.Slider(0, 63, value=64, step=1, label="Token Range End") | |
| stride = gr.Slider(1, 10, value=2, step=1, label="Layer Stride") | |
| btn = gr.Button("🔍 Generate Plot") | |
| output_text = gr.Textbox(label="Status") | |
| plot_output = gr.Plot(label="Visualization") | |
| btn.click( | |
| fn=lens_ui, | |
| inputs=[text, lens_type, statistic, stride, token_range_start, token_range_end], | |
| outputs=[output_text, plot_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |