llm-explorer / models.py
chyams's picture
System Prompt Explorer: dual model, multi-turn chat, configurable presets
beb8b02
"""Model management for LLM Explorer.
Handles loading, unloading, and swapping models at runtime.
Provides inference methods for next-token probabilities and step-by-step generation.
"""
import gc
import json
import os
import threading
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---------------------------------------------------------------------------
# Available models — add entries here to make them selectable in admin panel.
# To use a new model, just add it here and redeploy (or restart).
# ---------------------------------------------------------------------------
AVAILABLE_MODELS = {
"Qwen2.5-3B": {
"id": "Qwen/Qwen2.5-3B",
"dtype": "float16",
"description": "Fast, good quality (default)",
},
"Qwen2.5-7B": {
"id": "Qwen/Qwen2.5-7B",
"dtype": "float16",
"description": "Higher quality, needs 24GB+ VRAM (L4/A10)",
},
"Qwen2.5-7B (4-bit)": {
"id": "Qwen/Qwen2.5-7B",
"quantize": "4bit",
"description": "Higher quality, quantized to fit T4",
},
"Llama-3.2-3B": {
"id": "meta-llama/Llama-3.2-3B",
"dtype": "float16",
"description": "Meta's latest 3B",
},
"Mistral-7B-v0.3 (4-bit)": {
"id": "mistralai/Mistral-7B-v0.3",
"quantize": "4bit",
"description": "Best quality, quantized",
},
# -- Instruct models (for System Prompt Explorer) --
"Llama-3.2-3B-Instruct": {
"id": "meta-llama/Llama-3.2-3B-Instruct",
"dtype": "float16",
"instruct": True,
"description": "Chat/instruct model, same family as prod base model (3B)",
},
"Qwen2.5-3B-Instruct": {
"id": "Qwen/Qwen2.5-3B-Instruct",
"dtype": "float16",
"instruct": True,
"description": "Chat/instruct model, fast (3B)",
},
"Qwen2.5-7B-Instruct (4-bit)": {
"id": "Qwen/Qwen2.5-7B-Instruct",
"quantize": "4bit",
"instruct": True,
"description": "Chat/instruct model, higher quality (7B, quantized)",
},
}
DEFAULT_MODEL = "Qwen2.5-3B"
CONFIG_PATH = Path(__file__).parent / "config.json"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _detect_device() -> str:
"""Pick the best available device."""
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
DEFAULT_SYSTEM_PROMPT_PRESETS = {
"(none)": "",
"Helpful Assistant": "You are a helpful, friendly assistant.",
"Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.",
"Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.",
"Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.",
"Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.",
"Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.",
"Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.",
"Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.",
"Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.",
"Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others.",
}
# Env var → (config key, type converter). "json" = parse as JSON.
ENV_VAR_MAP = {
"DEFAULT_MODEL": ("model", str),
"DEFAULT_CHAT_MODEL": ("chat_model", str),
"DEFAULT_PROMPT": ("default_prompt", str),
"DEFAULT_TEMPERATURE": ("default_temperature", float),
"DEFAULT_TOP_K": ("default_top_k", int),
"DEFAULT_STEPS": ("default_steps", int),
"DEFAULT_SEED": ("default_seed", int),
"DEFAULT_TOKENIZER_TEXT": ("default_tokenizer_text", str),
"SYSTEM_PROMPT_PRESETS": ("system_prompt_presets", "json"),
}
def _load_config() -> dict:
"""Load config with three layers: code defaults → config.json → env vars."""
defaults = {
"model": DEFAULT_MODEL,
"default_prompt": "The best thing about Huston-Tillotson University is",
"default_temperature": 0.8,
"default_top_k": 10,
"default_steps": 8,
"default_seed": 42,
"default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.",
"system_prompt_presets": dict(DEFAULT_SYSTEM_PROMPT_PRESETS),
}
# Layer 2: config.json overrides code defaults
if CONFIG_PATH.exists():
try:
with open(CONFIG_PATH) as f:
saved = json.load(f)
defaults.update(saved)
except (json.JSONDecodeError, OSError):
pass
# Layer 3: env vars override everything
for env_var, (config_key, type_fn) in ENV_VAR_MAP.items():
val = os.environ.get(env_var)
if val is not None:
try:
if type_fn == "json":
defaults[config_key] = json.loads(val)
else:
defaults[config_key] = type_fn(val)
except (json.JSONDecodeError, ValueError, TypeError):
pass # bad env var value — skip
return defaults
def _save_config(cfg: dict) -> None:
"""Persist config to disk."""
with open(CONFIG_PATH, "w") as f:
json.dump(cfg, f, indent=2)
# ---------------------------------------------------------------------------
# ModelManager — singleton that owns the active model
# ---------------------------------------------------------------------------
class ModelManager:
"""Manages two model slots: base (Probability Explorer) and chat (System Prompt Explorer)."""
def __init__(self):
# Base model (Probability Explorer)
self.model = None
self.tokenizer = None
self.current_model_name: str | None = None
# Chat model (System Prompt Explorer)
self.chat_model = None
self.chat_tokenizer = None
self.chat_model_name: str | None = None
self.device: str = _detect_device()
self.loading = False
self._lock = threading.Lock()
self.config = _load_config()
# ------------------------------------------------------------------
# Shared loading logic
# ------------------------------------------------------------------
def _do_load(self, model_name: str):
"""Load model + tokenizer by name. Returns (model, tokenizer). Raises on failure."""
spec = AVAILABLE_MODELS[model_name]
if spec.get("quantize") and not torch.cuda.is_available():
raise RuntimeError(
f"Cannot load {model_name}: "
f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). "
f"Try a non-quantized model for local development."
)
model_id = spec["id"]
load_kwargs: dict = {"device_map": "auto"}
if spec.get("quantize") == "4bit":
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
)
elif spec.get("quantize") == "8bit":
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
dtype_str = spec.get("dtype", "float16")
if dtype_str == "auto":
load_kwargs["dtype"] = "auto"
else:
load_kwargs["dtype"] = getattr(torch, dtype_str)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
model.eval()
return model, tokenizer
# ------------------------------------------------------------------
# Base model lifecycle
# ------------------------------------------------------------------
def load_model(self, model_name: str) -> str:
"""Load base model for Probability Explorer. Returns status message."""
if model_name not in AVAILABLE_MODELS:
return f"Unknown model: {model_name}"
if self.loading:
return "A model is already being loaded. Please wait."
with self._lock:
self.loading = True
try:
# Unload current base model
if self.model is not None:
del self.model
self.model = None
if self.tokenizer is not None:
del self.tokenizer
self.tokenizer = None
self.current_model_name = None
gc.collect()
model, tokenizer = self._do_load(model_name)
self.model = model
self.tokenizer = tokenizer
self.current_model_name = model_name
self.config["model"] = model_name
_save_config(self.config)
return f"Loaded base model: {model_name}"
except Exception as e:
self.model = None
self.tokenizer = None
self.current_model_name = None
return f"Failed to load {model_name}: {e}"
finally:
self.loading = False
# ------------------------------------------------------------------
# Chat model lifecycle
# ------------------------------------------------------------------
def load_chat_model(self, model_name: str) -> str:
"""Load chat/instruct model for System Prompt Explorer. Returns status message."""
if model_name not in AVAILABLE_MODELS:
return f"Unknown model: {model_name}"
if self.loading:
return "A model is already being loaded. Please wait."
with self._lock:
self.loading = True
try:
if self.chat_model is not None:
del self.chat_model
self.chat_model = None
if self.chat_tokenizer is not None:
del self.chat_tokenizer
self.chat_tokenizer = None
self.chat_model_name = None
gc.collect()
model, tokenizer = self._do_load(model_name)
self.chat_model = model
self.chat_tokenizer = tokenizer
self.chat_model_name = model_name
self.config["chat_model"] = model_name
_save_config(self.config)
return f"Loaded chat model: {model_name}"
except Exception as e:
self.chat_model = None
self.chat_tokenizer = None
self.chat_model_name = None
return f"Failed to load chat model {model_name}: {e}"
finally:
self.loading = False
# ------------------------------------------------------------------
# Status
# ------------------------------------------------------------------
def is_ready(self) -> bool:
return self.model is not None and not self.loading
def chat_ready(self) -> bool:
return self.chat_model is not None and not self.loading
def status_message(self) -> str:
if self.loading:
return "Loading model..."
parts = []
if self.model:
parts.append(f"Base: {self.current_model_name}")
if self.chat_model:
parts.append(f"Chat: {self.chat_model_name}")
if not parts:
return "No models loaded"
return " | ".join(parts)
# ------------------------------------------------------------------
# Inference helpers
# ------------------------------------------------------------------
def _get_logits(self, text: str) -> torch.Tensor:
"""Run a forward pass and return logits for the last token position."""
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
out = self.model(**inputs)
return out.logits[0, -1, :] # (vocab_size,)
@staticmethod
def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""Apply temperature scaling to logits and return probabilities."""
if temperature <= 0:
temperature = 1e-6
scaled = logits / temperature
probs = torch.softmax(scaled, dim=-1)
# Softmax of all -inf produces NaN (0/0); replace with 0
probs = torch.nan_to_num(probs, nan=0.0)
return probs
@staticmethod
def entropy_bits(probs: torch.Tensor) -> float:
"""Shannon entropy in bits."""
p = probs[probs > 0]
return float(-torch.sum(p * torch.log2(p)))
def top_k_table(
self, probs: torch.Tensor, k: int = 10
) -> list[tuple[str, float, int]]:
"""Return list of (token_str, probability, token_id) for top-k tokens."""
topk = torch.topk(probs, k=min(k, probs.shape[0]))
rows = []
for prob, idx in zip(topk.values.tolist(), topk.indices.tolist()):
token_str = self.tokenizer.decode([idx])
rows.append((token_str, float(prob), int(idx)))
return rows
# ------------------------------------------------------------------
# High-level generation
# ------------------------------------------------------------------
def generate_step_by_step(
self,
prompt: str,
steps: int = 8,
temperature: float = 0.8,
top_k: int = 10,
seed: int = 42,
show_steps: bool = True,
) -> list[dict]:
"""Generate tokens one at a time, returning per-step data.
top_k controls both sampling (only top-k tokens considered) and
how many tokens appear in the probability table.
Each step dict contains:
- step: int (1-based)
- text: accumulated text so far
- token: the sampled token string
- token_id: int
- entropy: float (bits)
- top_tokens: list of (token_str, prob, token_id)
"""
if not self.is_ready():
return []
text = prompt
results = []
rng = torch.Generator()
for i in range(steps):
logits = self._get_logits(text)
# Apply top-k filtering before temperature
top_k_vals, top_k_idxs = torch.topk(logits, k=min(top_k, logits.shape[0]))
mask = torch.full_like(logits, float("-inf"))
mask.scatter_(0, top_k_idxs, top_k_vals)
logits = mask
# Temperature 0 = greedy: pick argmax of raw logits,
# but display probabilities at temperature=1 so the table is meaningful.
if temperature == 0:
probs = self.apply_temperature(logits, temperature=1.0)
idx = torch.argmax(probs).item()
else:
probs = self.apply_temperature(logits, temperature)
entropy = self.entropy_bits(probs)
top_tokens = self.top_k_table(probs, k=top_k) if show_steps else []
if temperature != 0:
rng.manual_seed(seed + i)
idx = torch.multinomial(probs.cpu(), num_samples=1, generator=rng).item()
token_str = self.tokenizer.decode([idx])
text += token_str
results.append({
"step": i + 1,
"text": text,
"token": token_str,
"token_id": int(idx),
"entropy": entropy,
"top_tokens": top_tokens,
})
return results
def generate_chat(
self,
messages: list[dict],
max_new_tokens: int = 256,
temperature: float = 0.7,
seed: int = 42,
) -> dict:
"""Generate a chat response using the dedicated chat model.
Args:
messages: Full conversation as list of {"role": ..., "content": ...} dicts,
including system prompt and all previous turns.
Returns dict with:
- formatted_display: the full template including the response (for terminal)
- response: the model's generated response text
"""
if not self.chat_ready():
return {"error": "Chat model not loaded"}
# Format input (everything up to and including the generation prompt)
formatted = self.chat_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
# Tokenize input
inputs = self.chat_tokenizer(formatted, return_tensors="pt")
inputs = {k: v.to(self.chat_model.device) for k, v in inputs.items()}
input_len = inputs["input_ids"].shape[1]
# Generate
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"pad_token_id": self.chat_tokenizer.eos_token_id,
}
if temperature > 0:
gen_kwargs["temperature"] = temperature
if self.chat_model.device.type == "cuda":
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
with torch.no_grad():
output_ids = self.chat_model.generate(**inputs, **gen_kwargs)
# Decode only the new tokens
new_ids = output_ids[0][input_len:]
response = self.chat_tokenizer.decode(new_ids, skip_special_tokens=True).strip()
# Build display template (includes the response) for green terminal
display_messages = messages + [{"role": "assistant", "content": response}]
formatted_display = self.chat_tokenizer.apply_chat_template(
display_messages, tokenize=False, add_generation_prompt=False,
)
return {
"formatted_display": formatted_display,
"response": response,
}
def format_chat_template(self, messages: list[dict]) -> str:
"""Format messages using the chat model's template (for terminal display)."""
if not self.chat_tokenizer:
return ""
return self.chat_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
def tokenize(self, text: str) -> list[tuple[str, int]]:
"""Tokenize text and return list of (token_str, token_id)."""
if self.tokenizer is None:
return []
ids = self.tokenizer.encode(text)
return [(self.tokenizer.decode([tid]), tid) for tid in ids]
# ------------------------------------------------------------------
# Config helpers
# ------------------------------------------------------------------
def get_config(self) -> dict:
return dict(self.config)
def update_config(self, **kwargs) -> None:
self.config.update(kwargs)
_save_config(self.config)
# ---------------------------------------------------------------------------
# Separate tokenizer for demo purposes (GPT-2 shows more interesting splits)
# ---------------------------------------------------------------------------
class DemoTokenizer:
"""Lightweight tokenizer for the Tokenizer tab.
Uses GPT-2's BPE tokenizer which has a smaller vocabulary and produces
more interesting subword splits than modern tokenizers like Qwen's.
"""
def __init__(self):
self.tokenizer = None
self._loaded = False
def ensure_loaded(self):
"""Load tokenizer on first use (lazy loading)."""
if not self._loaded:
self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
self._loaded = True
def tokenize(self, text: str) -> list[tuple[str, int]]:
"""Tokenize text and return list of (token_str, token_id)."""
self.ensure_loaded()
ids = self.tokenizer.encode(text)
return [(self.tokenizer.decode([tid]), tid) for tid in ids]
# Module-level singleton for demo tokenizer
demo_tokenizer = DemoTokenizer()
# Module-level singleton
manager = ModelManager()