| import threading |
| from dataclasses import dataclass |
| from typing import Any, Iterable |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from llm_steer import Steer, DecaySchedule |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODEL_IDS = ["LiquidAI/LFM2-350M", "LiquidAI/LFM2-700M", "LiquidAI/LFM2-1.2B"] |
| DEFAULT_SCHEDULE_ENABLED = True |
| DEFAULT_SCHEDULE_RATE = 0.85 |
| DEFAULT_SCHEDULE_MIN_MULTIPLIER = 0.35 |
| DEFAULT_SCHEDULE_RESTARTS = 14 |
| DEFAULT_VECTORS_ENABLED = True |
| PRESETS = [ |
| { |
| "label": "Logical thinking", |
| "model_id": MODEL_IDS[0], |
| "vectors": [ |
| ["But wait", 0.4, 7], |
| ["But wait", 0.4, 8], |
| ["overthink", -0.4, 7], |
| ["overthink", -0.4, 8], |
| ], |
| "schedule": { |
| "enabled": DEFAULT_SCHEDULE_ENABLED, |
| "rate": DEFAULT_SCHEDULE_RATE, |
| "min_multiplier": DEFAULT_SCHEDULE_MIN_MULTIPLIER, |
| "times_restart": DEFAULT_SCHEDULE_RESTARTS, |
| }, |
| } |
| ] |
|
|
|
|
| @dataclass |
| class ModelBundle: |
| model_id: str |
| tokenizer: Any |
| base_model: Any |
| steer: Steer |
| lock: threading.Lock |
|
|
|
|
| class IsolatedSteer(Steer): |
| def __init__(self, model, tokenizer, copyModel: bool): |
| super().__init__(model, tokenizer, copyModel=copyModel) |
| self.steers = {} |
|
|
|
|
| def _resolve_dtype() -> torch.dtype: |
| if torch.cuda.is_available(): |
| if torch.cuda.is_bf16_supported(): |
| return torch.bfloat16 |
| return torch.float16 |
| return torch.float32 |
|
|
|
|
| def _load_model(model_id: str) -> ModelBundle: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_id, |
| use_fast=True, |
| trust_remote_code=True, |
| ) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "left" |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=_resolve_dtype(), |
| trust_remote_code=True, |
| ) |
| base_model.eval() |
| if base_model.config.pad_token_id is None: |
| base_model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| steer = IsolatedSteer(base_model, tokenizer, copyModel=True) |
| steer.model.to(device) |
| steer.device = device |
| steer.model.eval() |
|
|
| return ModelBundle( |
| model_id=model_id, |
| tokenizer=tokenizer, |
| base_model=base_model, |
| steer=steer, |
| lock=threading.Lock(), |
| ) |
|
|
|
|
| MODELS = {model_id: _load_model(model_id) for model_id in MODEL_IDS} |
|
|
|
|
| def _parse_vectors( |
| raw_vectors: Iterable[Iterable[Any]], |
| ) -> list[tuple[str, float, int]]: |
| cleaned: list[tuple[str, float, int]] = [] |
| if not raw_vectors: |
| return cleaned |
| for row in raw_vectors: |
| if not row or len(row) < 3: |
| continue |
| text = "" if row[0] is None else str(row[0]).strip() |
| if not text: |
| continue |
| try: |
| coeff = float(row[1]) |
| except (TypeError, ValueError): |
| continue |
| if coeff == 0: |
| continue |
| try: |
| layer_idx = int(row[2]) |
| except (TypeError, ValueError): |
| try: |
| layer_idx = int(float(row[2])) |
| except (TypeError, ValueError): |
| continue |
| if layer_idx < 0: |
| continue |
| cleaned.append((text, coeff, layer_idx)) |
| return cleaned |
|
|
|
|
| def _format_plain_prompt(messages: list[dict[str, str]]) -> str: |
| lines = [] |
| for message in messages: |
| role = message.get("role", "user") |
| content = message.get("content", "") |
| if role == "assistant": |
| prefix = "Assistant" |
| elif role == "system": |
| prefix = "System" |
| else: |
| prefix = "User" |
| lines.append(f"{prefix}: {content}") |
| lines.append("Assistant:") |
| return "\n".join(lines) |
|
|
|
|
| def _build_prompt( |
| history: list[dict[str, str]], |
| user_message: str, |
| tokenizer: Any, |
| ) -> str: |
| messages: list[dict[str, str]] = [] |
| if history: |
| messages.extend(history) |
| messages.append({"role": "user", "content": user_message}) |
|
|
| if getattr(tokenizer, "chat_template", None): |
| return tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| return _format_plain_prompt(messages) |
|
|
|
|
| def _generate_reply( |
| bundle: ModelBundle, |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| min_p: float, |
| repetition_penalty: float, |
| ) -> str: |
| tokenizer = bundle.tokenizer |
| model = bundle.steer.model |
|
|
| inputs = tokenizer(prompt, return_tensors="pt") |
| input_ids = inputs["input_ids"].to(model.device) |
| attention_mask = inputs.get("attention_mask") |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(model.device) |
|
|
| gen_kwargs = { |
| "input_ids": input_ids, |
| "max_new_tokens": int(max_new_tokens), |
| "pad_token_id": tokenizer.pad_token_id, |
| "eos_token_id": tokenizer.eos_token_id, |
| } |
| if attention_mask is not None: |
| gen_kwargs["attention_mask"] = attention_mask |
|
|
| if repetition_penalty and repetition_penalty != 1.0: |
| gen_kwargs["repetition_penalty"] = float(repetition_penalty) |
|
|
| if temperature and temperature > 0: |
| gen_kwargs["do_sample"] = True |
| gen_kwargs["temperature"] = float(temperature) |
| if min_p and min_p > 0: |
| gen_kwargs["min_p"] = float(min_p) |
| else: |
| gen_kwargs["do_sample"] = False |
|
|
| output_ids = model.generate(**gen_kwargs) |
| new_tokens = output_ids[0][input_ids.shape[1] :] |
| return tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
|
|
| @spaces.GPU(duration=20) |
| def respond( |
| message: str, |
| history: list[dict[str, str]], |
| model_id: str, |
| vectors: list[list[Any]], |
| vectors_enabled: bool, |
| schedule_enabled: bool, |
| schedule_rate: float, |
| schedule_min_multiplier: float, |
| schedule_times_restart: float, |
| max_new_tokens: int, |
| temperature: float, |
| min_p: float, |
| repetition_penalty: float, |
| ): |
| history = history or [] |
| message = (message or "").strip() |
| if not message: |
| return history, "" |
|
|
| bundle = MODELS[model_id] |
| steer_vectors = _parse_vectors(vectors) |
| schedule = None |
| if vectors_enabled and schedule_enabled: |
| schedule = DecaySchedule( |
| rate=float(schedule_rate), |
| min_multiplier=float(schedule_min_multiplier), |
| times_restart=int(schedule_times_restart), |
| ) |
|
|
| with bundle.lock: |
| steer = bundle.steer |
| steer.reset_all() |
| layers_used: set[int] = set() |
| try: |
| if vectors_enabled: |
| for text, coeff, layer_idx in steer_vectors: |
| steer.add( |
| layer_idx=layer_idx, |
| coeff=coeff, |
| text=text, |
| coeff_schedule=schedule, |
| ) |
| layers_used.add(layer_idx) |
| prompt = _build_prompt(history, message, bundle.tokenizer) |
| with torch.inference_mode(): |
| reply = _generate_reply( |
| bundle, |
| prompt, |
| max_new_tokens, |
| temperature, |
| min_p, |
| repetition_penalty, |
| ) |
| finally: |
| steer.reset_all() |
|
|
| updated_history = history + [ |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": reply}, |
| ] |
| return updated_history, "" |
|
|
|
|
| def _clear_chat(): |
| return [], "" |
|
|
|
|
| def _add_vector(rows: list[list[Any]] | None): |
| rows = rows or [] |
| rows.append(["", 0.6, 8]) |
| return rows |
|
|
|
|
| def _apply_preset(preset: dict[str, Any]): |
| schedule = preset.get("schedule", {}) |
| return ( |
| preset["model_id"], |
| preset["vectors"], |
| schedule.get("enabled", DEFAULT_SCHEDULE_ENABLED), |
| schedule.get("rate", DEFAULT_SCHEDULE_RATE), |
| schedule.get("min_multiplier", DEFAULT_SCHEDULE_MIN_MULTIPLIER), |
| schedule.get("times_restart", DEFAULT_SCHEDULE_RESTARTS), |
| ) |
|
|
|
|
| THEME = gr.themes.Base( |
| font=["Space Grotesk", "IBM Plex Sans", "sans-serif"], |
| primary_hue="teal", |
| secondary_hue="orange", |
| neutral_hue="slate", |
| ) |
|
|
| CSS = """ |
| :root { |
| --body-background-fill: linear-gradient(135deg, #f6f2ea 0%, #f1f7f4 60%, #eef3f9 100%); |
| --block-background-fill: rgba(255, 255, 255, 0.92); |
| --block-border-color: #d6d8db; |
| --block-shadow: 0 14px 30px rgba(15, 23, 42, 0.08); |
| } |
| #title { |
| letter-spacing: 0.03em; |
| } |
| #vector-table thead th button.header-button { |
| pointer-events: none; |
| cursor: default; |
| } |
| #vector-table thead th { |
| user-select: none; |
| } |
| #preset-panel { |
| margin-top: 0.75rem; |
| padding: 0.75rem; |
| border: 1px dashed #cbd5e1; |
| border-radius: 14px; |
| background: rgba(255, 255, 255, 0.65); |
| } |
| #vectors-panel, |
| #schedule-panel { |
| margin-top: 0.75rem; |
| border: 1px solid #e2e8f0; |
| border-radius: 14px; |
| background: rgba(255, 255, 255, 0.7); |
| } |
| .preset-btn button { |
| background: linear-gradient(140deg, #ffffff 0%, #fff1e4 100%); |
| border: 1px solid #e2e8f0; |
| font-weight: 600; |
| } |
| .preset-btn button:hover { |
| border-color: #0f766e; |
| color: #0f766e; |
| } |
| """ |
|
|
|
|
| with gr.Blocks(theme=THEME, css=CSS) as demo: |
| gr.Markdown("# LLM Steer Playground", elem_id="title") |
| gr.Markdown("Pick a model, add steering vectors, and chat with the steered model.") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1, min_width=320): |
| model_choice = gr.Radio( |
| choices=MODEL_IDS, |
| value=MODEL_IDS[2], |
| label="Model", |
| ) |
| with gr.Accordion( |
| "Steering vectors", |
| open=True, |
| elem_id="vectors-panel", |
| ): |
| vectors_enabled = gr.Checkbox( |
| value=DEFAULT_VECTORS_ENABLED, |
| label="Enable steering vectors", |
| ) |
| vector_table = gr.Dataframe( |
| headers=["text", "coeff", "layer"], |
| row_count=(1, "dynamic"), |
| col_count=(3, "fixed"), |
| type="array", |
| datatype=["str", "number", "number"], |
| value=[ |
| ["But wait", 0.4, 7], |
| ["But wait", 0.4, 8], |
| ["overthink", -0.4, 7], |
| ["overthink", -0.4, 8], |
| ], |
| label="Vectors", |
| elem_id="vector-table", |
| interactive=True, |
| ) |
| add_vector = gr.Button("Add new vector row") |
| gr.Markdown( |
| "Coeff can be negative. Layer should have value between 0 and 15." |
| ) |
| with gr.Accordion( |
| "Steering schedule", |
| open=False, |
| elem_id="schedule-panel", |
| ): |
| schedule_enabled = gr.Checkbox( |
| value=DEFAULT_SCHEDULE_ENABLED, |
| label="Enable DecaySchedule", |
| ) |
| schedule_rate = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=DEFAULT_SCHEDULE_RATE, |
| step=0.01, |
| label="Decay rate", |
| ) |
| schedule_min_multiplier = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=DEFAULT_SCHEDULE_MIN_MULTIPLIER, |
| step=0.01, |
| label="Min multiplier", |
| ) |
| schedule_times_restart = gr.Slider( |
| minimum=0, |
| maximum=100, |
| value=DEFAULT_SCHEDULE_RESTARTS, |
| step=1, |
| label="Restarts", |
| ) |
| with gr.Column(elem_id="preset-panel"): |
| gr.Markdown("Presets") |
| with gr.Row(): |
| for preset in PRESETS[:2]: |
| button = gr.Button( |
| preset["label"], |
| elem_classes=["preset-btn"], |
| ) |
| button.click( |
| fn=lambda p=preset: _apply_preset(p), |
| outputs=[ |
| model_choice, |
| vector_table, |
| schedule_enabled, |
| schedule_rate, |
| schedule_min_multiplier, |
| schedule_times_restart, |
| ], |
| inputs=None, |
| ) |
| with gr.Row(): |
| for preset in PRESETS[2:]: |
| button = gr.Button( |
| preset["label"], |
| elem_classes=["preset-btn"], |
| ) |
| button.click( |
| fn=lambda p=preset: _apply_preset(p), |
| outputs=[ |
| model_choice, |
| vector_table, |
| schedule_enabled, |
| schedule_rate, |
| schedule_min_multiplier, |
| schedule_times_restart, |
| ], |
| inputs=None, |
| ) |
| with gr.Accordion("Generation options", open=False): |
| max_new_tokens = gr.Slider( |
| minimum=512 / 2, |
| maximum=512 * 4, |
| value=1024, |
| step=8, |
| label="Max new tokens", |
| ) |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.5, |
| value=0.0, |
| step=0.05, |
| label="Temperature", |
| ) |
| min_p = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.15, |
| step=0.01, |
| label="Min-p", |
| ) |
| repetition_penalty = gr.Slider( |
| minimum=0.8, |
| maximum=2.0, |
| value=1.2, |
| step=0.05, |
| label="Repetition penalty", |
| ) |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot( |
| label="Chat", |
| type="messages", |
| height=520, |
| show_copy_button=True, |
| ) |
| message = gr.Textbox( |
| label="Message", |
| placeholder="Ask something...", |
| lines=2, |
| ) |
| with gr.Row(): |
| send = gr.Button("Send", variant="primary") |
| clear = gr.Button("Clear chat") |
|
|
| send.click( |
| respond, |
| inputs=[ |
| message, |
| chatbot, |
| model_choice, |
| vector_table, |
| vectors_enabled, |
| schedule_enabled, |
| schedule_rate, |
| schedule_min_multiplier, |
| schedule_times_restart, |
| max_new_tokens, |
| temperature, |
| min_p, |
| repetition_penalty, |
| ], |
| outputs=[chatbot, message], |
| ) |
| message.submit( |
| respond, |
| inputs=[ |
| message, |
| chatbot, |
| model_choice, |
| vector_table, |
| vectors_enabled, |
| schedule_enabled, |
| schedule_rate, |
| schedule_min_multiplier, |
| schedule_times_restart, |
| max_new_tokens, |
| temperature, |
| min_p, |
| repetition_penalty, |
| ], |
| outputs=[chatbot, message], |
| ) |
| clear.click(_clear_chat, outputs=[chatbot, message]) |
| add_vector.click(_add_vector, inputs=[vector_table], outputs=[vector_table]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue() |
| demo.launch() |
|
|