"""Model management for LLM Explorer. Handles loading, unloading, and swapping models at runtime. Provides inference methods for next-token probabilities and step-by-step generation. """ import gc import json import os import threading from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer # --------------------------------------------------------------------------- # Available models — add entries here to make them selectable in admin panel. # To use a new model, just add it here and redeploy (or restart). # --------------------------------------------------------------------------- AVAILABLE_MODELS = { "Qwen2.5-3B": { "id": "Qwen/Qwen2.5-3B", "dtype": "float16", "description": "Fast, good quality (default)", }, "Qwen2.5-7B": { "id": "Qwen/Qwen2.5-7B", "dtype": "float16", "description": "Higher quality, needs 24GB+ VRAM (L4/A10)", }, "Qwen2.5-7B (4-bit)": { "id": "Qwen/Qwen2.5-7B", "quantize": "4bit", "description": "Higher quality, quantized to fit T4", }, "Llama-3.2-3B": { "id": "meta-llama/Llama-3.2-3B", "dtype": "float16", "description": "Meta's latest 3B", }, "Mistral-7B-v0.3 (4-bit)": { "id": "mistralai/Mistral-7B-v0.3", "quantize": "4bit", "description": "Best quality, quantized", }, # -- Instruct models (for System Prompt Explorer) -- "Llama-3.2-3B-Instruct": { "id": "meta-llama/Llama-3.2-3B-Instruct", "dtype": "float16", "instruct": True, "description": "Chat/instruct model, same family as prod base model (3B)", }, "Qwen2.5-3B-Instruct": { "id": "Qwen/Qwen2.5-3B-Instruct", "dtype": "float16", "instruct": True, "description": "Chat/instruct model, fast (3B)", }, "Qwen2.5-7B-Instruct (4-bit)": { "id": "Qwen/Qwen2.5-7B-Instruct", "quantize": "4bit", "instruct": True, "description": "Chat/instruct model, higher quality (7B, quantized)", }, } DEFAULT_MODEL = "Qwen2.5-3B" CONFIG_PATH = Path(__file__).parent / "config.json" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _detect_device() -> str: """Pick the best available device.""" if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" DEFAULT_SYSTEM_PROMPT_PRESETS = { "(none)": "", "Helpful Assistant": "You are a helpful, friendly assistant.", "Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.", "Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.", "Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.", "Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.", "Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.", "Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.", "Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.", "Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.", "Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others.", } # Env var → (config key, type converter). "json" = parse as JSON. ENV_VAR_MAP = { "DEFAULT_MODEL": ("model", str), "DEFAULT_CHAT_MODEL": ("chat_model", str), "DEFAULT_PROMPT": ("default_prompt", str), "DEFAULT_TEMPERATURE": ("default_temperature", float), "DEFAULT_TOP_K": ("default_top_k", int), "DEFAULT_STEPS": ("default_steps", int), "DEFAULT_SEED": ("default_seed", int), "DEFAULT_TOKENIZER_TEXT": ("default_tokenizer_text", str), "SYSTEM_PROMPT_PRESETS": ("system_prompt_presets", "json"), } def _load_config() -> dict: """Load config with three layers: code defaults → config.json → env vars.""" defaults = { "model": DEFAULT_MODEL, "default_prompt": "The best thing about Huston-Tillotson University is", "default_temperature": 0.8, "default_top_k": 10, "default_steps": 8, "default_seed": 42, "default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.", "system_prompt_presets": dict(DEFAULT_SYSTEM_PROMPT_PRESETS), } # Layer 2: config.json overrides code defaults if CONFIG_PATH.exists(): try: with open(CONFIG_PATH) as f: saved = json.load(f) defaults.update(saved) except (json.JSONDecodeError, OSError): pass # Layer 3: env vars override everything for env_var, (config_key, type_fn) in ENV_VAR_MAP.items(): val = os.environ.get(env_var) if val is not None: try: if type_fn == "json": defaults[config_key] = json.loads(val) else: defaults[config_key] = type_fn(val) except (json.JSONDecodeError, ValueError, TypeError): pass # bad env var value — skip return defaults def _save_config(cfg: dict) -> None: """Persist config to disk.""" with open(CONFIG_PATH, "w") as f: json.dump(cfg, f, indent=2) # --------------------------------------------------------------------------- # ModelManager — singleton that owns the active model # --------------------------------------------------------------------------- class ModelManager: """Manages two model slots: base (Probability Explorer) and chat (System Prompt Explorer).""" def __init__(self): # Base model (Probability Explorer) self.model = None self.tokenizer = None self.current_model_name: str | None = None # Chat model (System Prompt Explorer) self.chat_model = None self.chat_tokenizer = None self.chat_model_name: str | None = None self.device: str = _detect_device() self.loading = False self._lock = threading.Lock() self.config = _load_config() # ------------------------------------------------------------------ # Shared loading logic # ------------------------------------------------------------------ def _do_load(self, model_name: str): """Load model + tokenizer by name. Returns (model, tokenizer). Raises on failure.""" spec = AVAILABLE_MODELS[model_name] if spec.get("quantize") and not torch.cuda.is_available(): raise RuntimeError( f"Cannot load {model_name}: " f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). " f"Try a non-quantized model for local development." ) model_id = spec["id"] load_kwargs: dict = {"device_map": "auto"} if spec.get("quantize") == "4bit": from transformers import BitsAndBytesConfig load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, ) elif spec.get("quantize") == "8bit": from transformers import BitsAndBytesConfig load_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_8bit=True, ) else: dtype_str = spec.get("dtype", "float16") if dtype_str == "auto": load_kwargs["dtype"] = "auto" else: load_kwargs["dtype"] = getattr(torch, dtype_str) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) model.eval() return model, tokenizer # ------------------------------------------------------------------ # Base model lifecycle # ------------------------------------------------------------------ def load_model(self, model_name: str) -> str: """Load base model for Probability Explorer. Returns status message.""" if model_name not in AVAILABLE_MODELS: return f"Unknown model: {model_name}" if self.loading: return "A model is already being loaded. Please wait." with self._lock: self.loading = True try: # Unload current base model if self.model is not None: del self.model self.model = None if self.tokenizer is not None: del self.tokenizer self.tokenizer = None self.current_model_name = None gc.collect() model, tokenizer = self._do_load(model_name) self.model = model self.tokenizer = tokenizer self.current_model_name = model_name self.config["model"] = model_name _save_config(self.config) return f"Loaded base model: {model_name}" except Exception as e: self.model = None self.tokenizer = None self.current_model_name = None return f"Failed to load {model_name}: {e}" finally: self.loading = False # ------------------------------------------------------------------ # Chat model lifecycle # ------------------------------------------------------------------ def load_chat_model(self, model_name: str) -> str: """Load chat/instruct model for System Prompt Explorer. Returns status message.""" if model_name not in AVAILABLE_MODELS: return f"Unknown model: {model_name}" if self.loading: return "A model is already being loaded. Please wait." with self._lock: self.loading = True try: if self.chat_model is not None: del self.chat_model self.chat_model = None if self.chat_tokenizer is not None: del self.chat_tokenizer self.chat_tokenizer = None self.chat_model_name = None gc.collect() model, tokenizer = self._do_load(model_name) self.chat_model = model self.chat_tokenizer = tokenizer self.chat_model_name = model_name self.config["chat_model"] = model_name _save_config(self.config) return f"Loaded chat model: {model_name}" except Exception as e: self.chat_model = None self.chat_tokenizer = None self.chat_model_name = None return f"Failed to load chat model {model_name}: {e}" finally: self.loading = False # ------------------------------------------------------------------ # Status # ------------------------------------------------------------------ def is_ready(self) -> bool: return self.model is not None and not self.loading def chat_ready(self) -> bool: return self.chat_model is not None and not self.loading def status_message(self) -> str: if self.loading: return "Loading model..." parts = [] if self.model: parts.append(f"Base: {self.current_model_name}") if self.chat_model: parts.append(f"Chat: {self.chat_model_name}") if not parts: return "No models loaded" return " | ".join(parts) # ------------------------------------------------------------------ # Inference helpers # ------------------------------------------------------------------ def _get_logits(self, text: str) -> torch.Tensor: """Run a forward pass and return logits for the last token position.""" inputs = self.tokenizer(text, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): out = self.model(**inputs) return out.logits[0, -1, :] # (vocab_size,) @staticmethod def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor: """Apply temperature scaling to logits and return probabilities.""" if temperature <= 0: temperature = 1e-6 scaled = logits / temperature probs = torch.softmax(scaled, dim=-1) # Softmax of all -inf produces NaN (0/0); replace with 0 probs = torch.nan_to_num(probs, nan=0.0) return probs @staticmethod def entropy_bits(probs: torch.Tensor) -> float: """Shannon entropy in bits.""" p = probs[probs > 0] return float(-torch.sum(p * torch.log2(p))) def top_k_table( self, probs: torch.Tensor, k: int = 10 ) -> list[tuple[str, float, int]]: """Return list of (token_str, probability, token_id) for top-k tokens.""" topk = torch.topk(probs, k=min(k, probs.shape[0])) rows = [] for prob, idx in zip(topk.values.tolist(), topk.indices.tolist()): token_str = self.tokenizer.decode([idx]) rows.append((token_str, float(prob), int(idx))) return rows # ------------------------------------------------------------------ # High-level generation # ------------------------------------------------------------------ def generate_step_by_step( self, prompt: str, steps: int = 8, temperature: float = 0.8, top_k: int = 10, seed: int = 42, show_steps: bool = True, ) -> list[dict]: """Generate tokens one at a time, returning per-step data. top_k controls both sampling (only top-k tokens considered) and how many tokens appear in the probability table. Each step dict contains: - step: int (1-based) - text: accumulated text so far - token: the sampled token string - token_id: int - entropy: float (bits) - top_tokens: list of (token_str, prob, token_id) """ if not self.is_ready(): return [] text = prompt results = [] rng = torch.Generator() for i in range(steps): logits = self._get_logits(text) # Apply top-k filtering before temperature top_k_vals, top_k_idxs = torch.topk(logits, k=min(top_k, logits.shape[0])) mask = torch.full_like(logits, float("-inf")) mask.scatter_(0, top_k_idxs, top_k_vals) logits = mask # Temperature 0 = greedy: pick argmax of raw logits, # but display probabilities at temperature=1 so the table is meaningful. if temperature == 0: probs = self.apply_temperature(logits, temperature=1.0) idx = torch.argmax(probs).item() else: probs = self.apply_temperature(logits, temperature) entropy = self.entropy_bits(probs) top_tokens = self.top_k_table(probs, k=top_k) if show_steps else [] if temperature != 0: rng.manual_seed(seed + i) idx = torch.multinomial(probs.cpu(), num_samples=1, generator=rng).item() token_str = self.tokenizer.decode([idx]) text += token_str results.append({ "step": i + 1, "text": text, "token": token_str, "token_id": int(idx), "entropy": entropy, "top_tokens": top_tokens, }) return results def generate_chat( self, messages: list[dict], max_new_tokens: int = 256, temperature: float = 0.7, seed: int = 42, ) -> dict: """Generate a chat response using the dedicated chat model. Args: messages: Full conversation as list of {"role": ..., "content": ...} dicts, including system prompt and all previous turns. Returns dict with: - formatted_display: the full template including the response (for terminal) - response: the model's generated response text """ if not self.chat_ready(): return {"error": "Chat model not loaded"} # Format input (everything up to and including the generation prompt) formatted = self.chat_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize input inputs = self.chat_tokenizer(formatted, return_tensors="pt") inputs = {k: v.to(self.chat_model.device) for k, v in inputs.items()} input_len = inputs["input_ids"].shape[1] # Generate gen_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, "pad_token_id": self.chat_tokenizer.eos_token_id, } if temperature > 0: gen_kwargs["temperature"] = temperature if self.chat_model.device.type == "cuda": torch.cuda.manual_seed(seed) torch.manual_seed(seed) with torch.no_grad(): output_ids = self.chat_model.generate(**inputs, **gen_kwargs) # Decode only the new tokens new_ids = output_ids[0][input_len:] response = self.chat_tokenizer.decode(new_ids, skip_special_tokens=True).strip() # Build display template (includes the response) for green terminal display_messages = messages + [{"role": "assistant", "content": response}] formatted_display = self.chat_tokenizer.apply_chat_template( display_messages, tokenize=False, add_generation_prompt=False, ) return { "formatted_display": formatted_display, "response": response, } def format_chat_template(self, messages: list[dict]) -> str: """Format messages using the chat model's template (for terminal display).""" if not self.chat_tokenizer: return "" return self.chat_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) def tokenize(self, text: str) -> list[tuple[str, int]]: """Tokenize text and return list of (token_str, token_id).""" if self.tokenizer is None: return [] ids = self.tokenizer.encode(text) return [(self.tokenizer.decode([tid]), tid) for tid in ids] # ------------------------------------------------------------------ # Config helpers # ------------------------------------------------------------------ def get_config(self) -> dict: return dict(self.config) def update_config(self, **kwargs) -> None: self.config.update(kwargs) _save_config(self.config) # --------------------------------------------------------------------------- # Separate tokenizer for demo purposes (GPT-2 shows more interesting splits) # --------------------------------------------------------------------------- class DemoTokenizer: """Lightweight tokenizer for the Tokenizer tab. Uses GPT-2's BPE tokenizer which has a smaller vocabulary and produces more interesting subword splits than modern tokenizers like Qwen's. """ def __init__(self): self.tokenizer = None self._loaded = False def ensure_loaded(self): """Load tokenizer on first use (lazy loading).""" if not self._loaded: self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") self._loaded = True def tokenize(self, text: str) -> list[tuple[str, int]]: """Tokenize text and return list of (token_str, token_id).""" self.ensure_loaded() ids = self.tokenizer.encode(text) return [(self.tokenizer.decode([tid]), tid) for tid in ids] # Module-level singleton for demo tokenizer demo_tokenizer = DemoTokenizer() # Module-level singleton manager = ModelManager()