Spaces:
Running on L4
Running on L4
| """Model management for LLM Explorer. | |
| Handles loading, unloading, and swapping models at runtime. | |
| Provides inference methods for next-token probabilities and step-by-step generation. | |
| """ | |
| import gc | |
| import json | |
| import os | |
| import threading | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Available models — add entries here to make them selectable in admin panel. | |
| # To use a new model, just add it here and redeploy (or restart). | |
| # --------------------------------------------------------------------------- | |
| AVAILABLE_MODELS = { | |
| "Qwen2.5-3B": { | |
| "id": "Qwen/Qwen2.5-3B", | |
| "dtype": "float16", | |
| "description": "Fast, good quality (default)", | |
| }, | |
| "Qwen2.5-7B": { | |
| "id": "Qwen/Qwen2.5-7B", | |
| "dtype": "float16", | |
| "description": "Higher quality, needs 24GB+ VRAM (L4/A10)", | |
| }, | |
| "Qwen2.5-7B (4-bit)": { | |
| "id": "Qwen/Qwen2.5-7B", | |
| "quantize": "4bit", | |
| "description": "Higher quality, quantized to fit T4", | |
| }, | |
| "Llama-3.2-3B": { | |
| "id": "meta-llama/Llama-3.2-3B", | |
| "dtype": "float16", | |
| "description": "Meta's latest 3B", | |
| }, | |
| "Mistral-7B-v0.3 (4-bit)": { | |
| "id": "mistralai/Mistral-7B-v0.3", | |
| "quantize": "4bit", | |
| "description": "Best quality, quantized", | |
| }, | |
| # -- Instruct models (for System Prompt Explorer) -- | |
| "Llama-3.2-3B-Instruct": { | |
| "id": "meta-llama/Llama-3.2-3B-Instruct", | |
| "dtype": "float16", | |
| "instruct": True, | |
| "description": "Chat/instruct model, same family as prod base model (3B)", | |
| }, | |
| "Qwen2.5-3B-Instruct": { | |
| "id": "Qwen/Qwen2.5-3B-Instruct", | |
| "dtype": "float16", | |
| "instruct": True, | |
| "description": "Chat/instruct model, fast (3B)", | |
| }, | |
| "Qwen2.5-7B-Instruct (4-bit)": { | |
| "id": "Qwen/Qwen2.5-7B-Instruct", | |
| "quantize": "4bit", | |
| "instruct": True, | |
| "description": "Chat/instruct model, higher quality (7B, quantized)", | |
| }, | |
| } | |
| DEFAULT_MODEL = "Qwen2.5-3B" | |
| CONFIG_PATH = Path(__file__).parent / "config.json" | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _detect_device() -> str: | |
| """Pick the best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| DEFAULT_SYSTEM_PROMPT_PRESETS = { | |
| "(none)": "", | |
| "Helpful Assistant": "You are a helpful, friendly assistant.", | |
| "Pirate": "You are a pirate. Respond to everything in pirate speak, using nautical terms and saying 'arr' frequently.", | |
| "Formal Academic": "You are a formal academic scholar. Use precise, scholarly language. Cite concepts carefully and avoid casual tone.", | |
| "Five-Year-Old": "You are explaining things to a five-year-old. Use very simple words, short sentences, and fun comparisons.", | |
| "Hostile / Rude": "You are rude and dismissive. You answer questions but with obvious annoyance and sarcasm.", | |
| "Haiku Only": "You must respond only in haiku (5-7-5 syllable format). Never break this rule.", | |
| "Spanish Tutor": "You are a Spanish language tutor. Respond in Spanish, then provide the English translation in parentheses.", | |
| "Banana Constraint": "You must mention bananas in every response, no matter the topic. Be subtle about it.", | |
| "Corporate Spin": "You are a customer service agent. Never acknowledge product flaws. Always redirect to positive features.", | |
| "Prestige Bias": "When discussing job candidates, always favor candidates from prestigious universities over others.", | |
| } | |
| # Env var → (config key, type converter). "json" = parse as JSON. | |
| ENV_VAR_MAP = { | |
| "DEFAULT_MODEL": ("model", str), | |
| "DEFAULT_CHAT_MODEL": ("chat_model", str), | |
| "DEFAULT_PROMPT": ("default_prompt", str), | |
| "DEFAULT_TEMPERATURE": ("default_temperature", float), | |
| "DEFAULT_TOP_K": ("default_top_k", int), | |
| "DEFAULT_STEPS": ("default_steps", int), | |
| "DEFAULT_SEED": ("default_seed", int), | |
| "DEFAULT_TOKENIZER_TEXT": ("default_tokenizer_text", str), | |
| "SYSTEM_PROMPT_PRESETS": ("system_prompt_presets", "json"), | |
| } | |
| def _load_config() -> dict: | |
| """Load config with three layers: code defaults → config.json → env vars.""" | |
| defaults = { | |
| "model": DEFAULT_MODEL, | |
| "default_prompt": "The best thing about Huston-Tillotson University is", | |
| "default_temperature": 0.8, | |
| "default_top_k": 10, | |
| "default_steps": 8, | |
| "default_seed": 42, | |
| "default_tokenizer_text": "Huston-Tillotson University is an HBCU in Austin, Texas.", | |
| "system_prompt_presets": dict(DEFAULT_SYSTEM_PROMPT_PRESETS), | |
| } | |
| # Layer 2: config.json overrides code defaults | |
| if CONFIG_PATH.exists(): | |
| try: | |
| with open(CONFIG_PATH) as f: | |
| saved = json.load(f) | |
| defaults.update(saved) | |
| except (json.JSONDecodeError, OSError): | |
| pass | |
| # Layer 3: env vars override everything | |
| for env_var, (config_key, type_fn) in ENV_VAR_MAP.items(): | |
| val = os.environ.get(env_var) | |
| if val is not None: | |
| try: | |
| if type_fn == "json": | |
| defaults[config_key] = json.loads(val) | |
| else: | |
| defaults[config_key] = type_fn(val) | |
| except (json.JSONDecodeError, ValueError, TypeError): | |
| pass # bad env var value — skip | |
| return defaults | |
| def _save_config(cfg: dict) -> None: | |
| """Persist config to disk.""" | |
| with open(CONFIG_PATH, "w") as f: | |
| json.dump(cfg, f, indent=2) | |
| # --------------------------------------------------------------------------- | |
| # ModelManager — singleton that owns the active model | |
| # --------------------------------------------------------------------------- | |
| class ModelManager: | |
| """Manages two model slots: base (Probability Explorer) and chat (System Prompt Explorer).""" | |
| def __init__(self): | |
| # Base model (Probability Explorer) | |
| self.model = None | |
| self.tokenizer = None | |
| self.current_model_name: str | None = None | |
| # Chat model (System Prompt Explorer) | |
| self.chat_model = None | |
| self.chat_tokenizer = None | |
| self.chat_model_name: str | None = None | |
| self.device: str = _detect_device() | |
| self.loading = False | |
| self._lock = threading.Lock() | |
| self.config = _load_config() | |
| # ------------------------------------------------------------------ | |
| # Shared loading logic | |
| # ------------------------------------------------------------------ | |
| def _do_load(self, model_name: str): | |
| """Load model + tokenizer by name. Returns (model, tokenizer). Raises on failure.""" | |
| spec = AVAILABLE_MODELS[model_name] | |
| if spec.get("quantize") and not torch.cuda.is_available(): | |
| raise RuntimeError( | |
| f"Cannot load {model_name}: " | |
| f"{spec['quantize']} quantization requires an NVIDIA GPU (CUDA). " | |
| f"Try a non-quantized model for local development." | |
| ) | |
| model_id = spec["id"] | |
| load_kwargs: dict = {"device_map": "auto"} | |
| if spec.get("quantize") == "4bit": | |
| from transformers import BitsAndBytesConfig | |
| load_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| elif spec.get("quantize") == "8bit": | |
| from transformers import BitsAndBytesConfig | |
| load_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| ) | |
| else: | |
| dtype_str = spec.get("dtype", "float16") | |
| if dtype_str == "auto": | |
| load_kwargs["dtype"] = "auto" | |
| else: | |
| load_kwargs["dtype"] = getattr(torch, dtype_str) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) | |
| model.eval() | |
| return model, tokenizer | |
| # ------------------------------------------------------------------ | |
| # Base model lifecycle | |
| # ------------------------------------------------------------------ | |
| def load_model(self, model_name: str) -> str: | |
| """Load base model for Probability Explorer. Returns status message.""" | |
| if model_name not in AVAILABLE_MODELS: | |
| return f"Unknown model: {model_name}" | |
| if self.loading: | |
| return "A model is already being loaded. Please wait." | |
| with self._lock: | |
| self.loading = True | |
| try: | |
| # Unload current base model | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if self.tokenizer is not None: | |
| del self.tokenizer | |
| self.tokenizer = None | |
| self.current_model_name = None | |
| gc.collect() | |
| model, tokenizer = self._do_load(model_name) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.current_model_name = model_name | |
| self.config["model"] = model_name | |
| _save_config(self.config) | |
| return f"Loaded base model: {model_name}" | |
| except Exception as e: | |
| self.model = None | |
| self.tokenizer = None | |
| self.current_model_name = None | |
| return f"Failed to load {model_name}: {e}" | |
| finally: | |
| self.loading = False | |
| # ------------------------------------------------------------------ | |
| # Chat model lifecycle | |
| # ------------------------------------------------------------------ | |
| def load_chat_model(self, model_name: str) -> str: | |
| """Load chat/instruct model for System Prompt Explorer. Returns status message.""" | |
| if model_name not in AVAILABLE_MODELS: | |
| return f"Unknown model: {model_name}" | |
| if self.loading: | |
| return "A model is already being loaded. Please wait." | |
| with self._lock: | |
| self.loading = True | |
| try: | |
| if self.chat_model is not None: | |
| del self.chat_model | |
| self.chat_model = None | |
| if self.chat_tokenizer is not None: | |
| del self.chat_tokenizer | |
| self.chat_tokenizer = None | |
| self.chat_model_name = None | |
| gc.collect() | |
| model, tokenizer = self._do_load(model_name) | |
| self.chat_model = model | |
| self.chat_tokenizer = tokenizer | |
| self.chat_model_name = model_name | |
| self.config["chat_model"] = model_name | |
| _save_config(self.config) | |
| return f"Loaded chat model: {model_name}" | |
| except Exception as e: | |
| self.chat_model = None | |
| self.chat_tokenizer = None | |
| self.chat_model_name = None | |
| return f"Failed to load chat model {model_name}: {e}" | |
| finally: | |
| self.loading = False | |
| # ------------------------------------------------------------------ | |
| # Status | |
| # ------------------------------------------------------------------ | |
| def is_ready(self) -> bool: | |
| return self.model is not None and not self.loading | |
| def chat_ready(self) -> bool: | |
| return self.chat_model is not None and not self.loading | |
| def status_message(self) -> str: | |
| if self.loading: | |
| return "Loading model..." | |
| parts = [] | |
| if self.model: | |
| parts.append(f"Base: {self.current_model_name}") | |
| if self.chat_model: | |
| parts.append(f"Chat: {self.chat_model_name}") | |
| if not parts: | |
| return "No models loaded" | |
| return " | ".join(parts) | |
| # ------------------------------------------------------------------ | |
| # Inference helpers | |
| # ------------------------------------------------------------------ | |
| def _get_logits(self, text: str) -> torch.Tensor: | |
| """Run a forward pass and return logits for the last token position.""" | |
| inputs = self.tokenizer(text, return_tensors="pt") | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = self.model(**inputs) | |
| return out.logits[0, -1, :] # (vocab_size,) | |
| def apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor: | |
| """Apply temperature scaling to logits and return probabilities.""" | |
| if temperature <= 0: | |
| temperature = 1e-6 | |
| scaled = logits / temperature | |
| probs = torch.softmax(scaled, dim=-1) | |
| # Softmax of all -inf produces NaN (0/0); replace with 0 | |
| probs = torch.nan_to_num(probs, nan=0.0) | |
| return probs | |
| def entropy_bits(probs: torch.Tensor) -> float: | |
| """Shannon entropy in bits.""" | |
| p = probs[probs > 0] | |
| return float(-torch.sum(p * torch.log2(p))) | |
| def top_k_table( | |
| self, probs: torch.Tensor, k: int = 10 | |
| ) -> list[tuple[str, float, int]]: | |
| """Return list of (token_str, probability, token_id) for top-k tokens.""" | |
| topk = torch.topk(probs, k=min(k, probs.shape[0])) | |
| rows = [] | |
| for prob, idx in zip(topk.values.tolist(), topk.indices.tolist()): | |
| token_str = self.tokenizer.decode([idx]) | |
| rows.append((token_str, float(prob), int(idx))) | |
| return rows | |
| # ------------------------------------------------------------------ | |
| # High-level generation | |
| # ------------------------------------------------------------------ | |
| def generate_step_by_step( | |
| self, | |
| prompt: str, | |
| steps: int = 8, | |
| temperature: float = 0.8, | |
| top_k: int = 10, | |
| seed: int = 42, | |
| show_steps: bool = True, | |
| ) -> list[dict]: | |
| """Generate tokens one at a time, returning per-step data. | |
| top_k controls both sampling (only top-k tokens considered) and | |
| how many tokens appear in the probability table. | |
| Each step dict contains: | |
| - step: int (1-based) | |
| - text: accumulated text so far | |
| - token: the sampled token string | |
| - token_id: int | |
| - entropy: float (bits) | |
| - top_tokens: list of (token_str, prob, token_id) | |
| """ | |
| if not self.is_ready(): | |
| return [] | |
| text = prompt | |
| results = [] | |
| rng = torch.Generator() | |
| for i in range(steps): | |
| logits = self._get_logits(text) | |
| # Apply top-k filtering before temperature | |
| top_k_vals, top_k_idxs = torch.topk(logits, k=min(top_k, logits.shape[0])) | |
| mask = torch.full_like(logits, float("-inf")) | |
| mask.scatter_(0, top_k_idxs, top_k_vals) | |
| logits = mask | |
| # Temperature 0 = greedy: pick argmax of raw logits, | |
| # but display probabilities at temperature=1 so the table is meaningful. | |
| if temperature == 0: | |
| probs = self.apply_temperature(logits, temperature=1.0) | |
| idx = torch.argmax(probs).item() | |
| else: | |
| probs = self.apply_temperature(logits, temperature) | |
| entropy = self.entropy_bits(probs) | |
| top_tokens = self.top_k_table(probs, k=top_k) if show_steps else [] | |
| if temperature != 0: | |
| rng.manual_seed(seed + i) | |
| idx = torch.multinomial(probs.cpu(), num_samples=1, generator=rng).item() | |
| token_str = self.tokenizer.decode([idx]) | |
| text += token_str | |
| results.append({ | |
| "step": i + 1, | |
| "text": text, | |
| "token": token_str, | |
| "token_id": int(idx), | |
| "entropy": entropy, | |
| "top_tokens": top_tokens, | |
| }) | |
| return results | |
| def generate_chat( | |
| self, | |
| messages: list[dict], | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.7, | |
| seed: int = 42, | |
| ) -> dict: | |
| """Generate a chat response using the dedicated chat model. | |
| Args: | |
| messages: Full conversation as list of {"role": ..., "content": ...} dicts, | |
| including system prompt and all previous turns. | |
| Returns dict with: | |
| - formatted_display: the full template including the response (for terminal) | |
| - response: the model's generated response text | |
| """ | |
| if not self.chat_ready(): | |
| return {"error": "Chat model not loaded"} | |
| # Format input (everything up to and including the generation prompt) | |
| formatted = self.chat_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| # Tokenize input | |
| inputs = self.chat_tokenizer(formatted, return_tensors="pt") | |
| inputs = {k: v.to(self.chat_model.device) for k, v in inputs.items()} | |
| input_len = inputs["input_ids"].shape[1] | |
| # Generate | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": temperature > 0, | |
| "pad_token_id": self.chat_tokenizer.eos_token_id, | |
| } | |
| if temperature > 0: | |
| gen_kwargs["temperature"] = temperature | |
| if self.chat_model.device.type == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| torch.manual_seed(seed) | |
| with torch.no_grad(): | |
| output_ids = self.chat_model.generate(**inputs, **gen_kwargs) | |
| # Decode only the new tokens | |
| new_ids = output_ids[0][input_len:] | |
| response = self.chat_tokenizer.decode(new_ids, skip_special_tokens=True).strip() | |
| # Build display template (includes the response) for green terminal | |
| display_messages = messages + [{"role": "assistant", "content": response}] | |
| formatted_display = self.chat_tokenizer.apply_chat_template( | |
| display_messages, tokenize=False, add_generation_prompt=False, | |
| ) | |
| return { | |
| "formatted_display": formatted_display, | |
| "response": response, | |
| } | |
| def format_chat_template(self, messages: list[dict]) -> str: | |
| """Format messages using the chat model's template (for terminal display).""" | |
| if not self.chat_tokenizer: | |
| return "" | |
| return self.chat_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| def tokenize(self, text: str) -> list[tuple[str, int]]: | |
| """Tokenize text and return list of (token_str, token_id).""" | |
| if self.tokenizer is None: | |
| return [] | |
| ids = self.tokenizer.encode(text) | |
| return [(self.tokenizer.decode([tid]), tid) for tid in ids] | |
| # ------------------------------------------------------------------ | |
| # Config helpers | |
| # ------------------------------------------------------------------ | |
| def get_config(self) -> dict: | |
| return dict(self.config) | |
| def update_config(self, **kwargs) -> None: | |
| self.config.update(kwargs) | |
| _save_config(self.config) | |
| # --------------------------------------------------------------------------- | |
| # Separate tokenizer for demo purposes (GPT-2 shows more interesting splits) | |
| # --------------------------------------------------------------------------- | |
| class DemoTokenizer: | |
| """Lightweight tokenizer for the Tokenizer tab. | |
| Uses GPT-2's BPE tokenizer which has a smaller vocabulary and produces | |
| more interesting subword splits than modern tokenizers like Qwen's. | |
| """ | |
| def __init__(self): | |
| self.tokenizer = None | |
| self._loaded = False | |
| def ensure_loaded(self): | |
| """Load tokenizer on first use (lazy loading).""" | |
| if not self._loaded: | |
| self.tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
| self._loaded = True | |
| def tokenize(self, text: str) -> list[tuple[str, int]]: | |
| """Tokenize text and return list of (token_str, token_id).""" | |
| self.ensure_loaded() | |
| ids = self.tokenizer.encode(text) | |
| return [(self.tokenizer.decode([tid]), tid) for tid in ids] | |
| # Module-level singleton for demo tokenizer | |
| demo_tokenizer = DemoTokenizer() | |
| # Module-level singleton | |
| manager = ModelManager() | |