""" Steerling-8B Demo — Hugging Face Spaces (ZeroGPU) An interactive demo for Steerling, an interpretable causal diffusion language model with concept steering. Uses confidence-based unmasking to generate text. Tokens are streamed live so you can watch the diffusion process fill in words out of order — the signature behavior of this model. https://huggingface.co/guidelabs/steerling-8b """ from __future__ import annotations # --------------------------------------------------------------------------- # Install steerling at startup — its metadata says >=3.13 but the code works # fine on 3.12 (every module uses `from __future__ import annotations`). # --------------------------------------------------------------------------- import subprocess import sys subprocess.check_call( [ sys.executable, "-m", "pip", "install", "--quiet", "--no-deps", "--ignore-requires-python", "steerling>=0.1.2", ] ) # --------------------------------------------------------------------------- # Imports # --------------------------------------------------------------------------- import html as html_lib import logging import math import time from textwrap import dedent import gradio as gr import spaces import torch from steerling import SteerlingGenerator logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- MODEL_ID = "guidelabs/steerling-8b" logger.info("Loading Steerling-8B model …") generator = SteerlingGenerator.from_pretrained(MODEL_ID, device="cuda") logger.info("Model loaded successfully.") # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- MASK_CHAR = "▒" # visual placeholder for masked positions STEERING_PRESETS: dict[str, dict[str, dict[int, float] | None]] = { "None": {"steer_known": None, "steer_unknown": None}, "More formal / academic": { "steer_known": {102: 2.5, 3421: 2.0, 8910: -1.5}, "steer_unknown": None, }, "More creative / poetic": { "steer_known": {541: 2.0, 7723: 2.5, 102: -2.0}, "steer_unknown": None, }, "More concise / factual": { "steer_known": {102: 1.5, 3421: 1.5, 541: -2.0, 7723: -1.5}, "steer_unknown": None, }, } EXAMPLE_PROMPTS = [ "The key to understanding neural networks is", "In the year 2050, renewable energy", "The theory of general relativity explains", "Once upon a time, in a land where machines could dream,", "The most important breakthrough in modern medicine was", "Artificial intelligence will", "The future of space exploration depends on", "A poem about the ocean:\n", ] # --------------------------------------------------------------------------- # HTML rendering helpers # --------------------------------------------------------------------------- def _render_html( prompt_text: str, gen_tokens: list[int], is_finalized: list[bool], just_unmasked: set[int], tokenizer, total_gen_slots: int, step: int, elapsed: float, ) -> str: """Build an HTML snippet showing prompt + generation with color-coded tokens.""" # Escape prompt for safe HTML escaped_prompt = html_lib.escape(prompt_text) parts: list[str] = [] for i in range(total_gen_slots): if i < len(gen_tokens) and is_finalized[i]: tok_text = tokenizer.decode([gen_tokens[i]]) escaped = html_lib.escape(tok_text) # Replace newlines with
and preserve spaces escaped = escaped.replace("\n", "
") escaped = escaped.replace(" ", "  ") if i in just_unmasked: # Newly revealed this step — bright highlight parts.append( f'' f"{escaped}" ) else: # Previously revealed parts.append(f'{escaped}') else: # Still masked parts.append( f'{MASK_CHAR}' ) gen_html = "".join(parts) n_filled = sum(1 for f in is_finalized if f) n_total = total_gen_slots pct = int(100 * n_filled / n_total) if n_total > 0 else 0 # Progress bar bar_html = ( f'
' f'
' f'
' f"Step {step} — {n_filled}/{n_total} tokens unmasked ({pct}%)" f" — {elapsed:.1f}s elapsed
" ) return ( f"
' f'{escaped_prompt}' f"{gen_html}" f"{bar_html}" f"
" ) def _render_final_info( prompt_tokens: int, generated_tokens: int, total_steps: int, elapsed: float, steering_preset: str, steer_known: dict[int, float] | None, ) -> str: tok_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0 info = ( f"**Prompt tokens:** {prompt_tokens} \n" f"**Generated tokens:** {generated_tokens} \n" f"**Diffusion steps:** {total_steps} \n" f"**Time:** {elapsed:.2f}s ({tok_per_sec:.1f} tok/s) \n" f"**Steering:** {steering_preset}" ) if steer_known: info += f" \n**Known concept IDs:** `{steer_known}`" return info # --------------------------------------------------------------------------- # Streaming generation — reimplements the core loop from # SteerlingGenerator.generate_full so we can yield after each step. # --------------------------------------------------------------------------- @spaces.GPU(duration=60) def generate_streaming( prompt: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, tokens_per_step: int, use_entropy_sampling: bool, seed: int | None, steering_preset: str, custom_steer_known: str, ): """Generator: yields (html_viz, generated_text, info_md) tuples.""" if not prompt or not prompt.strip(): yield ( '
⚠️ Please enter a prompt.
', "", "", ) return # --- resolve steering --------------------------------------------------- steer_known: dict[int, float] | None = None steer_unknown: dict[int, float] | None = None if steering_preset != "None" and steering_preset in STEERING_PRESETS: preset = STEERING_PRESETS[steering_preset] steer_known = preset.get("steer_known") steer_unknown = preset.get("steer_unknown") if custom_steer_known and custom_steer_known.strip(): try: parsed: dict[int, float] = {} for pair in custom_steer_known.split(","): pair = pair.strip() if not pair: continue cid, val = pair.split(":") parsed[int(cid.strip())] = float(val.strip()) if parsed: steer_known = parsed except Exception as exc: yield ( f'
⚠️ Could not parse custom steering: {exc}
', "", "", ) return # --- setup -------------------------------------------------------------- gen = generator # alias tokenizer = gen.tokenizer model = gen.model device = gen.device mask_id = gen.mask_token_id eos_id = gen.eos_token_id pad_id = gen.pad_token_id if seed is not None and seed >= 0: torch.manual_seed(int(seed)) prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) prompt_len = len(prompt_ids) total_len = prompt_len + max_new_tokens # Initialize sequence: prompt tokens + mask tokens for generation slots x = torch.full((1, total_len), mask_id, dtype=torch.long, device=device) if prompt_len > 0: x[0, :prompt_len] = torch.tensor(prompt_ids, dtype=torch.long, device=device) # Track what's finalized is_prompt_mask = torch.zeros(total_len, dtype=torch.bool, device=device) is_prompt_mask[:prompt_len] = True gen_region = ~is_prompt_mask is_finalized = is_prompt_mask.clone() # Banned token ids banned_ids = {mask_id} if pad_id is not None: banned_ids.add(pad_id) # Build steering intervention tensors int_known_ids, int_known_vals = None, None int_unknown_ids, int_unknown_vals = None, None if gen.is_interpretable and steer_known: int_known_ids, int_known_vals = gen._build_intervention_tensors( steer_known, total_len ) if gen.is_interpretable and steer_unknown: int_unknown_ids, int_unknown_vals = gen._build_intervention_tensors( steer_unknown, total_len ) # Lists to track generation-region state for rendering gen_token_ids: list[int] = [mask_id] * max_new_tokens gen_finalized: list[bool] = [False] * max_new_tokens t0 = time.perf_counter() tokens_generated = 0 step_count = 0 # --- yield initial state (all masked) ----------------------------------- yield ( _render_html( prompt, gen_token_ids, gen_finalized, set(), tokenizer, max_new_tokens, 0, 0.0, ), "", "*Generating…*", ) # --- diffusion loop (under inference_mode for perf) --------------------- with torch.inference_mode(): while tokens_generated < max_new_tokens: still_masked = (x[0] == mask_id) & gen_region masked_indices = still_masked.nonzero(as_tuple=False).squeeze(-1) if masked_indices.numel() == 0: break if masked_indices.dim() == 0: masked_indices = masked_indices.unsqueeze(0) # Forward pass if gen.is_interpretable: logits, _ = model( x, use_teacher_forcing=False, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=True, ) else: logits = model(x) masked_logits = logits[0, masked_indices].clone() # Ban special tokens for tid in banned_ids: masked_logits[:, tid] = -1e9 # Repetition penalty if repetition_penalty != 1.0: finalized_tokens = x[0, is_finalized].tolist() for tok in set(finalized_tokens): if tok not in banned_ids: masked_logits[:, tok] /= repetition_penalty # Confidence-based position selection probs_for_conf = torch.softmax(masked_logits, dim=-1) confidences = probs_for_conf.max(dim=-1).values k = min(tokens_per_step, masked_indices.numel()) _, selected_pos_indices = confidences.topk(k) step_count += 1 just_unmasked: set[int] = set() # Fill selected positions for pos_idx in selected_pos_indices: seq_idx = int(masked_indices[pos_idx].item()) gen_slot = seq_idx - prompt_len # index into gen arrays pos_logits = masked_logits[pos_idx] # Temperature (entropy-adaptive or fixed) if use_entropy_sampling: pos_probs_raw = torch.softmax(pos_logits, dim=-1) sorted_probs, _ = torch.sort(pos_probs_raw, descending=True) cumsum = torch.cumsum(sorted_probs, dim=-1) effective_k = max((cumsum <= top_p).sum().item() + 1, 2) entropy = -torch.sum( pos_probs_raw * torch.log(pos_probs_raw + 1e-10) ) normalized_entropy = min( 1.0, entropy.item() / math.log(effective_k) ) adaptive_temp = 0.3 + 0.4 * normalized_entropy pos_probs = torch.softmax(pos_logits / adaptive_temp, dim=-1) else: pos_probs = torch.softmax( pos_logits / max(temperature, 1e-8), dim=-1 ) tok = _sample_top_p(pos_probs, top_p) x[0, seq_idx] = tok is_finalized[seq_idx] = True tokens_generated += 1 if 0 <= gen_slot < max_new_tokens: gen_token_ids[gen_slot] = tok gen_finalized[gen_slot] = True just_unmasked.add(gen_slot) if eos_id is not None and tok == eos_id: break elapsed = time.perf_counter() - t0 # Decode the current generated text (finalized tokens only, in order) current_gen_tokens = [] for i in range(max_new_tokens): if gen_finalized[i]: current_gen_tokens.append(gen_token_ids[i]) else: break current_text = ( tokenizer.decode(current_gen_tokens) if current_gen_tokens else "" ) yield ( _render_html( prompt, gen_token_ids, gen_finalized, just_unmasked, tokenizer, max_new_tokens, step_count, elapsed, ), current_text, f"*Step {step_count} — {tokens_generated}/{max_new_tokens} tokens — {elapsed:.1f}s*", ) # Check EOS if eos_id is not None and (x[0, gen_region] == eos_id).any(): break # --- final yield -------------------------------------------------------- elapsed = time.perf_counter() - t0 final_tokens = [] for i in range(max_new_tokens): if gen_finalized[i]: final_tokens.append(gen_token_ids[i]) else: break final_text = tokenizer.decode(final_tokens) if final_tokens else "" final_info = _render_final_info( prompt_tokens=prompt_len, generated_tokens=len(final_tokens), total_steps=step_count, elapsed=elapsed, steering_preset=steering_preset, steer_known=steer_known, ) # Final HTML without any highlight yield ( _render_html( prompt, gen_token_ids, gen_finalized, set(), tokenizer, max_new_tokens, step_count, elapsed, ), final_text, final_info, ) def _sample_top_p(probs: torch.Tensor, top_p: float) -> int: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) cutoff_mask = cumulative <= top_p cutoff_mask[0] = True cutoff_idx = min(cutoff_mask.sum().item() + 1, len(sorted_probs)) truncated = sorted_probs[:cutoff_idx] truncated = truncated / truncated.sum() return int(sorted_indices[torch.multinomial(truncated, 1)].item()) # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- DESCRIPTION = dedent("""\ # 🧭 Steerling-8B Demo **[Steerling-8B](https://huggingface.co/guidelabs/steerling-8b)** is an 8 billion parameter *causal diffusion* language model with interpretable concept steering, built by [Guide Labs](https://www.guidelabs.ai/). Unlike standard autoregressive LLMs, Steerling generates text by **iteratively unmasking tokens in order of confidence** — the model fills in positions where it is most certain first. Watch the diffusion process live below! ### ✨ Key Features | Feature | Description | |---|---| | 🎲 **Diffusion decoding** | Confidence-based unmasking instead of left-to-right | | 🔍 **Interpretability** | Hidden states → known + unknown concept decomposition | | 🎛️ **Concept steering** | Amplify or suppress concepts to guide generation | | 📐 **Block-causal attention** | Bidirectional within 64-token blocks, causal across | > ℹ️ This Space runs on **ZeroGPU** (NVIDIA H200). Generation may be > queued briefly while a GPU is allocated. """) ARTICLE = dedent("""\ --- ### How It Works ``` hidden → known_features + unknown_features + ε = composed → logits ``` - **known_features** — weighted sum of top-k learned concept embeddings (interpretable) - **unknown_features** — residual captured by a factorized unknown head - **ε** — small correction for reconstruction fidelity The **live visualization** above shows the diffusion process in action: - **Blue text** = your prompt - **Highlighted** = just unmasked this step - = still masked (waiting to be filled) Unlike autoregressive models that generate left-to-right, Steerling fills in the **most confident positions first**, regardless of order. ### Links - 📄 [Model Card](https://huggingface.co/guidelabs/steerling-8b) - 💻 [GitHub](https://github.com/guidelabs/steerling) - 🏢 [Guide Labs](https://www.guidelabs.ai/) - 📝 [Architecture Blog Post](https://www.guidelabs.ai/post/block-causal-diffusion-language-model/) """) CSS = """ footer { display: none !important; } .generating { border: none !important; } """ with gr.Blocks(css=CSS, title="Steerling-8B Demo", theme=gr.themes.Soft()) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): # ── Left column: inputs ─────────────────────────────────────── with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="Enter your prompt here…", lines=4, value=EXAMPLE_PROMPTS[0], ) with gr.Accordion("⚙️ Generation Settings", open=False): max_new_tokens = gr.Slider( 16, 512, value=128, step=16, label="Max new tokens", ) temperature = gr.Slider( 0.0, 2.0, value=1.0, step=0.05, label="Temperature", info="Overridden when entropy sampling is on", ) top_p = gr.Slider( 0.1, 1.0, value=0.9, step=0.05, label="Top-p (nucleus)", ) repetition_penalty = gr.Slider( 1.0, 2.0, value=1.2, step=0.05, label="Repetition penalty", ) tokens_per_step = gr.Slider( 1, 64, value=1, step=1, label="Tokens per step", info="Unmask multiple positions per diffusion step (faster but noisier)", ) use_entropy_sampling = gr.Checkbox( value=True, label="Entropy-adaptive sampling", info="Automatically adjusts temperature (0.3–0.7) based on model uncertainty", ) seed = gr.Number( value=42, label="Seed (-1 = random)", precision=0, ) with gr.Accordion("🎛️ Concept Steering", open=False): gr.Markdown( "Steerling decomposes hidden states into **known concepts**. " "You can amplify (positive weight) or suppress (negative weight) " "specific concept IDs to steer generation." ) steering_preset = gr.Dropdown( choices=list(STEERING_PRESETS.keys()), value="None", label="Steering preset", ) custom_steer_known = gr.Textbox( label="Custom known-concept overrides", placeholder="e.g. 102:2.5, 541:-1.0", info="Comma-separated id:weight pairs. Overrides the preset.", ) generate_btn = gr.Button( "🚀 Generate", variant="primary", size="lg", ) gr.Examples( examples=[[p] for p in EXAMPLE_PROMPTS], inputs=[prompt], label="Example prompts", ) # ── Right column: outputs ───────────────────────────────────── with gr.Column(scale=1): viz_html = gr.HTML( label="Live diffusion", value=( '
Press Generate to watch the diffusion ' "process unfold…
" ), ) generated_output = gr.Textbox( label="Generated text (plain)", lines=6, interactive=False, ) info_md = gr.Markdown(label="Generation info") # Wire inputs list inputs = [ prompt, max_new_tokens, temperature, top_p, repetition_penalty, tokens_per_step, use_entropy_sampling, seed, steering_preset, custom_steer_known, ] outputs = [viz_html, generated_output, info_md] generate_btn.click( fn=generate_streaming, inputs=inputs, outputs=outputs, ) prompt.submit( fn=generate_streaming, inputs=inputs, outputs=outputs, ) gr.Markdown(ARTICLE) if __name__ == "__main__": demo.launch()