llm_steer / app.py
Mihai Chirculescu
Revert "Add Qwen 0.8 3.5"
5195c1d
import threading
from dataclasses import dataclass
from typing import Any, Iterable
import gradio as gr
import spaces
import torch
from llm_steer import Steer, DecaySchedule
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_IDS = ["LiquidAI/LFM2-350M", "LiquidAI/LFM2-700M", "LiquidAI/LFM2-1.2B"]
DEFAULT_SCHEDULE_ENABLED = True
DEFAULT_SCHEDULE_RATE = 0.85
DEFAULT_SCHEDULE_MIN_MULTIPLIER = 0.35
DEFAULT_SCHEDULE_RESTARTS = 14
DEFAULT_VECTORS_ENABLED = True
PRESETS = [
{
"label": "Logical thinking",
"model_id": MODEL_IDS[0],
"vectors": [
["But wait", 0.4, 7],
["But wait", 0.4, 8],
["overthink", -0.4, 7],
["overthink", -0.4, 8],
],
"schedule": {
"enabled": DEFAULT_SCHEDULE_ENABLED,
"rate": DEFAULT_SCHEDULE_RATE,
"min_multiplier": DEFAULT_SCHEDULE_MIN_MULTIPLIER,
"times_restart": DEFAULT_SCHEDULE_RESTARTS,
},
}
]
@dataclass
class ModelBundle:
model_id: str
tokenizer: Any
base_model: Any
steer: Steer
lock: threading.Lock
class IsolatedSteer(Steer):
def __init__(self, model, tokenizer, copyModel: bool):
super().__init__(model, tokenizer, copyModel=copyModel)
self.steers = {}
def _resolve_dtype() -> torch.dtype:
if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
return torch.float32
def _load_model(model_id: str) -> ModelBundle:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=_resolve_dtype(),
trust_remote_code=True,
)
base_model.eval()
if base_model.config.pad_token_id is None:
base_model.config.pad_token_id = tokenizer.pad_token_id
steer = IsolatedSteer(base_model, tokenizer, copyModel=True)
steer.model.to(device)
steer.device = device
steer.model.eval()
return ModelBundle(
model_id=model_id,
tokenizer=tokenizer,
base_model=base_model,
steer=steer,
lock=threading.Lock(),
)
MODELS = {model_id: _load_model(model_id) for model_id in MODEL_IDS}
def _parse_vectors(
raw_vectors: Iterable[Iterable[Any]],
) -> list[tuple[str, float, int]]:
cleaned: list[tuple[str, float, int]] = []
if not raw_vectors:
return cleaned
for row in raw_vectors:
if not row or len(row) < 3:
continue
text = "" if row[0] is None else str(row[0]).strip()
if not text:
continue
try:
coeff = float(row[1])
except (TypeError, ValueError):
continue
if coeff == 0:
continue
try:
layer_idx = int(row[2])
except (TypeError, ValueError):
try:
layer_idx = int(float(row[2]))
except (TypeError, ValueError):
continue
if layer_idx < 0:
continue
cleaned.append((text, coeff, layer_idx))
return cleaned
def _format_plain_prompt(messages: list[dict[str, str]]) -> str:
lines = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if role == "assistant":
prefix = "Assistant"
elif role == "system":
prefix = "System"
else:
prefix = "User"
lines.append(f"{prefix}: {content}")
lines.append("Assistant:")
return "\n".join(lines)
def _build_prompt(
history: list[dict[str, str]],
user_message: str,
tokenizer: Any,
) -> str:
messages: list[dict[str, str]] = []
if history:
messages.extend(history)
messages.append({"role": "user", "content": user_message})
if getattr(tokenizer, "chat_template", None):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return _format_plain_prompt(messages)
def _generate_reply(
bundle: ModelBundle,
prompt: str,
max_new_tokens: int,
temperature: float,
min_p: float,
repetition_penalty: float,
) -> str:
tokenizer = bundle.tokenizer
model = bundle.steer.model
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(model.device)
gen_kwargs = {
"input_ids": input_ids,
"max_new_tokens": int(max_new_tokens),
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
if attention_mask is not None:
gen_kwargs["attention_mask"] = attention_mask
if repetition_penalty and repetition_penalty != 1.0:
gen_kwargs["repetition_penalty"] = float(repetition_penalty)
if temperature and temperature > 0:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = float(temperature)
if min_p and min_p > 0:
gen_kwargs["min_p"] = float(min_p)
else:
gen_kwargs["do_sample"] = False
output_ids = model.generate(**gen_kwargs)
new_tokens = output_ids[0][input_ids.shape[1] :]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
@spaces.GPU(duration=20)
def respond(
message: str,
history: list[dict[str, str]],
model_id: str,
vectors: list[list[Any]],
vectors_enabled: bool,
schedule_enabled: bool,
schedule_rate: float,
schedule_min_multiplier: float,
schedule_times_restart: float,
max_new_tokens: int,
temperature: float,
min_p: float,
repetition_penalty: float,
):
history = history or []
message = (message or "").strip()
if not message:
return history, ""
bundle = MODELS[model_id]
steer_vectors = _parse_vectors(vectors)
schedule = None
if vectors_enabled and schedule_enabled:
schedule = DecaySchedule(
rate=float(schedule_rate),
min_multiplier=float(schedule_min_multiplier),
times_restart=int(schedule_times_restart),
)
with bundle.lock:
steer = bundle.steer
steer.reset_all()
layers_used: set[int] = set()
try:
if vectors_enabled:
for text, coeff, layer_idx in steer_vectors:
steer.add(
layer_idx=layer_idx,
coeff=coeff,
text=text,
coeff_schedule=schedule,
)
layers_used.add(layer_idx)
prompt = _build_prompt(history, message, bundle.tokenizer)
with torch.inference_mode():
reply = _generate_reply(
bundle,
prompt,
max_new_tokens,
temperature,
min_p,
repetition_penalty,
)
finally:
steer.reset_all()
updated_history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": reply},
]
return updated_history, ""
def _clear_chat():
return [], ""
def _add_vector(rows: list[list[Any]] | None):
rows = rows or []
rows.append(["", 0.6, 8])
return rows
def _apply_preset(preset: dict[str, Any]):
schedule = preset.get("schedule", {})
return (
preset["model_id"],
preset["vectors"],
schedule.get("enabled", DEFAULT_SCHEDULE_ENABLED),
schedule.get("rate", DEFAULT_SCHEDULE_RATE),
schedule.get("min_multiplier", DEFAULT_SCHEDULE_MIN_MULTIPLIER),
schedule.get("times_restart", DEFAULT_SCHEDULE_RESTARTS),
)
THEME = gr.themes.Base(
font=["Space Grotesk", "IBM Plex Sans", "sans-serif"],
primary_hue="teal",
secondary_hue="orange",
neutral_hue="slate",
)
CSS = """
:root {
--body-background-fill: linear-gradient(135deg, #f6f2ea 0%, #f1f7f4 60%, #eef3f9 100%);
--block-background-fill: rgba(255, 255, 255, 0.92);
--block-border-color: #d6d8db;
--block-shadow: 0 14px 30px rgba(15, 23, 42, 0.08);
}
#title {
letter-spacing: 0.03em;
}
#vector-table thead th button.header-button {
pointer-events: none;
cursor: default;
}
#vector-table thead th {
user-select: none;
}
#preset-panel {
margin-top: 0.75rem;
padding: 0.75rem;
border: 1px dashed #cbd5e1;
border-radius: 14px;
background: rgba(255, 255, 255, 0.65);
}
#vectors-panel,
#schedule-panel {
margin-top: 0.75rem;
border: 1px solid #e2e8f0;
border-radius: 14px;
background: rgba(255, 255, 255, 0.7);
}
.preset-btn button {
background: linear-gradient(140deg, #ffffff 0%, #fff1e4 100%);
border: 1px solid #e2e8f0;
font-weight: 600;
}
.preset-btn button:hover {
border-color: #0f766e;
color: #0f766e;
}
"""
with gr.Blocks(theme=THEME, css=CSS) as demo:
gr.Markdown("# LLM Steer Playground", elem_id="title")
gr.Markdown("Pick a model, add steering vectors, and chat with the steered model.")
with gr.Row():
with gr.Column(scale=1, min_width=320):
model_choice = gr.Radio(
choices=MODEL_IDS,
value=MODEL_IDS[2],
label="Model",
)
with gr.Accordion(
"Steering vectors",
open=True,
elem_id="vectors-panel",
):
vectors_enabled = gr.Checkbox(
value=DEFAULT_VECTORS_ENABLED,
label="Enable steering vectors",
)
vector_table = gr.Dataframe(
headers=["text", "coeff", "layer"],
row_count=(1, "dynamic"),
col_count=(3, "fixed"),
type="array",
datatype=["str", "number", "number"],
value=[
["But wait", 0.4, 7],
["But wait", 0.4, 8],
["overthink", -0.4, 7],
["overthink", -0.4, 8],
],
label="Vectors",
elem_id="vector-table",
interactive=True,
)
add_vector = gr.Button("Add new vector row")
gr.Markdown(
"Coeff can be negative. Layer should have value between 0 and 15."
)
with gr.Accordion(
"Steering schedule",
open=False,
elem_id="schedule-panel",
):
schedule_enabled = gr.Checkbox(
value=DEFAULT_SCHEDULE_ENABLED,
label="Enable DecaySchedule",
)
schedule_rate = gr.Slider(
minimum=0.1,
maximum=1.0,
value=DEFAULT_SCHEDULE_RATE,
step=0.01,
label="Decay rate",
)
schedule_min_multiplier = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_SCHEDULE_MIN_MULTIPLIER,
step=0.01,
label="Min multiplier",
)
schedule_times_restart = gr.Slider(
minimum=0,
maximum=100,
value=DEFAULT_SCHEDULE_RESTARTS,
step=1,
label="Restarts",
)
with gr.Column(elem_id="preset-panel"):
gr.Markdown("Presets")
with gr.Row():
for preset in PRESETS[:2]:
button = gr.Button(
preset["label"],
elem_classes=["preset-btn"],
)
button.click(
fn=lambda p=preset: _apply_preset(p),
outputs=[
model_choice,
vector_table,
schedule_enabled,
schedule_rate,
schedule_min_multiplier,
schedule_times_restart,
],
inputs=None,
)
with gr.Row():
for preset in PRESETS[2:]:
button = gr.Button(
preset["label"],
elem_classes=["preset-btn"],
)
button.click(
fn=lambda p=preset: _apply_preset(p),
outputs=[
model_choice,
vector_table,
schedule_enabled,
schedule_rate,
schedule_min_multiplier,
schedule_times_restart,
],
inputs=None,
)
with gr.Accordion("Generation options", open=False):
max_new_tokens = gr.Slider(
minimum=512 / 2,
maximum=512 * 4,
value=1024,
step=8,
label="Max new tokens",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.5,
value=0.0,
step=0.05,
label="Temperature",
)
min_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.15,
step=0.01,
label="Min-p",
)
repetition_penalty = gr.Slider(
minimum=0.8,
maximum=2.0,
value=1.2,
step=0.05,
label="Repetition penalty",
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Chat",
type="messages",
height=520,
show_copy_button=True,
)
message = gr.Textbox(
label="Message",
placeholder="Ask something...",
lines=2,
)
with gr.Row():
send = gr.Button("Send", variant="primary")
clear = gr.Button("Clear chat")
send.click(
respond,
inputs=[
message,
chatbot,
model_choice,
vector_table,
vectors_enabled,
schedule_enabled,
schedule_rate,
schedule_min_multiplier,
schedule_times_restart,
max_new_tokens,
temperature,
min_p,
repetition_penalty,
],
outputs=[chatbot, message],
)
message.submit(
respond,
inputs=[
message,
chatbot,
model_choice,
vector_table,
vectors_enabled,
schedule_enabled,
schedule_rate,
schedule_min_multiplier,
schedule_times_restart,
max_new_tokens,
temperature,
min_p,
repetition_penalty,
],
outputs=[chatbot, message],
)
clear.click(_clear_chat, outputs=[chatbot, message])
add_vector.click(_add_vector, inputs=[vector_table], outputs=[vector_table])
if __name__ == "__main__":
demo.queue()
demo.launch()